联邦平均算法:结合yolov8
This commit is contained in:
		
							
								
								
									
										131
									
								
								federated_learning/yolov8_fed.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										131
									
								
								federated_learning/yolov8_fed.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,131 @@ | |||||||
|  | import glob | ||||||
|  | import os | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | import yaml | ||||||
|  | from ultralytics import YOLO | ||||||
|  | import copy | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------ 新增联邦学习工具函数 ------------ | ||||||
|  | def federated_avg(global_model, client_weights): | ||||||
|  |     """联邦平均核心算法""" | ||||||
|  |     # 计算总样本数 | ||||||
|  |     total_samples = sum(n for _, n in client_weights) | ||||||
|  |     if total_samples == 0: | ||||||
|  |         raise ValueError("Total number of samples must be positive.") | ||||||
|  |      | ||||||
|  |     # 获取YOLO底层PyTorch模型参数 | ||||||
|  |     global_dict = global_model.model.state_dict() | ||||||
|  |     # 提取所有客户端的 state_dict 和对应样本数 | ||||||
|  |     state_dicts, sample_counts = zip(*client_weights) | ||||||
|  |      | ||||||
|  |     for key in global_dict: | ||||||
|  |         # 对每一层参数取平均 | ||||||
|  |         # if global_dict[key].data.dtype == torch.float32: | ||||||
|  |         #     global_dict[key].data = torch.stack( | ||||||
|  |         #         [w[key].float() for w in client_weights], 0 | ||||||
|  |         #     ).mean(0) | ||||||
|  |          | ||||||
|  |         # 加权平均 | ||||||
|  |         if global_dict[key].dtype == torch.float32:  # 只聚合浮点型参数 | ||||||
|  |             # 跳过 BatchNorm 层的统计量 | ||||||
|  |             if any(x in key for x in ['running_mean', 'running_var', 'num_batches_tracked']): | ||||||
|  |                 continue | ||||||
|  |             # 按照样本数加权求和 | ||||||
|  |             weighted_tensors = [sd[key].float() * (n / total_samples) | ||||||
|  |                                 for sd, n in zip(state_dicts, sample_counts)] | ||||||
|  |             global_dict[key] = torch.stack(weighted_tensors, dim=0).sum(dim=0) | ||||||
|  |      | ||||||
|  |     # 解决模型参数不匹配问题 | ||||||
|  |     try: | ||||||
|  |         # 加载回YOLO模型 | ||||||
|  |         global_model.model.load_state_dict(global_dict) | ||||||
|  |     except RuntimeError as e: | ||||||
|  |         print('Ignoring "' + str(e) + '"') | ||||||
|  |      | ||||||
|  |     # 添加调试输出 | ||||||
|  |     print("\n=== 参数聚合检查 ===") | ||||||
|  |      | ||||||
|  |     # 选取一个典型参数层 | ||||||
|  |     # sample_key = list(global_dict.keys())[10] | ||||||
|  |     # original = global_dict[sample_key].data.mean().item() | ||||||
|  |     # aggregated = torch.stack([w[sample_key] for w in client_weights]).mean().item() | ||||||
|  |     # print(f"参数层 '{sample_key}' 变化: {original:.4f} → {aggregated:.4f}") | ||||||
|  |     # print(f"客户端参数差异: {[w[sample_key].mean().item() for w in client_weights]}") | ||||||
|  |      | ||||||
|  |     # 随机选取一个非统计量层进行对比 | ||||||
|  |     sample_key = next(k for k in global_dict if 'running_' not in k) | ||||||
|  |     aggregated_mean = global_dict[sample_key].mean().item() | ||||||
|  |     client_means = [sd[sample_key].float().mean().item() for sd in state_dicts] | ||||||
|  |     print(f"层 '{sample_key}' 聚合后均值: {aggregated_mean:.6f}") | ||||||
|  |     print(f"各客户端该层均值: {client_means}") | ||||||
|  |      | ||||||
|  |     return global_model | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------ 修改训练流程 ------------ | ||||||
|  | def federated_train(num_rounds, clients_data): | ||||||
|  |     # 初始化全局模型 | ||||||
|  |     device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||||
|  |     global_model = YOLO("yolov8n.pt").to(device) | ||||||
|  |     # 设置类别数 | ||||||
|  |     global_model.model.nc = 2 | ||||||
|  |      | ||||||
|  |     for _ in range(num_rounds): | ||||||
|  |         client_weights = [] | ||||||
|  |          | ||||||
|  |         # 每个客户端本地训练 | ||||||
|  |         for data_path in clients_data: | ||||||
|  |             # 统计本地训练样本数 | ||||||
|  |             with open(data_path, 'r') as f: | ||||||
|  |                 config = yaml.safe_load(f) | ||||||
|  |             #  Resolve img_dir relative to the YAML file's location | ||||||
|  |             yaml_dir = os.path.dirname(data_path) | ||||||
|  |             img_dir = os.path.join(yaml_dir, config.get('train', data_path))  # 从配置文件中获取图像目录 | ||||||
|  |              | ||||||
|  |             print(f"Image directory: {img_dir}") | ||||||
|  |             num_samples = len(glob.glob(os.path.join(img_dir, '*.jpg'))) | ||||||
|  |             print(f"Number of images: {num_samples}") | ||||||
|  |              | ||||||
|  |             # 克隆全局模型 | ||||||
|  |             local_model = copy.deepcopy(global_model) | ||||||
|  |              | ||||||
|  |             # 本地训练(保持你的原有参数设置) | ||||||
|  |             local_model.train( | ||||||
|  |                 data=data_path, | ||||||
|  |                 epochs=1,  # 每轮本地训练1个epoch | ||||||
|  |                 imgsz=128,  # 图像大小 | ||||||
|  |                 verbose=False  # 关闭冗余输出 | ||||||
|  |             ) | ||||||
|  |              | ||||||
|  |             # 收集模型参数及样本数 | ||||||
|  |             client_weights.append((copy.deepcopy(local_model.model.state_dict()), num_samples)) | ||||||
|  |          | ||||||
|  |         # 聚合参数更新全局模型 | ||||||
|  |         global_model = federated_avg(global_model, client_weights) | ||||||
|  |      | ||||||
|  |     return global_model | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------ 使用示例 ------------ | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     # 联邦训练配置 | ||||||
|  |     clients_config = [ | ||||||
|  |         "./config/client1_data.yaml",  # 客户端1数据路径 | ||||||
|  |         "./config/client2_data.yaml"  # 客户端2数据路径 | ||||||
|  |     ] | ||||||
|  |      | ||||||
|  |     # 运行联邦训练 | ||||||
|  |     final_model = federated_train(num_rounds=1, clients_data=clients_config) | ||||||
|  |      | ||||||
|  |     # 保存最终模型 | ||||||
|  |     # final_model.export(format="onnx")  # 导出为ONNX格式 | ||||||
|  |      | ||||||
|  |     # 检查1:确认模型保存 | ||||||
|  |     # assert Path("yolov8n_federated.onnx").exists(), "模型导出失败" | ||||||
|  |      | ||||||
|  |     # 检查2:验证预测功能 | ||||||
|  |     # results = final_model.predict("test_data/client1/train/images/img1.jpg") | ||||||
|  |     # assert len(results[0].boxes) > 0, "预测结果异常" | ||||||
		Reference in New Issue
	
	Block a user