2025-04-20 07:20:16 +00:00
|
|
|
|
import glob
|
|
|
|
|
import os
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
import yaml
|
|
|
|
|
from ultralytics import YOLO
|
|
|
|
|
import copy
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ------------ 新增联邦学习工具函数 ------------
|
|
|
|
|
def federated_avg(global_model, client_weights):
|
|
|
|
|
"""联邦平均核心算法"""
|
|
|
|
|
# 计算总样本数
|
|
|
|
|
total_samples = sum(n for _, n in client_weights)
|
|
|
|
|
if total_samples == 0:
|
|
|
|
|
raise ValueError("Total number of samples must be positive.")
|
|
|
|
|
|
|
|
|
|
# 获取YOLO底层PyTorch模型参数
|
|
|
|
|
global_dict = global_model.model.state_dict()
|
|
|
|
|
# 提取所有客户端的 state_dict 和对应样本数
|
|
|
|
|
state_dicts, sample_counts = zip(*client_weights)
|
|
|
|
|
|
|
|
|
|
for key in global_dict:
|
|
|
|
|
# 对每一层参数取平均
|
|
|
|
|
# if global_dict[key].data.dtype == torch.float32:
|
|
|
|
|
# global_dict[key].data = torch.stack(
|
|
|
|
|
# [w[key].float() for w in client_weights], 0
|
|
|
|
|
# ).mean(0)
|
|
|
|
|
|
|
|
|
|
# 加权平均
|
|
|
|
|
if global_dict[key].dtype == torch.float32: # 只聚合浮点型参数
|
|
|
|
|
# 跳过 BatchNorm 层的统计量
|
|
|
|
|
if any(x in key for x in ['running_mean', 'running_var', 'num_batches_tracked']):
|
|
|
|
|
continue
|
|
|
|
|
# 按照样本数加权求和
|
|
|
|
|
weighted_tensors = [sd[key].float() * (n / total_samples)
|
|
|
|
|
for sd, n in zip(state_dicts, sample_counts)]
|
|
|
|
|
global_dict[key] = torch.stack(weighted_tensors, dim=0).sum(dim=0)
|
|
|
|
|
|
|
|
|
|
# 解决模型参数不匹配问题
|
|
|
|
|
try:
|
|
|
|
|
# 加载回YOLO模型
|
|
|
|
|
global_model.model.load_state_dict(global_dict)
|
|
|
|
|
except RuntimeError as e:
|
|
|
|
|
print('Ignoring "' + str(e) + '"')
|
|
|
|
|
|
|
|
|
|
# 添加调试输出
|
|
|
|
|
print("\n=== 参数聚合检查 ===")
|
|
|
|
|
|
|
|
|
|
# 选取一个典型参数层
|
|
|
|
|
# sample_key = list(global_dict.keys())[10]
|
|
|
|
|
# original = global_dict[sample_key].data.mean().item()
|
|
|
|
|
# aggregated = torch.stack([w[sample_key] for w in client_weights]).mean().item()
|
|
|
|
|
# print(f"参数层 '{sample_key}' 变化: {original:.4f} → {aggregated:.4f}")
|
|
|
|
|
# print(f"客户端参数差异: {[w[sample_key].mean().item() for w in client_weights]}")
|
|
|
|
|
|
|
|
|
|
# 随机选取一个非统计量层进行对比
|
|
|
|
|
sample_key = next(k for k in global_dict if 'running_' not in k)
|
|
|
|
|
aggregated_mean = global_dict[sample_key].mean().item()
|
|
|
|
|
client_means = [sd[sample_key].float().mean().item() for sd in state_dicts]
|
2025-04-22 06:56:45 +00:00
|
|
|
|
print(f"layer: '{sample_key}' Mean after aggregation: {aggregated_mean:.6f}")
|
|
|
|
|
print(f"The average value of the layer for each client: {client_means}")
|
2025-04-20 07:20:16 +00:00
|
|
|
|
|
|
|
|
|
return global_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ------------ 修改训练流程 ------------
|
|
|
|
|
def federated_train(num_rounds, clients_data):
|
|
|
|
|
# 初始化全局模型
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
global_model = YOLO("yolov8n.pt").to(device)
|
|
|
|
|
# 设置类别数
|
2025-04-21 16:19:43 +00:00
|
|
|
|
global_model.model.nc = 1
|
2025-04-20 07:20:16 +00:00
|
|
|
|
|
|
|
|
|
for _ in range(num_rounds):
|
|
|
|
|
client_weights = []
|
|
|
|
|
|
|
|
|
|
# 每个客户端本地训练
|
|
|
|
|
for data_path in clients_data:
|
|
|
|
|
# 统计本地训练样本数
|
|
|
|
|
with open(data_path, 'r') as f:
|
|
|
|
|
config = yaml.safe_load(f)
|
|
|
|
|
# Resolve img_dir relative to the YAML file's location
|
|
|
|
|
yaml_dir = os.path.dirname(data_path)
|
|
|
|
|
img_dir = os.path.join(yaml_dir, config.get('train', data_path)) # 从配置文件中获取图像目录
|
|
|
|
|
|
2025-04-22 06:56:45 +00:00
|
|
|
|
# print(f"Image directory: {img_dir}")
|
2025-04-21 16:19:43 +00:00
|
|
|
|
num_samples = (len(glob.glob(os.path.join(img_dir, '*.jpg'))) +
|
|
|
|
|
len(glob.glob(os.path.join(img_dir, '*.png'))))
|
2025-04-22 06:56:45 +00:00
|
|
|
|
# print(f"Number of images: {num_samples}")
|
2025-04-20 07:20:16 +00:00
|
|
|
|
|
|
|
|
|
# 克隆全局模型
|
|
|
|
|
local_model = copy.deepcopy(global_model)
|
|
|
|
|
|
|
|
|
|
# 本地训练(保持你的原有参数设置)
|
|
|
|
|
local_model.train(
|
|
|
|
|
data=data_path,
|
2025-04-22 06:56:45 +00:00
|
|
|
|
epochs=16, # 每轮本地训练1个epoch
|
|
|
|
|
save_period=16,
|
2025-04-21 15:50:41 +00:00
|
|
|
|
imgsz=640, # 图像大小
|
2025-04-22 06:56:45 +00:00
|
|
|
|
verbose=False, # 关闭冗余输出
|
|
|
|
|
batch=-1
|
2025-04-20 07:20:16 +00:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 收集模型参数及样本数
|
|
|
|
|
client_weights.append((copy.deepcopy(local_model.model.state_dict()), num_samples))
|
|
|
|
|
|
|
|
|
|
# 聚合参数更新全局模型
|
|
|
|
|
global_model = federated_avg(global_model, client_weights)
|
2025-04-22 06:56:45 +00:00
|
|
|
|
print(f"Round {_ + 1}/{num_rounds} completed.")
|
2025-04-20 07:20:16 +00:00
|
|
|
|
return global_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ------------ 使用示例 ------------
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
# 联邦训练配置
|
|
|
|
|
clients_config = [
|
2025-04-22 06:56:45 +00:00
|
|
|
|
"/root/autodl-tmp/dataset/train1/train1.yaml", # 客户端1数据路径
|
|
|
|
|
"/root/autodl-tmp/dataset/train2/train2.yaml" # 客户端2数据路径
|
2025-04-20 07:20:16 +00:00
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# 运行联邦训练
|
2025-04-22 06:56:45 +00:00
|
|
|
|
final_model = federated_train(num_rounds=10, clients_data=clients_config)
|
2025-04-20 07:20:16 +00:00
|
|
|
|
|
|
|
|
|
# 保存最终模型
|
2025-04-22 06:56:45 +00:00
|
|
|
|
final_model.save("yolov8n_federated.pt")
|
2025-04-20 07:20:16 +00:00
|
|
|
|
# final_model.export(format="onnx") # 导出为ONNX格式
|
|
|
|
|
|
|
|
|
|
# 检查1:确认模型保存
|
|
|
|
|
# assert Path("yolov8n_federated.onnx").exists(), "模型导出失败"
|
|
|
|
|
|
|
|
|
|
# 检查2:验证预测功能
|
2025-04-21 16:19:43 +00:00
|
|
|
|
# results = final_model.predict("../dataset/val/images/VS_P65.jpg", save=True)
|
2025-04-20 07:20:16 +00:00
|
|
|
|
# assert len(results[0].boxes) > 0, "预测结果异常"
|