增加联邦学习指标;fix:Pytorch 加载模型不匹配
This commit is contained in:
		@@ -1,6 +1,7 @@
 | 
				
			|||||||
import glob
 | 
					import glob
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import yaml
 | 
					import yaml
 | 
				
			||||||
from ultralytics import YOLO
 | 
					from ultralytics import YOLO
 | 
				
			||||||
@@ -31,105 +32,170 @@ def federated_avg(global_model, client_weights):
 | 
				
			|||||||
        # 加权平均
 | 
					        # 加权平均
 | 
				
			||||||
        if global_dict[key].dtype == torch.float32:  # 只聚合浮点型参数
 | 
					        if global_dict[key].dtype == torch.float32:  # 只聚合浮点型参数
 | 
				
			||||||
            # 跳过 BatchNorm 层的统计量
 | 
					            # 跳过 BatchNorm 层的统计量
 | 
				
			||||||
            if any(x in key for x in ['running_mean', 'running_var', 'num_batches_tracked']):
 | 
					            if any(
 | 
				
			||||||
 | 
					                x in key for x in ["running_mean", "running_var", "num_batches_tracked"]
 | 
				
			||||||
 | 
					            ):
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
            # 按照样本数加权求和
 | 
					            # 按照样本数加权求和
 | 
				
			||||||
            weighted_tensors = [sd[key].float() * (n / total_samples)
 | 
					            weighted_tensors = [
 | 
				
			||||||
                                for sd, n in zip(state_dicts, sample_counts)]
 | 
					                sd[key].float() * (n / total_samples)
 | 
				
			||||||
 | 
					                for sd, n in zip(state_dicts, sample_counts)
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
            global_dict[key] = torch.stack(weighted_tensors, dim=0).sum(dim=0)
 | 
					            global_dict[key] = torch.stack(weighted_tensors, dim=0).sum(dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # 解决模型参数不匹配问题
 | 
					        # 解决模型参数不匹配问题
 | 
				
			||||||
    try:
 | 
					        # try:
 | 
				
			||||||
 | 
					        #     # 加载回YOLO模型
 | 
				
			||||||
 | 
					        #     global_model.model.load_state_dict(global_dict)
 | 
				
			||||||
 | 
					        # except RuntimeError as e:
 | 
				
			||||||
 | 
					        #     print('Ignoring "' + str(e) + '"')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # 加载回YOLO模型
 | 
					        # 加载回YOLO模型
 | 
				
			||||||
        global_model.model.load_state_dict(global_dict)
 | 
					        global_model.model.load_state_dict(global_dict)
 | 
				
			||||||
    except RuntimeError as e:
 | 
					 | 
				
			||||||
        print('Ignoring "' + str(e) + '"')
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # 添加调试输出
 | 
					 | 
				
			||||||
    print("\n=== 参数聚合检查 ===")
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # 选取一个典型参数层
 | 
					 | 
				
			||||||
    # sample_key = list(global_dict.keys())[10]
 | 
					 | 
				
			||||||
    # original = global_dict[sample_key].data.mean().item()
 | 
					 | 
				
			||||||
    # aggregated = torch.stack([w[sample_key] for w in client_weights]).mean().item()
 | 
					 | 
				
			||||||
    # print(f"参数层 '{sample_key}' 变化: {original:.4f} → {aggregated:.4f}")
 | 
					 | 
				
			||||||
    # print(f"客户端参数差异: {[w[sample_key].mean().item() for w in client_weights]}")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # 随机选取一个非统计量层进行对比
 | 
					    # 随机选取一个非统计量层进行对比
 | 
				
			||||||
    sample_key = next(k for k in global_dict if 'running_' not in k)
 | 
					    # sample_key = next(k for k in global_dict if 'running_' not in k)
 | 
				
			||||||
    aggregated_mean = global_dict[sample_key].mean().item()
 | 
					    # aggregated_mean = global_dict[sample_key].mean().item()
 | 
				
			||||||
    client_means = [sd[sample_key].float().mean().item() for sd in state_dicts]
 | 
					    # client_means = [sd[sample_key].float().mean().item() for sd in state_dicts]
 | 
				
			||||||
    print(f"layer: '{sample_key}' Mean after aggregation: {aggregated_mean:.6f}")
 | 
					    # print(f"layer: '{sample_key}' Mean after aggregation: {aggregated_mean:.6f}")
 | 
				
			||||||
    print(f"The average value of the layer for each client: {client_means}")
 | 
					    # print(f"The average value of the layer for each client: {client_means}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 定义多个关键层
 | 
				
			||||||
 | 
					    MONITOR_KEYS = [
 | 
				
			||||||
 | 
					        "model.0.conv.weight",  # 输入层卷积
 | 
				
			||||||
 | 
					        "model.10.conv.weight",  # 中间层卷积
 | 
				
			||||||
 | 
					        "model.22.dfl.conv.weight",  # 输出层分类头
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with open("aggregation_check.txt", "a") as f:
 | 
				
			||||||
 | 
					        f.write("\n=== 参数聚合检查 ===\n")
 | 
				
			||||||
 | 
					    for key in MONITOR_KEYS:
 | 
				
			||||||
 | 
					        if key not in global_dict:
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					        # 计算聚合后均值
 | 
				
			||||||
 | 
					        aggregated_mean = global_dict[key].mean().item()
 | 
				
			||||||
 | 
					        # 计算各客户端均值
 | 
				
			||||||
 | 
					        client_means = [sd[key].float().mean().item() for sd in state_dicts]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with open("aggregation_check.txt", "a") as f:
 | 
				
			||||||
 | 
					            f.write(f"层 '{key}' 聚合后均值: {aggregated_mean:.6f}\n")
 | 
				
			||||||
 | 
					            f.write(f"各客户端该层均值差异: {[f'{cm:.6f}' for cm in client_means]}\n")
 | 
				
			||||||
 | 
					            f.write(f"客户端最大差异: {max(client_means) - min(client_means):.6f}\n\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return global_model
 | 
					    return global_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# ------------ 修改训练流程 ------------
 | 
					# ------------ 修改训练流程 ------------
 | 
				
			||||||
def federated_train(num_rounds, clients_data):
 | 
					def federated_train(num_rounds, clients_data):
 | 
				
			||||||
 | 
					    # ========== 新增:初始化指标记录 ==========
 | 
				
			||||||
 | 
					    metrics = {
 | 
				
			||||||
 | 
					        "round": [],
 | 
				
			||||||
 | 
					        "val_mAP": [],  # 每轮验证集mAP
 | 
				
			||||||
 | 
					        "train_loss": [],  # 每轮平均训练损失
 | 
				
			||||||
 | 
					        "client_mAPs": [],  # 各客户端本地模型在验证集上的mAP
 | 
				
			||||||
 | 
					        "communication_cost": [],  # 每轮通信开销(MB)
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # 初始化全局模型
 | 
					    # 初始化全局模型
 | 
				
			||||||
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
					    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
				
			||||||
    global_model = YOLO("../yolov8n.pt").to(device)
 | 
					    global_model = YOLO("../yolov8n.yaml").to(device)
 | 
				
			||||||
    # 设置类别数
 | 
					    # 设置类别数
 | 
				
			||||||
    global_model.model.nc = 1
 | 
					    # global_model.model.nc = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for _ in range(num_rounds):
 | 
					    for _ in range(num_rounds):
 | 
				
			||||||
        client_weights = []
 | 
					        client_weights = []
 | 
				
			||||||
 | 
					        client_losses = []  # 记录各客户端的训练损失
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # 每个客户端本地训练
 | 
					        # 每个客户端本地训练
 | 
				
			||||||
        for data_path in clients_data:
 | 
					        for data_path in clients_data:
 | 
				
			||||||
            # 统计本地训练样本数
 | 
					            # 统计本地训练样本数
 | 
				
			||||||
            with open(data_path, 'r') as f:
 | 
					            with open(data_path, "r") as f:
 | 
				
			||||||
                config = yaml.safe_load(f)
 | 
					                config = yaml.safe_load(f)
 | 
				
			||||||
            #  Resolve img_dir relative to the YAML file's location
 | 
					            #  Resolve img_dir relative to the YAML file's location
 | 
				
			||||||
            yaml_dir = os.path.dirname(data_path)
 | 
					            yaml_dir = os.path.dirname(data_path)
 | 
				
			||||||
            img_dir = os.path.join(yaml_dir, config.get('train', data_path))  # 从配置文件中获取图像目录
 | 
					            img_dir = os.path.join(
 | 
				
			||||||
 | 
					                yaml_dir, config.get("train", data_path)
 | 
				
			||||||
 | 
					            )  # 从配置文件中获取图像目录
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # print(f"Image directory: {img_dir}")
 | 
					            # print(f"Image directory: {img_dir}")
 | 
				
			||||||
            num_samples = (len(glob.glob(os.path.join(img_dir, '*.jpg'))) +
 | 
					            num_samples = (len(glob.glob(os.path.join(img_dir, "*.jpg")))
 | 
				
			||||||
                           len(glob.glob(os.path.join(img_dir, '*.png'))))
 | 
					                + len(glob.glob(os.path.join(img_dir, "*.png")))
 | 
				
			||||||
 | 
					                + len(glob.glob(os.path.join(img_dir, "*.jpeg")))
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            # print(f"Number of images: {num_samples}")
 | 
					            # print(f"Number of images: {num_samples}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # 克隆全局模型
 | 
					            # 克隆全局模型
 | 
				
			||||||
            local_model = copy.deepcopy(global_model)
 | 
					            local_model = copy.deepcopy(global_model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # 本地训练(保持你的原有参数设置)
 | 
					            # 本地训练(保持你的原有参数设置)
 | 
				
			||||||
            local_model.train(
 | 
					            results = local_model.train(
 | 
				
			||||||
                data=data_path,
 | 
					                data=data_path,
 | 
				
			||||||
                epochs=16,  # 每轮本地训练1个epoch
 | 
					                epochs=4,  # 每轮本地训练多少个epoch
 | 
				
			||||||
                save_period=16,
 | 
					                # save_period=16,
 | 
				
			||||||
                imgsz=640,  # 图像大小
 | 
					                imgsz=640,  # 图像大小
 | 
				
			||||||
                verbose=False,  # 关闭冗余输出
 | 
					                verbose=False,  # 关闭冗余输出
 | 
				
			||||||
                batch=-1
 | 
					                batch=-1,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # 记录客户端训练损失
 | 
				
			||||||
 | 
					            # client_loss = results.results_dict['train_loss']
 | 
				
			||||||
 | 
					            # client_losses.append(client_loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # 收集模型参数及样本数
 | 
					            # 收集模型参数及样本数
 | 
				
			||||||
            client_weights.append((copy.deepcopy(local_model.model.state_dict()), num_samples))
 | 
					            client_weights.append(
 | 
				
			||||||
 | 
					                (copy.deepcopy(local_model.model.state_dict()), num_samples)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # 聚合参数更新全局模型
 | 
					        # 聚合参数更新全局模型
 | 
				
			||||||
        global_model = federated_avg(global_model, client_weights)
 | 
					        global_model = federated_avg(global_model, client_weights)
 | 
				
			||||||
        print(f"Round {_ + 1}/{num_rounds} completed.")
 | 
					
 | 
				
			||||||
    return global_model
 | 
					        # ========== 评估全局模型 ==========
 | 
				
			||||||
 | 
					        # 评估全局模型在验证集上的性能
 | 
				
			||||||
 | 
					        val_results = global_model.val(
 | 
				
			||||||
 | 
					            data="/mnt/DATA/UAVdataset/data.yaml",  # 指定验证集配置文件
 | 
				
			||||||
 | 
					            imgsz=640,
 | 
				
			||||||
 | 
					            batch=-1,
 | 
				
			||||||
 | 
					            verbose=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        val_mAP = val_results.box.map  # 获取mAP@0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 计算平均训练损失
 | 
				
			||||||
 | 
					        # avg_train_loss = sum(client_losses) / len(client_losses)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 计算通信开销(假设传输全部模型参数)
 | 
				
			||||||
 | 
					        model_size = sum(p.numel() * 4 for p in global_model.model.parameters()) / (
 | 
				
			||||||
 | 
					            1024**2
 | 
				
			||||||
 | 
					        )  # MB
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 记录到指标容器
 | 
				
			||||||
 | 
					        metrics["round"].append(_ + 1)
 | 
				
			||||||
 | 
					        metrics["val_mAP"].append(val_mAP)
 | 
				
			||||||
 | 
					        # metrics['train_loss'].append(avg_train_loss)
 | 
				
			||||||
 | 
					        metrics["communication_cost"].append(model_size)
 | 
				
			||||||
 | 
					        # 打印当前轮次结果
 | 
				
			||||||
 | 
					        with open("aggregation_check.txt", "a") as f:
 | 
				
			||||||
 | 
					            f.write(f"\n[Round {_ + 1}/{num_rounds}]")
 | 
				
			||||||
 | 
					            f.write(f"Validation mAP@0.5: {val_mAP:.4f}")
 | 
				
			||||||
 | 
					            # f.write(f"Average Train Loss: {avg_train_loss:.4f}")
 | 
				
			||||||
 | 
					            f.write(f"Communication Cost: {model_size:.2f} MB\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return global_model, metrics
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# ------------ 使用示例 ------------
 | 
					# ------------ 使用示例 ------------
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # 联邦训练配置
 | 
					    # 联邦训练配置
 | 
				
			||||||
    clients_config = [
 | 
					    clients_config = [
 | 
				
			||||||
        "/root/autodl-tmp/dataset/train1/train1.yaml",  # 客户端1数据路径
 | 
					        "/mnt/DATA/uav_dataset_fed/train1/train1.yaml",  # 客户端1数据路径
 | 
				
			||||||
        "/root/autodl-tmp/dataset/train2/train2.yaml"  # 客户端2数据路径
 | 
					        "/mnt/DATA/uav_dataset_fed/train2/train2.yaml",  # 客户端2数据路径
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # 运行联邦训练
 | 
					    # 运行联邦训练
 | 
				
			||||||
    final_model = federated_train(num_rounds=10, clients_data=clients_config)
 | 
					    final_model, metrics = federated_train(num_rounds=40, clients_data=clients_config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # 保存最终模型
 | 
					    # 保存最终模型
 | 
				
			||||||
    final_model.save("yolov8n_federated.pt")
 | 
					    final_model.save("yolov8n_federated.pt")
 | 
				
			||||||
    # final_model.export(format="onnx")  # 导出为ONNX格式
 | 
					    # final_model.export(format="onnx")  # 导出为ONNX格式
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # 检查1:确认模型保存
 | 
					    with open("metrics.json", "w") as f:
 | 
				
			||||||
    # assert Path("yolov8n_federated.onnx").exists(), "模型导出失败"
 | 
					        json.dump(metrics, f, indent=4)
 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # 检查2:验证预测功能
 | 
					 | 
				
			||||||
    # results = final_model.predict("../dataset/val/images/VS_P65.jpg", save=True)
 | 
					 | 
				
			||||||
    # assert len(results[0].boxes) > 0, "预测结果异常"
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										49
									
								
								yolov8.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								yolov8.yaml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,49 @@
 | 
				
			|||||||
 | 
					# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Ultralytics YOLOv8 object detection model with P3/8 - P5/32 outputs
 | 
				
			||||||
 | 
					# Model docs: https://docs.ultralytics.com/models/yolov8
 | 
				
			||||||
 | 
					# Task docs: https://docs.ultralytics.com/tasks/detect
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Parameters
 | 
				
			||||||
 | 
					nc: 1 # number of classes
 | 
				
			||||||
 | 
					scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
 | 
				
			||||||
 | 
					  # [depth, width, max_channels]
 | 
				
			||||||
 | 
					  n: [0.33, 0.25, 1024] # YOLOv8n summary: 129 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPS
 | 
				
			||||||
 | 
					  s: [0.33, 0.50, 1024] # YOLOv8s summary: 129 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPS
 | 
				
			||||||
 | 
					  m: [0.67, 0.75, 768] # YOLOv8m summary: 169 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPS
 | 
				
			||||||
 | 
					  l: [1.00, 1.00, 512] # YOLOv8l summary: 209 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPS
 | 
				
			||||||
 | 
					  x: [1.00, 1.25, 512] # YOLOv8x summary: 209 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# YOLOv8.0n backbone
 | 
				
			||||||
 | 
					backbone:
 | 
				
			||||||
 | 
					  # [from, repeats, module, args]
 | 
				
			||||||
 | 
					  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
 | 
				
			||||||
 | 
					  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
 | 
				
			||||||
 | 
					  - [-1, 3, C2f, [128, True]]
 | 
				
			||||||
 | 
					  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
 | 
				
			||||||
 | 
					  - [-1, 6, C2f, [256, True]]
 | 
				
			||||||
 | 
					  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
 | 
				
			||||||
 | 
					  - [-1, 6, C2f, [512, True]]
 | 
				
			||||||
 | 
					  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
 | 
				
			||||||
 | 
					  - [-1, 3, C2f, [1024, True]]
 | 
				
			||||||
 | 
					  - [-1, 1, SPPF, [1024, 5]] # 9
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# YOLOv8.0n head
 | 
				
			||||||
 | 
					head:
 | 
				
			||||||
 | 
					  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
 | 
				
			||||||
 | 
					  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
 | 
				
			||||||
 | 
					  - [-1, 3, C2f, [512]] # 12
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
 | 
				
			||||||
 | 
					  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
 | 
				
			||||||
 | 
					  - [-1, 3, C2f, [256]] # 15 (P3/8-small)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  - [-1, 1, Conv, [256, 3, 2]]
 | 
				
			||||||
 | 
					  - [[-1, 12], 1, Concat, [1]] # cat head P4
 | 
				
			||||||
 | 
					  - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  - [-1, 1, Conv, [512, 3, 2]]
 | 
				
			||||||
 | 
					  - [[-1, 9], 1, Concat, [1]] # cat head P5
 | 
				
			||||||
 | 
					  - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
 | 
				
			||||||
		Reference in New Issue
	
	Block a user