From 291b82bec336eb933a29c801a2fc7a62eafd43e5 Mon Sep 17 00:00:00 2001 From: Yunhao Meng Date: Fri, 31 Oct 2025 13:14:37 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E8=AE=AD=E7=BB=83=E8=BF=87?= =?UTF-8?q?=E7=A8=8B=EF=BC=8C=E6=B7=BB=E5=8A=A0=20CSV=20=E6=97=A5=E5=BF=97?= =?UTF-8?q?=E8=AE=B0=E5=BD=95=EF=BC=8C=E6=94=B9=E8=BF=9B=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E6=9D=83=E9=87=8D=E4=BF=9D=E5=AD=98=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fed_run.py | 124 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 77 insertions(+), 47 deletions(-) diff --git a/fed_run.py b/fed_run.py index 6dfb91e..97cb9d1 100644 --- a/fed_run.py +++ b/fed_run.py @@ -5,12 +5,16 @@ import yaml import time from tqdm import tqdm import torch +import csv +import copy from utils.fed_util import build_valset_if_available, seed_everything, plot_curves from fed_algo_cs.client_base import FedYoloClient from fed_algo_cs.server_base import FedYoloServer from utils.args import args_parser # args parser from utils.fed_util import divide_trainset # divide_trainset +from utils import util +from utils.fed_util import prepare_result_dir def fed_run(): @@ -26,11 +30,6 @@ def fed_run(): with open(args_cli.config, "r", encoding="utf-8") as f: cfg = yaml.safe_load(f) - # --- params / config normalization --- - # For convenience we pass the same `params` dict used by Dataset/model/loss. - # Here we re-use the top-level cfg directly as params. - # params = dict(cfg) - if "names" in cfg and isinstance(cfg["names"], dict): # Convert {0: 'uav', 1: 'car', ...} to list if you prefer list # but we can leave dict; your utils appear to accept dict @@ -39,6 +38,9 @@ def fed_run(): # seeds seed_everything(int(cfg.get("i_seed", 0))) + # result directory + res_root, weights_root = prepare_result_dir(base_root=cfg.get("res_root", "results")) + # --- split clients' train data from a global train list --- # Expect either cfg["train_txt"] or /train.txt train_txt = cfg.get("train_txt", "") @@ -67,7 +69,7 @@ def fed_run(): # --- build clients --- model_name = cfg.get("model_name", "yolo_v11_n") - clients = {} + clients: dict[str, FedYoloClient] = {} for uid in users: c = FedYoloClient(name=uid, model_name=model_name, params=cfg) @@ -84,9 +86,6 @@ def fed_run(): # --- push initial global weights --- global_state = server.state_dict() - # --- args object for client.train() --- - # args_train = _make_args_for_client(cfg, args_cli) - # --- history recorder --- history = { "mAP": [], @@ -98,16 +97,16 @@ def fed_run(): } # --- main FL loop --- + best = 0.0 # best mAP num_round = int(cfg.get("num_round", 50)) connection_ratio = float(cfg.get("connection_ratio", 1.0)) # e.g., 1.0 = all clients - res_root = cfg.get("res_root", "results") - os.makedirs(res_root, exist_ok=True) # tqdm logging header = ("%10s" * 2) % ("Round", "client") tqdm.write("\n" + header) p_bar = tqdm(total=num_round, ncols=160, ascii="->>") + # train loop for rnd in range(num_round): t0 = time.time() # Local training (sequential over all users) @@ -115,7 +114,7 @@ def fed_run(): # tqdm desc update p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}")) - client = clients[uid] # FedYoloClient instance + client: FedYoloClient = clients[uid] # FedYoloClient instance client.update(global_state) # load global weights state_dict, n_data, train_loss = client.train(args_cli) # local training server.rec(uid, state_dict, n_data, train_loss) @@ -129,51 +128,82 @@ def fed_run(): # Compute a scalar train loss for plotting (sum of components) scalar_train_loss = avg_loss if avg_loss else 0.0 - # Test (if valset provided) - mAP, mAP50, recall, precision = server.test() if server.valset is not None else (0.0, 0.0, 0.0, 0.0) + if args_cli.local_rank == 0: + # Test (if valset provided) + mAP, mAP50, recall, precision = server.test() if server.valset is not None else (0.0, 0.0, 0.0, 0.0) - # Flush per-round client caches - server.flush() + if mAP > best: + best = mAP - # Record & log - history["mAP"].append(mAP) - history["mAP50"].append(mAP50) - history["precision"].append(precision) - history["recall"].append(recall) - history["train_loss"].append(scalar_train_loss) - history["round_time_sec"].append(time.time() - t0) + # Flush per-round client caches + server.flush() - # Log GPU memory usage - # gpu_mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G" - # tqdm update - desc = { - "loss": f"{scalar_train_loss:.6g}", - "mAP50": f"{mAP50:.6g}", - "mAP": f"{mAP:.6g}", - "precision": f"{precision:.6g}", - "recall": f"{recall:.6g}", - # "gpu_mem": gpu_mem, - } - p_bar.set_postfix(desc) + # Record & log + history["mAP"].append(mAP) + history["mAP50"].append(mAP50) + history["precision"].append(precision) + history["recall"].append(recall) + history["train_loss"].append(scalar_train_loss) + history["round_time_sec"].append(time.time() - t0) - # Save running JSON (resumable logs) - save_name = f"{cfg.get('fed_algo', 'FedAvg')}_{[cfg.get('model_name', 'yolo')]}_{cfg.get('num_client', 0)}c_{cfg.get('num_local_class', 1)}cls_{cfg.get('num_round', 0)}r_{cfg.get('connection_ratio', 1):.2f}cr_{cfg.get('i_seed', 0)}s" - out_json = os.path.join(res_root, save_name + ".json") - with open(out_json, "w", encoding="utf-8") as f: - json.dump(history, f, indent=4) + # Log GPU memory usage + # gpu_mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G" + # tqdm update + desc = { + "loss": f"{scalar_train_loss:.6g}", + "mAP50": f"{mAP50:.6g}", + "mAP": f"{mAP:.6g}", + "precision": f"{precision:.6g}", + "recall": f"{recall:.6g}", + # "gpu_mem": gpu_mem, + } + p_bar.set_postfix(desc) + # Save running JSON (resumable logs) + # save_name = f"{cfg.get('fed_algo', 'FedAvg')}_{[cfg.get('model_name', 'yolo')]}_{cfg.get('num_client', 0)}c_{cfg.get('num_local_class', 1)}cls_{cfg.get('num_round', 0)}r_{cfg.get('connection_ratio', 1):.2f}cr_{cfg.get('i_seed', 0)}s" + + # out_json = os.path.join(res_root, save_name + ".json") + # with open(out_json, "w", encoding="utf-8") as f: + # json.dump(history, f, indent=4) + + # Use csv file to save running metrics + row = { + "round": rnd + 1, + "loss": f"{scalar_train_loss:.3f}", + "mAP": f"{mAP:.3f}", + "mAP50": f"{mAP50:.3f}", + "precision": f"{precision:.3f}", + "recall": f"{recall:.3f}", + "sec": f"{time.time() - t0:.1f}", + } + + # log to csv + out_csv = os.path.join(res_root, "step.csv") + fieldnames = ["round", "loss", "mAP", "mAP50", "precision", "recall", "sec"] + mode = "w" if rnd == 0 else "a" + with open(file=out_csv, mode=mode, newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + if rnd == 0: + writer.writeheader() # write header only once + writer.writerow(row) + + # Save final global model weights + # FIXME: save model not adaptive YOLOv11-pt specific + save_model = {"config": cfg, "model": copy.deepcopy(global_state if global_state else None)} + torch.save(save_model, f"{weights_root}/last.pt") + if best == mAP: + torch.save(save_model, f"{weights_root}/best.pt") + del save_model + # print(f"[save] final global model weights: {weights_root}/last.pt") p_bar.update(1) - p_bar.close() - # Save final global model weights - if not os.path.exists("./weights"): - os.makedirs("./weights", exist_ok=True) - torch.save(global_state, f"./weights/{save_name}_final.pth") - print(f"[save] final global model weights: ./weights/{save_name}_final.pth") + if args_cli.local_rank == 0: + util.strip_optimizer(f"{weights_root}/best.pt") + util.strip_optimizer(f"{weights_root}/last.pt") # --- final plot --- - plot_curves(res_root, history, savename=f"{save_name}_curve.png") + plot_curves(res_root, history, savename="train_curve.png") print("[done] training complete.")