Graduation-Project/federated_learning/yolov8_fed.py

132 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import glob
import os
from pathlib import Path
import yaml
from ultralytics import YOLO
import copy
import torch
# ------------ 新增联邦学习工具函数 ------------
def federated_avg(global_model, client_weights):
"""联邦平均核心算法"""
# 计算总样本数
total_samples = sum(n for _, n in client_weights)
if total_samples == 0:
raise ValueError("Total number of samples must be positive.")
# 获取YOLO底层PyTorch模型参数
global_dict = global_model.model.state_dict()
# 提取所有客户端的 state_dict 和对应样本数
state_dicts, sample_counts = zip(*client_weights)
for key in global_dict:
# 对每一层参数取平均
# if global_dict[key].data.dtype == torch.float32:
# global_dict[key].data = torch.stack(
# [w[key].float() for w in client_weights], 0
# ).mean(0)
# 加权平均
if global_dict[key].dtype == torch.float32: # 只聚合浮点型参数
# 跳过 BatchNorm 层的统计量
if any(x in key for x in ['running_mean', 'running_var', 'num_batches_tracked']):
continue
# 按照样本数加权求和
weighted_tensors = [sd[key].float() * (n / total_samples)
for sd, n in zip(state_dicts, sample_counts)]
global_dict[key] = torch.stack(weighted_tensors, dim=0).sum(dim=0)
# 解决模型参数不匹配问题
try:
# 加载回YOLO模型
global_model.model.load_state_dict(global_dict)
except RuntimeError as e:
print('Ignoring "' + str(e) + '"')
# 添加调试输出
print("\n=== 参数聚合检查 ===")
# 选取一个典型参数层
# sample_key = list(global_dict.keys())[10]
# original = global_dict[sample_key].data.mean().item()
# aggregated = torch.stack([w[sample_key] for w in client_weights]).mean().item()
# print(f"参数层 '{sample_key}' 变化: {original:.4f} → {aggregated:.4f}")
# print(f"客户端参数差异: {[w[sample_key].mean().item() for w in client_weights]}")
# 随机选取一个非统计量层进行对比
sample_key = next(k for k in global_dict if 'running_' not in k)
aggregated_mean = global_dict[sample_key].mean().item()
client_means = [sd[sample_key].float().mean().item() for sd in state_dicts]
print(f"'{sample_key}' 聚合后均值: {aggregated_mean:.6f}")
print(f"各客户端该层均值: {client_means}")
return global_model
# ------------ 修改训练流程 ------------
def federated_train(num_rounds, clients_data):
# 初始化全局模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
global_model = YOLO("yolov8n.pt").to(device)
# 设置类别数
global_model.model.nc = 2
for _ in range(num_rounds):
client_weights = []
# 每个客户端本地训练
for data_path in clients_data:
# 统计本地训练样本数
with open(data_path, 'r') as f:
config = yaml.safe_load(f)
# Resolve img_dir relative to the YAML file's location
yaml_dir = os.path.dirname(data_path)
img_dir = os.path.join(yaml_dir, config.get('train', data_path)) # 从配置文件中获取图像目录
print(f"Image directory: {img_dir}")
num_samples = len(glob.glob(os.path.join(img_dir, '*.jpg')))
print(f"Number of images: {num_samples}")
# 克隆全局模型
local_model = copy.deepcopy(global_model)
# 本地训练(保持你的原有参数设置)
local_model.train(
data=data_path,
epochs=1, # 每轮本地训练1个epoch
imgsz=128, # 图像大小
verbose=False # 关闭冗余输出
)
# 收集模型参数及样本数
client_weights.append((copy.deepcopy(local_model.model.state_dict()), num_samples))
# 聚合参数更新全局模型
global_model = federated_avg(global_model, client_weights)
return global_model
# ------------ 使用示例 ------------
if __name__ == "__main__":
# 联邦训练配置
clients_config = [
"./config/client1_data.yaml", # 客户端1数据路径
"./config/client2_data.yaml" # 客户端2数据路径
]
# 运行联邦训练
final_model = federated_train(num_rounds=1, clients_data=clients_config)
# 保存最终模型
# final_model.export(format="onnx") # 导出为ONNX格式
# 检查1确认模型保存
# assert Path("yolov8n_federated.onnx").exists(), "模型导出失败"
# 检查2验证预测功能
# results = final_model.predict("test_data/client1/train/images/img1.jpg")
# assert len(results[0].boxes) > 0, "预测结果异常"