import glob
import os
from pathlib import Path

import yaml
from ultralytics import YOLO
import copy
import torch


# ------------ 新增联邦学习工具函数 ------------
def federated_avg(global_model, client_weights):
    """联邦平均核心算法"""
    # 计算总样本数
    total_samples = sum(n for _, n in client_weights)
    if total_samples == 0:
        raise ValueError("Total number of samples must be positive.")
    
    # 获取YOLO底层PyTorch模型参数
    global_dict = global_model.model.state_dict()
    # 提取所有客户端的 state_dict 和对应样本数
    state_dicts, sample_counts = zip(*client_weights)
    
    for key in global_dict:
        # 对每一层参数取平均
        # if global_dict[key].data.dtype == torch.float32:
        #     global_dict[key].data = torch.stack(
        #         [w[key].float() for w in client_weights], 0
        #     ).mean(0)
        
        # 加权平均
        if global_dict[key].dtype == torch.float32:  # 只聚合浮点型参数
            # 跳过 BatchNorm 层的统计量
            if any(x in key for x in ['running_mean', 'running_var', 'num_batches_tracked']):
                continue
            # 按照样本数加权求和
            weighted_tensors = [sd[key].float() * (n / total_samples)
                                for sd, n in zip(state_dicts, sample_counts)]
            global_dict[key] = torch.stack(weighted_tensors, dim=0).sum(dim=0)
    
    # 解决模型参数不匹配问题
    try:
        # 加载回YOLO模型
        global_model.model.load_state_dict(global_dict)
    except RuntimeError as e:
        print('Ignoring "' + str(e) + '"')
    
    # 添加调试输出
    print("\n=== 参数聚合检查 ===")
    
    # 选取一个典型参数层
    # sample_key = list(global_dict.keys())[10]
    # original = global_dict[sample_key].data.mean().item()
    # aggregated = torch.stack([w[sample_key] for w in client_weights]).mean().item()
    # print(f"参数层 '{sample_key}' 变化: {original:.4f} → {aggregated:.4f}")
    # print(f"客户端参数差异: {[w[sample_key].mean().item() for w in client_weights]}")
    
    # 随机选取一个非统计量层进行对比
    sample_key = next(k for k in global_dict if 'running_' not in k)
    aggregated_mean = global_dict[sample_key].mean().item()
    client_means = [sd[sample_key].float().mean().item() for sd in state_dicts]
    print(f"层 '{sample_key}' 聚合后均值: {aggregated_mean:.6f}")
    print(f"各客户端该层均值: {client_means}")
    
    return global_model


# ------------ 修改训练流程 ------------
def federated_train(num_rounds, clients_data):
    # 初始化全局模型
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    global_model = YOLO("yolov8n.pt").to(device)
    # 设置类别数
    global_model.model.nc = 1
    
    for _ in range(num_rounds):
        client_weights = []
        
        # 每个客户端本地训练
        for data_path in clients_data:
            # 统计本地训练样本数
            with open(data_path, 'r') as f:
                config = yaml.safe_load(f)
            #  Resolve img_dir relative to the YAML file's location
            yaml_dir = os.path.dirname(data_path)
            img_dir = os.path.join(yaml_dir, config.get('train', data_path))  # 从配置文件中获取图像目录
            
            print(f"Image directory: {img_dir}")
            num_samples = (len(glob.glob(os.path.join(img_dir, '*.jpg'))) +
                           len(glob.glob(os.path.join(img_dir, '*.png'))))
            print(f"Number of images: {num_samples}")
            
            # 克隆全局模型
            local_model = copy.deepcopy(global_model)
            
            # 本地训练(保持你的原有参数设置)
            local_model.train(
                data=data_path,
                epochs=1,  # 每轮本地训练1个epoch
                imgsz=640,  # 图像大小
                verbose=False  # 关闭冗余输出
            )
            
            # 收集模型参数及样本数
            client_weights.append((copy.deepcopy(local_model.model.state_dict()), num_samples))
        
        # 聚合参数更新全局模型
        global_model = federated_avg(global_model, client_weights)
    
    return global_model


# ------------ 使用示例 ------------
if __name__ == "__main__":
    # 联邦训练配置
    clients_config = [
        "../dataset/train1/train1.yaml",  # 客户端1数据路径
        "../dataset/train2/train2.yaml"  # 客户端2数据路径
    ]
    
    # 运行联邦训练
    final_model = federated_train(num_rounds=1, clients_data=clients_config)
    
    # 保存最终模型
    # final_model.export(format="onnx")  # 导出为ONNX格式
    
    # 检查1:确认模型保存
    # assert Path("yolov8n_federated.onnx").exists(), "模型导出失败"
    
    # 检查2:验证预测功能
    # results = final_model.predict("../dataset/val/images/VS_P65.jpg", save=True)
    # assert len(results[0].boxes) > 0, "预测结果异常"