diff --git a/federated_learning/yolov8_fed.py b/federated_learning/yolov8_fed.py new file mode 100644 index 0000000..c374fa4 --- /dev/null +++ b/federated_learning/yolov8_fed.py @@ -0,0 +1,131 @@ +import glob +import os +from pathlib import Path + +import yaml +from ultralytics import YOLO +import copy +import torch + + +# ------------ 新增联邦学习工具函数 ------------ +def federated_avg(global_model, client_weights): + """联邦平均核心算法""" + # 计算总样本数 + total_samples = sum(n for _, n in client_weights) + if total_samples == 0: + raise ValueError("Total number of samples must be positive.") + + # 获取YOLO底层PyTorch模型参数 + global_dict = global_model.model.state_dict() + # 提取所有客户端的 state_dict 和对应样本数 + state_dicts, sample_counts = zip(*client_weights) + + for key in global_dict: + # 对每一层参数取平均 + # if global_dict[key].data.dtype == torch.float32: + # global_dict[key].data = torch.stack( + # [w[key].float() for w in client_weights], 0 + # ).mean(0) + + # 加权平均 + if global_dict[key].dtype == torch.float32: # 只聚合浮点型参数 + # 跳过 BatchNorm 层的统计量 + if any(x in key for x in ['running_mean', 'running_var', 'num_batches_tracked']): + continue + # 按照样本数加权求和 + weighted_tensors = [sd[key].float() * (n / total_samples) + for sd, n in zip(state_dicts, sample_counts)] + global_dict[key] = torch.stack(weighted_tensors, dim=0).sum(dim=0) + + # 解决模型参数不匹配问题 + try: + # 加载回YOLO模型 + global_model.model.load_state_dict(global_dict) + except RuntimeError as e: + print('Ignoring "' + str(e) + '"') + + # 添加调试输出 + print("\n=== 参数聚合检查 ===") + + # 选取一个典型参数层 + # sample_key = list(global_dict.keys())[10] + # original = global_dict[sample_key].data.mean().item() + # aggregated = torch.stack([w[sample_key] for w in client_weights]).mean().item() + # print(f"参数层 '{sample_key}' 变化: {original:.4f} → {aggregated:.4f}") + # print(f"客户端参数差异: {[w[sample_key].mean().item() for w in client_weights]}") + + # 随机选取一个非统计量层进行对比 + sample_key = next(k for k in global_dict if 'running_' not in k) + aggregated_mean = global_dict[sample_key].mean().item() + client_means = [sd[sample_key].float().mean().item() for sd in state_dicts] + print(f"层 '{sample_key}' 聚合后均值: {aggregated_mean:.6f}") + print(f"各客户端该层均值: {client_means}") + + return global_model + + +# ------------ 修改训练流程 ------------ +def federated_train(num_rounds, clients_data): + # 初始化全局模型 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + global_model = YOLO("yolov8n.pt").to(device) + # 设置类别数 + global_model.model.nc = 2 + + for _ in range(num_rounds): + client_weights = [] + + # 每个客户端本地训练 + for data_path in clients_data: + # 统计本地训练样本数 + with open(data_path, 'r') as f: + config = yaml.safe_load(f) + # Resolve img_dir relative to the YAML file's location + yaml_dir = os.path.dirname(data_path) + img_dir = os.path.join(yaml_dir, config.get('train', data_path)) # 从配置文件中获取图像目录 + + print(f"Image directory: {img_dir}") + num_samples = len(glob.glob(os.path.join(img_dir, '*.jpg'))) + print(f"Number of images: {num_samples}") + + # 克隆全局模型 + local_model = copy.deepcopy(global_model) + + # 本地训练(保持你的原有参数设置) + local_model.train( + data=data_path, + epochs=1, # 每轮本地训练1个epoch + imgsz=128, # 图像大小 + verbose=False # 关闭冗余输出 + ) + + # 收集模型参数及样本数 + client_weights.append((copy.deepcopy(local_model.model.state_dict()), num_samples)) + + # 聚合参数更新全局模型 + global_model = federated_avg(global_model, client_weights) + + return global_model + + +# ------------ 使用示例 ------------ +if __name__ == "__main__": + # 联邦训练配置 + clients_config = [ + "./config/client1_data.yaml", # 客户端1数据路径 + "./config/client2_data.yaml" # 客户端2数据路径 + ] + + # 运行联邦训练 + final_model = federated_train(num_rounds=1, clients_data=clients_config) + + # 保存最终模型 + # final_model.export(format="onnx") # 导出为ONNX格式 + + # 检查1:确认模型保存 + # assert Path("yolov8n_federated.onnx").exists(), "模型导出失败" + + # 检查2:验证预测功能 + # results = final_model.predict("test_data/client1/train/images/img1.jpg") + # assert len(results[0].boxes) > 0, "预测结果异常"