import glob
import os
from pathlib import Path
import json

import yaml
from ultralytics import YOLO
import copy
import torch


# ------------ 新增联邦学习工具函数 ------------
def federated_avg(global_model, client_weights):
    """联邦平均核心算法"""
    # 计算总样本数
    total_samples = sum(n for _, n in client_weights)
    if total_samples == 0:
        raise ValueError("Total number of samples must be positive.")

    # 获取YOLO底层PyTorch模型参数
    global_dict = global_model.model.state_dict()
    # 提取所有客户端的 state_dict 和对应样本数
    state_dicts, sample_counts = zip(*client_weights)

    for key in global_dict:
        # 对每一层参数取平均
        # if global_dict[key].data.dtype == torch.float32:
        #     global_dict[key].data = torch.stack(
        #         [w[key].float() for w in client_weights], 0
        #     ).mean(0)

        # 加权平均
        if global_dict[key].dtype == torch.float32:  # 只聚合浮点型参数
            # 跳过 BatchNorm 层的统计量
            if any(
                x in key for x in ["running_mean", "running_var", "num_batches_tracked"]
            ):
                continue
            # 按照样本数加权求和
            weighted_tensors = [
                sd[key].float() * (n / total_samples)
                for sd, n in zip(state_dicts, sample_counts)
            ]
            global_dict[key] = torch.stack(weighted_tensors, dim=0).sum(dim=0)

        # 解决模型参数不匹配问题
        # try:
        #     # 加载回YOLO模型
        #     global_model.model.load_state_dict(global_dict)
        # except RuntimeError as e:
        #     print('Ignoring "' + str(e) + '"')

        # 加载回YOLO模型
        global_model.model.load_state_dict(global_dict)

    # 随机选取一个非统计量层进行对比
    # sample_key = next(k for k in global_dict if 'running_' not in k)
    # aggregated_mean = global_dict[sample_key].mean().item()
    # client_means = [sd[sample_key].float().mean().item() for sd in state_dicts]
    # print(f"layer: '{sample_key}' Mean after aggregation: {aggregated_mean:.6f}")
    # print(f"The average value of the layer for each client: {client_means}")

    # 定义多个关键层
    MONITOR_KEYS = [
        "model.0.conv.weight",  # 输入层卷积
        "model.10.conv.weight",  # 中间层卷积
        "model.22.dfl.conv.weight",  # 输出层分类头
    ]

    with open("aggregation_check.txt", "a") as f:
        f.write("\n=== 参数聚合检查 ===\n")
    for key in MONITOR_KEYS:
        if key not in global_dict:
            continue
        # 计算聚合后均值
        aggregated_mean = global_dict[key].mean().item()
        # 计算各客户端均值
        client_means = [sd[key].float().mean().item() for sd in state_dicts]

        with open("aggregation_check.txt", "a") as f:
            f.write(f"层 '{key}' 聚合后均值: {aggregated_mean:.6f}\n")
            f.write(f"各客户端该层均值差异: {[f'{cm:.6f}' for cm in client_means]}\n")
            f.write(f"客户端最大差异: {max(client_means) - min(client_means):.6f}\n\n")

    return global_model


# ------------ 修改训练流程 ------------
def federated_train(num_rounds, clients_data):
    # ========== 新增:初始化指标记录 ==========
    metrics = {
        "round": [],
        "val_mAP": [],  # 每轮验证集mAP
        "train_loss": [],  # 每轮平均训练损失
        "client_mAPs": [],  # 各客户端本地模型在验证集上的mAP
        "communication_cost": [],  # 每轮通信开销(MB)
    }

    # 初始化全局模型
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    global_model = YOLO("../yolov8n.yaml").to(device)
    # 设置类别数
    # global_model.model.nc = 1

    for _ in range(num_rounds):
        client_weights = []
        client_losses = []  # 记录各客户端的训练损失

        # 每个客户端本地训练
        for data_path in clients_data:
            # 统计本地训练样本数
            with open(data_path, "r") as f:
                config = yaml.safe_load(f)
            #  Resolve img_dir relative to the YAML file's location
            yaml_dir = os.path.dirname(data_path)
            img_dir = os.path.join(
                yaml_dir, config.get("train", data_path)
            )  # 从配置文件中获取图像目录

            # print(f"Image directory: {img_dir}")
            num_samples = (len(glob.glob(os.path.join(img_dir, "*.jpg")))
                + len(glob.glob(os.path.join(img_dir, "*.png")))
                + len(glob.glob(os.path.join(img_dir, "*.jpeg")))
            )
            # print(f"Number of images: {num_samples}")

            # 克隆全局模型
            local_model = copy.deepcopy(global_model)

            # 本地训练(保持你的原有参数设置)
            results = local_model.train(
                data=data_path,
                epochs=4,  # 每轮本地训练多少个epoch
                # save_period=16,
                imgsz=640,  # 图像大小
                verbose=False,  # 关闭冗余输出
                batch=-1,
            )

            # 记录客户端训练损失
            # client_loss = results.results_dict['train_loss']
            # client_losses.append(client_loss)

            # 收集模型参数及样本数
            client_weights.append(
                (copy.deepcopy(local_model.model.state_dict()), num_samples)
            )

        # 聚合参数更新全局模型
        global_model = federated_avg(global_model, client_weights)

        # ========== 评估全局模型 ==========
        # 评估全局模型在验证集上的性能
        val_results = global_model.val(
            data="/mnt/DATA/UAVdataset/data.yaml",  # 指定验证集配置文件
            imgsz=640,
            batch=-1,
            verbose=False,
        )
        val_mAP = val_results.box.map  # 获取mAP@0.5

        # 计算平均训练损失
        # avg_train_loss = sum(client_losses) / len(client_losses)

        # 计算通信开销(假设传输全部模型参数)
        model_size = sum(p.numel() * 4 for p in global_model.model.parameters()) / (
            1024**2
        )  # MB

        # 记录到指标容器
        metrics["round"].append(_ + 1)
        metrics["val_mAP"].append(val_mAP)
        # metrics['train_loss'].append(avg_train_loss)
        metrics["communication_cost"].append(model_size)
        # 打印当前轮次结果
        with open("aggregation_check.txt", "a") as f:
            f.write(f"\n[Round {_ + 1}/{num_rounds}]")
            f.write(f"Validation mAP@0.5: {val_mAP:.4f}")
            # f.write(f"Average Train Loss: {avg_train_loss:.4f}")
            f.write(f"Communication Cost: {model_size:.2f} MB\n")

    return global_model, metrics


# ------------ 使用示例 ------------
if __name__ == "__main__":
    # 联邦训练配置
    clients_config = [
        "/mnt/DATA/uav_dataset_fed/train1/train1.yaml",  # 客户端1数据路径
        "/mnt/DATA/uav_dataset_fed/train2/train2.yaml",  # 客户端2数据路径
    ]

    # 运行联邦训练
    final_model, metrics = federated_train(num_rounds=40, clients_data=clients_config)

    # 保存最终模型
    final_model.save("yolov8n_federated.pt")
    # final_model.export(format="onnx")  # 导出为ONNX格式

    with open("metrics.json", "w") as f:
        json.dump(metrics, f, indent=4)