2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#!/usr/bin/env python3
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import os
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import json
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import yaml
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import time
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from tqdm import tqdm
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import torch
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:30:45 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								from utils.fed_util import build_valset_if_available, seed_everything, plot_curves
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from fed_algo_cs.client_base import FedYoloClient
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from fed_algo_cs.server_base import FedYoloServer
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-03 20:23:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								from utils.args import args_parser  # args parser
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from utils.fed_util import divide_trainset  # divide_trainset
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								def fed_run():
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    """
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    Main FL process:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        - Initialize clients & server
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        - For each round: sequential local training -> record -> select -> aggregate
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        - Test & flush
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        - Record & save results, plot curves
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    """
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    args_cli = args_parser()
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:30:45 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    # TODO: cfg and params should not be separately defined
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    with open(args_cli.config, "r", encoding="utf-8") as f:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        cfg = yaml.safe_load(f)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # --- params / config normalization ---
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # For convenience we pass the same `params` dict used by Dataset/model/loss.
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # Here we re-use the top-level cfg directly as params.
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:30:45 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    # params = dict(cfg)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-23 13:07:04 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    if "names" in cfg and isinstance(cfg["names"], dict):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # Convert {0: 'uav', 1: 'car', ...} to list if you prefer list
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # but we can leave dict; your utils appear to accept dict
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        pass
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # seeds
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:30:45 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    seed_everything(int(cfg.get("i_seed", 0)))
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # --- split clients' train data from a global train list ---
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # Expect either cfg["train_txt"] or <dataset_path>/train.txt
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    train_txt = cfg.get("train_txt", "")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    if not train_txt:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        ds_root = cfg.get("dataset_path", "")
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-23 13:07:04 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        guess = os.path.join(ds_root, "train2017.txt") if ds_root else ""
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        train_txt = guess
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    if not train_txt or not os.path.exists(train_txt):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        raise FileNotFoundError(
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-23 13:07:04 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            f"train2017.txt not found. Provide --config with 'train_txt' or ensure '{train_txt}' exists."
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    split = divide_trainset(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        trainset_path=train_txt,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        num_local_class=int(cfg.get("num_local_class", 1)),
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        num_client=int(cfg.get("num_client", 64)),
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        min_data=int(cfg.get("min_data", 100)),
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        max_data=int(cfg.get("max_data", 100)),
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-03 20:23:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        mode=str(cfg.get("partition_mode", "overlap")),  # "overlap" or "disjoint"
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        seed=int(cfg.get("i_seed", 0)),
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    users = split["users"]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    user_data = split["user_data"]  # mapping: id -> {"filename": [...]}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # --- build clients ---
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    model_name = cfg.get("model_name", "yolo_v11_n")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    clients = {}
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 22:37:50 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-03 20:23:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    for uid in users:
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:30:45 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        c = FedYoloClient(name=uid, model_name=model_name, params=cfg)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        c.load_trainset(user_data[uid]["filename"])
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        clients[uid] = c
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # --- build server & optional validation set ---
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:30:45 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    server = FedYoloServer(client_list=users, model_name=model_name, params=cfg)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-23 13:07:04 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    valset = build_valset_if_available(cfg, params=cfg, args=args_cli, val_name="val2017")
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # valset is a Dataset class, not data loader
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    if valset is not None:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        server.load_valset(valset)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # --- push initial global weights ---
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    global_state = server.state_dict()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # --- args object for client.train() ---
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # args_train = _make_args_for_client(cfg, args_cli)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # --- history recorder ---
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    history = {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        "mAP": [],
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        "mAP50": [],
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        "precision": [],
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        "recall": [],
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        "train_loss": [],  # scalar sum of client-weighted dict losses
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        "round_time_sec": [],
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # --- main FL loop ---
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    num_round = int(cfg.get("num_round", 50))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    connection_ratio = float(cfg.get("connection_ratio", 1.0))  # e.g., 1.0 = all clients
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    res_root = cfg.get("res_root", "results")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    os.makedirs(res_root, exist_ok=True)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-03 20:23:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    # tqdm logging
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    header = ("%10s" * 2) % ("Round", "client")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    tqdm.write("\n" + header)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    p_bar = tqdm(total=num_round, ncols=160, ascii="->>")
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-03 20:23:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    for rnd in range(num_round):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        t0 = time.time()
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # Local training (sequential over all users)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-03 20:23:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        for uid in users:
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:30:45 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            # tqdm desc update
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-03 20:23:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}"))
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:30:45 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            client = clients[uid]  # FedYoloClient instance
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            client.update(global_state)  # load global weights
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:30:45 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            state_dict, n_data, train_loss = client.train(args_cli)  # local training
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            server.rec(uid, state_dict, n_data, train_loss)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # Select a fraction for aggregation (FedAvg subset if desired)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        server.select_clients(connection_ratio=connection_ratio)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # Aggregate
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:30:45 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        global_state, avg_loss, _ = server.agg()
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # Compute a scalar train loss for plotting (sum of components)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:30:45 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        scalar_train_loss = avg_loss if avg_loss else 0.0
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # Test (if valset provided)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:30:45 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        mAP, mAP50, recall, precision = server.test() if server.valset is not None else (0.0, 0.0, 0.0, 0.0)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # Flush per-round client caches
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        server.flush()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # Record & log
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        history["mAP"].append(mAP)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        history["mAP50"].append(mAP50)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        history["precision"].append(precision)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        history["recall"].append(recall)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        history["train_loss"].append(scalar_train_loss)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        history["round_time_sec"].append(time.time() - t0)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-03 20:23:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        # Log GPU memory usage
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # gpu_mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # tqdm update
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        desc = {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            "loss": f"{scalar_train_loss:.6g}",
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            "mAP50": f"{mAP50:.6g}",
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            "mAP": f"{mAP:.6g}",
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            "precision": f"{precision:.6g}",
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            "recall": f"{recall:.6g}",
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            # "gpu_mem": gpu_mem,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        p_bar.set_postfix(desc)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # Save running JSON (resumable logs)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:30:45 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        save_name = f"{cfg.get('fed_algo', 'FedAvg')}_{[cfg.get('model_name', 'yolo')]}_{cfg.get('num_client', 0)}c_{cfg.get('num_local_class', 1)}cls_{cfg.get('num_round', 0)}r_{cfg.get('connection_ratio', 1):.2f}cr_{cfg.get('i_seed', 0)}s"
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        out_json = os.path.join(res_root, save_name + ".json")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        with open(out_json, "w", encoding="utf-8") as f:
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:30:45 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            json.dump(history, f, indent=4)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-03 20:23:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        p_bar.update(1)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    p_bar.close()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:30:45 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    # Save final global model weights
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    if not os.path.exists("./weights"):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        os.makedirs("./weights", exist_ok=True)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    torch.save(global_state, f"./weights/{save_name}_final.pth")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    print(f"[save] final global model weights: ./weights/{save_name}_final.pth")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # --- final plot ---
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:30:45 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    plot_curves(res_root, history, savename=f"{save_name}_curve.png")
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    print("[done] training complete.")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								if __name__ == "__main__":
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    fed_run()
							 |