增加联邦学习指标;fix:Pytorch 加载模型不匹配
This commit is contained in:
		@@ -1,6 +1,7 @@
 | 
			
		||||
import glob
 | 
			
		||||
import os
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
import yaml
 | 
			
		||||
from ultralytics import YOLO
 | 
			
		||||
@@ -31,105 +32,170 @@ def federated_avg(global_model, client_weights):
 | 
			
		||||
        # 加权平均
 | 
			
		||||
        if global_dict[key].dtype == torch.float32:  # 只聚合浮点型参数
 | 
			
		||||
            # 跳过 BatchNorm 层的统计量
 | 
			
		||||
            if any(x in key for x in ['running_mean', 'running_var', 'num_batches_tracked']):
 | 
			
		||||
            if any(
 | 
			
		||||
                x in key for x in ["running_mean", "running_var", "num_batches_tracked"]
 | 
			
		||||
            ):
 | 
			
		||||
                continue
 | 
			
		||||
            # 按照样本数加权求和
 | 
			
		||||
            weighted_tensors = [sd[key].float() * (n / total_samples)
 | 
			
		||||
                                for sd, n in zip(state_dicts, sample_counts)]
 | 
			
		||||
            weighted_tensors = [
 | 
			
		||||
                sd[key].float() * (n / total_samples)
 | 
			
		||||
                for sd, n in zip(state_dicts, sample_counts)
 | 
			
		||||
            ]
 | 
			
		||||
            global_dict[key] = torch.stack(weighted_tensors, dim=0).sum(dim=0)
 | 
			
		||||
 | 
			
		||||
        # 解决模型参数不匹配问题
 | 
			
		||||
    try:
 | 
			
		||||
        # try:
 | 
			
		||||
        #     # 加载回YOLO模型
 | 
			
		||||
        #     global_model.model.load_state_dict(global_dict)
 | 
			
		||||
        # except RuntimeError as e:
 | 
			
		||||
        #     print('Ignoring "' + str(e) + '"')
 | 
			
		||||
 | 
			
		||||
        # 加载回YOLO模型
 | 
			
		||||
        global_model.model.load_state_dict(global_dict)
 | 
			
		||||
    except RuntimeError as e:
 | 
			
		||||
        print('Ignoring "' + str(e) + '"')
 | 
			
		||||
    
 | 
			
		||||
    # 添加调试输出
 | 
			
		||||
    print("\n=== 参数聚合检查 ===")
 | 
			
		||||
    
 | 
			
		||||
    # 选取一个典型参数层
 | 
			
		||||
    # sample_key = list(global_dict.keys())[10]
 | 
			
		||||
    # original = global_dict[sample_key].data.mean().item()
 | 
			
		||||
    # aggregated = torch.stack([w[sample_key] for w in client_weights]).mean().item()
 | 
			
		||||
    # print(f"参数层 '{sample_key}' 变化: {original:.4f} → {aggregated:.4f}")
 | 
			
		||||
    # print(f"客户端参数差异: {[w[sample_key].mean().item() for w in client_weights]}")
 | 
			
		||||
 | 
			
		||||
    # 随机选取一个非统计量层进行对比
 | 
			
		||||
    sample_key = next(k for k in global_dict if 'running_' not in k)
 | 
			
		||||
    aggregated_mean = global_dict[sample_key].mean().item()
 | 
			
		||||
    client_means = [sd[sample_key].float().mean().item() for sd in state_dicts]
 | 
			
		||||
    print(f"layer: '{sample_key}' Mean after aggregation: {aggregated_mean:.6f}")
 | 
			
		||||
    print(f"The average value of the layer for each client: {client_means}")
 | 
			
		||||
    # sample_key = next(k for k in global_dict if 'running_' not in k)
 | 
			
		||||
    # aggregated_mean = global_dict[sample_key].mean().item()
 | 
			
		||||
    # client_means = [sd[sample_key].float().mean().item() for sd in state_dicts]
 | 
			
		||||
    # print(f"layer: '{sample_key}' Mean after aggregation: {aggregated_mean:.6f}")
 | 
			
		||||
    # print(f"The average value of the layer for each client: {client_means}")
 | 
			
		||||
 | 
			
		||||
    # 定义多个关键层
 | 
			
		||||
    MONITOR_KEYS = [
 | 
			
		||||
        "model.0.conv.weight",  # 输入层卷积
 | 
			
		||||
        "model.10.conv.weight",  # 中间层卷积
 | 
			
		||||
        "model.22.dfl.conv.weight",  # 输出层分类头
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    with open("aggregation_check.txt", "a") as f:
 | 
			
		||||
        f.write("\n=== 参数聚合检查 ===\n")
 | 
			
		||||
    for key in MONITOR_KEYS:
 | 
			
		||||
        if key not in global_dict:
 | 
			
		||||
            continue
 | 
			
		||||
        # 计算聚合后均值
 | 
			
		||||
        aggregated_mean = global_dict[key].mean().item()
 | 
			
		||||
        # 计算各客户端均值
 | 
			
		||||
        client_means = [sd[key].float().mean().item() for sd in state_dicts]
 | 
			
		||||
 | 
			
		||||
        with open("aggregation_check.txt", "a") as f:
 | 
			
		||||
            f.write(f"层 '{key}' 聚合后均值: {aggregated_mean:.6f}\n")
 | 
			
		||||
            f.write(f"各客户端该层均值差异: {[f'{cm:.6f}' for cm in client_means]}\n")
 | 
			
		||||
            f.write(f"客户端最大差异: {max(client_means) - min(client_means):.6f}\n\n")
 | 
			
		||||
 | 
			
		||||
    return global_model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ------------ 修改训练流程 ------------
 | 
			
		||||
def federated_train(num_rounds, clients_data):
 | 
			
		||||
    # ========== 新增:初始化指标记录 ==========
 | 
			
		||||
    metrics = {
 | 
			
		||||
        "round": [],
 | 
			
		||||
        "val_mAP": [],  # 每轮验证集mAP
 | 
			
		||||
        "train_loss": [],  # 每轮平均训练损失
 | 
			
		||||
        "client_mAPs": [],  # 各客户端本地模型在验证集上的mAP
 | 
			
		||||
        "communication_cost": [],  # 每轮通信开销(MB)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    # 初始化全局模型
 | 
			
		||||
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
			
		||||
    global_model = YOLO("../yolov8n.pt").to(device)
 | 
			
		||||
    global_model = YOLO("../yolov8n.yaml").to(device)
 | 
			
		||||
    # 设置类别数
 | 
			
		||||
    global_model.model.nc = 1
 | 
			
		||||
    # global_model.model.nc = 1
 | 
			
		||||
 | 
			
		||||
    for _ in range(num_rounds):
 | 
			
		||||
        client_weights = []
 | 
			
		||||
        client_losses = []  # 记录各客户端的训练损失
 | 
			
		||||
 | 
			
		||||
        # 每个客户端本地训练
 | 
			
		||||
        for data_path in clients_data:
 | 
			
		||||
            # 统计本地训练样本数
 | 
			
		||||
            with open(data_path, 'r') as f:
 | 
			
		||||
            with open(data_path, "r") as f:
 | 
			
		||||
                config = yaml.safe_load(f)
 | 
			
		||||
            #  Resolve img_dir relative to the YAML file's location
 | 
			
		||||
            yaml_dir = os.path.dirname(data_path)
 | 
			
		||||
            img_dir = os.path.join(yaml_dir, config.get('train', data_path))  # 从配置文件中获取图像目录
 | 
			
		||||
            img_dir = os.path.join(
 | 
			
		||||
                yaml_dir, config.get("train", data_path)
 | 
			
		||||
            )  # 从配置文件中获取图像目录
 | 
			
		||||
 | 
			
		||||
            # print(f"Image directory: {img_dir}")
 | 
			
		||||
            num_samples = (len(glob.glob(os.path.join(img_dir, '*.jpg'))) +
 | 
			
		||||
                           len(glob.glob(os.path.join(img_dir, '*.png'))))
 | 
			
		||||
            num_samples = (len(glob.glob(os.path.join(img_dir, "*.jpg")))
 | 
			
		||||
                + len(glob.glob(os.path.join(img_dir, "*.png")))
 | 
			
		||||
                + len(glob.glob(os.path.join(img_dir, "*.jpeg")))
 | 
			
		||||
            )
 | 
			
		||||
            # print(f"Number of images: {num_samples}")
 | 
			
		||||
 | 
			
		||||
            # 克隆全局模型
 | 
			
		||||
            local_model = copy.deepcopy(global_model)
 | 
			
		||||
 | 
			
		||||
            # 本地训练(保持你的原有参数设置)
 | 
			
		||||
            local_model.train(
 | 
			
		||||
            results = local_model.train(
 | 
			
		||||
                data=data_path,
 | 
			
		||||
                epochs=16,  # 每轮本地训练1个epoch
 | 
			
		||||
                save_period=16,
 | 
			
		||||
                epochs=4,  # 每轮本地训练多少个epoch
 | 
			
		||||
                # save_period=16,
 | 
			
		||||
                imgsz=640,  # 图像大小
 | 
			
		||||
                verbose=False,  # 关闭冗余输出
 | 
			
		||||
                batch=-1
 | 
			
		||||
                batch=-1,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # 记录客户端训练损失
 | 
			
		||||
            # client_loss = results.results_dict['train_loss']
 | 
			
		||||
            # client_losses.append(client_loss)
 | 
			
		||||
 | 
			
		||||
            # 收集模型参数及样本数
 | 
			
		||||
            client_weights.append((copy.deepcopy(local_model.model.state_dict()), num_samples))
 | 
			
		||||
            client_weights.append(
 | 
			
		||||
                (copy.deepcopy(local_model.model.state_dict()), num_samples)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # 聚合参数更新全局模型
 | 
			
		||||
        global_model = federated_avg(global_model, client_weights)
 | 
			
		||||
        print(f"Round {_ + 1}/{num_rounds} completed.")
 | 
			
		||||
    return global_model
 | 
			
		||||
 | 
			
		||||
        # ========== 评估全局模型 ==========
 | 
			
		||||
        # 评估全局模型在验证集上的性能
 | 
			
		||||
        val_results = global_model.val(
 | 
			
		||||
            data="/mnt/DATA/UAVdataset/data.yaml",  # 指定验证集配置文件
 | 
			
		||||
            imgsz=640,
 | 
			
		||||
            batch=-1,
 | 
			
		||||
            verbose=False,
 | 
			
		||||
        )
 | 
			
		||||
        val_mAP = val_results.box.map  # 获取mAP@0.5
 | 
			
		||||
 | 
			
		||||
        # 计算平均训练损失
 | 
			
		||||
        # avg_train_loss = sum(client_losses) / len(client_losses)
 | 
			
		||||
 | 
			
		||||
        # 计算通信开销(假设传输全部模型参数)
 | 
			
		||||
        model_size = sum(p.numel() * 4 for p in global_model.model.parameters()) / (
 | 
			
		||||
            1024**2
 | 
			
		||||
        )  # MB
 | 
			
		||||
 | 
			
		||||
        # 记录到指标容器
 | 
			
		||||
        metrics["round"].append(_ + 1)
 | 
			
		||||
        metrics["val_mAP"].append(val_mAP)
 | 
			
		||||
        # metrics['train_loss'].append(avg_train_loss)
 | 
			
		||||
        metrics["communication_cost"].append(model_size)
 | 
			
		||||
        # 打印当前轮次结果
 | 
			
		||||
        with open("aggregation_check.txt", "a") as f:
 | 
			
		||||
            f.write(f"\n[Round {_ + 1}/{num_rounds}]")
 | 
			
		||||
            f.write(f"Validation mAP@0.5: {val_mAP:.4f}")
 | 
			
		||||
            # f.write(f"Average Train Loss: {avg_train_loss:.4f}")
 | 
			
		||||
            f.write(f"Communication Cost: {model_size:.2f} MB\n")
 | 
			
		||||
 | 
			
		||||
    return global_model, metrics
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ------------ 使用示例 ------------
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # 联邦训练配置
 | 
			
		||||
    clients_config = [
 | 
			
		||||
        "/root/autodl-tmp/dataset/train1/train1.yaml",  # 客户端1数据路径
 | 
			
		||||
        "/root/autodl-tmp/dataset/train2/train2.yaml"  # 客户端2数据路径
 | 
			
		||||
        "/mnt/DATA/uav_dataset_fed/train1/train1.yaml",  # 客户端1数据路径
 | 
			
		||||
        "/mnt/DATA/uav_dataset_fed/train2/train2.yaml",  # 客户端2数据路径
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    # 运行联邦训练
 | 
			
		||||
    final_model = federated_train(num_rounds=10, clients_data=clients_config)
 | 
			
		||||
    final_model, metrics = federated_train(num_rounds=40, clients_data=clients_config)
 | 
			
		||||
 | 
			
		||||
    # 保存最终模型
 | 
			
		||||
    final_model.save("yolov8n_federated.pt")
 | 
			
		||||
    # final_model.export(format="onnx")  # 导出为ONNX格式
 | 
			
		||||
 | 
			
		||||
    # 检查1:确认模型保存
 | 
			
		||||
    # assert Path("yolov8n_federated.onnx").exists(), "模型导出失败"
 | 
			
		||||
    
 | 
			
		||||
    # 检查2:验证预测功能
 | 
			
		||||
    # results = final_model.predict("../dataset/val/images/VS_P65.jpg", save=True)
 | 
			
		||||
    # assert len(results[0].boxes) > 0, "预测结果异常"
 | 
			
		||||
    with open("metrics.json", "w") as f:
 | 
			
		||||
        json.dump(metrics, f, indent=4)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										49
									
								
								yolov8.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								yolov8.yaml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,49 @@
 | 
			
		||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
 | 
			
		||||
 | 
			
		||||
# Ultralytics YOLOv8 object detection model with P3/8 - P5/32 outputs
 | 
			
		||||
# Model docs: https://docs.ultralytics.com/models/yolov8
 | 
			
		||||
# Task docs: https://docs.ultralytics.com/tasks/detect
 | 
			
		||||
 | 
			
		||||
# Parameters
 | 
			
		||||
nc: 1 # number of classes
 | 
			
		||||
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
 | 
			
		||||
  # [depth, width, max_channels]
 | 
			
		||||
  n: [0.33, 0.25, 1024] # YOLOv8n summary: 129 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPS
 | 
			
		||||
  s: [0.33, 0.50, 1024] # YOLOv8s summary: 129 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPS
 | 
			
		||||
  m: [0.67, 0.75, 768] # YOLOv8m summary: 169 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPS
 | 
			
		||||
  l: [1.00, 1.00, 512] # YOLOv8l summary: 209 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPS
 | 
			
		||||
  x: [1.00, 1.25, 512] # YOLOv8x summary: 209 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPS
 | 
			
		||||
 | 
			
		||||
# YOLOv8.0n backbone
 | 
			
		||||
backbone:
 | 
			
		||||
  # [from, repeats, module, args]
 | 
			
		||||
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
 | 
			
		||||
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
 | 
			
		||||
  - [-1, 3, C2f, [128, True]]
 | 
			
		||||
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
 | 
			
		||||
  - [-1, 6, C2f, [256, True]]
 | 
			
		||||
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
 | 
			
		||||
  - [-1, 6, C2f, [512, True]]
 | 
			
		||||
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
 | 
			
		||||
  - [-1, 3, C2f, [1024, True]]
 | 
			
		||||
  - [-1, 1, SPPF, [1024, 5]] # 9
 | 
			
		||||
 | 
			
		||||
# YOLOv8.0n head
 | 
			
		||||
head:
 | 
			
		||||
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
 | 
			
		||||
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
 | 
			
		||||
  - [-1, 3, C2f, [512]] # 12
 | 
			
		||||
 | 
			
		||||
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
 | 
			
		||||
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
 | 
			
		||||
  - [-1, 3, C2f, [256]] # 15 (P3/8-small)
 | 
			
		||||
 | 
			
		||||
  - [-1, 1, Conv, [256, 3, 2]]
 | 
			
		||||
  - [[-1, 12], 1, Concat, [1]] # cat head P4
 | 
			
		||||
  - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
 | 
			
		||||
 | 
			
		||||
  - [-1, 1, Conv, [512, 3, 2]]
 | 
			
		||||
  - [[-1, 9], 1, Concat, [1]] # cat head P5
 | 
			
		||||
  - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
 | 
			
		||||
 | 
			
		||||
  - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
 | 
			
		||||
		Reference in New Issue
	
	Block a user