优化训练过程,添加 CSV 日志记录,改进模型权重保存机制
This commit is contained in:
124
fed_run.py
124
fed_run.py
@@ -5,12 +5,16 @@ import yaml
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import csv
|
||||
import copy
|
||||
|
||||
from utils.fed_util import build_valset_if_available, seed_everything, plot_curves
|
||||
from fed_algo_cs.client_base import FedYoloClient
|
||||
from fed_algo_cs.server_base import FedYoloServer
|
||||
from utils.args import args_parser # args parser
|
||||
from utils.fed_util import divide_trainset # divide_trainset
|
||||
from utils import util
|
||||
from utils.fed_util import prepare_result_dir
|
||||
|
||||
|
||||
def fed_run():
|
||||
@@ -26,11 +30,6 @@ def fed_run():
|
||||
with open(args_cli.config, "r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
# --- params / config normalization ---
|
||||
# For convenience we pass the same `params` dict used by Dataset/model/loss.
|
||||
# Here we re-use the top-level cfg directly as params.
|
||||
# params = dict(cfg)
|
||||
|
||||
if "names" in cfg and isinstance(cfg["names"], dict):
|
||||
# Convert {0: 'uav', 1: 'car', ...} to list if you prefer list
|
||||
# but we can leave dict; your utils appear to accept dict
|
||||
@@ -39,6 +38,9 @@ def fed_run():
|
||||
# seeds
|
||||
seed_everything(int(cfg.get("i_seed", 0)))
|
||||
|
||||
# result directory
|
||||
res_root, weights_root = prepare_result_dir(base_root=cfg.get("res_root", "results"))
|
||||
|
||||
# --- split clients' train data from a global train list ---
|
||||
# Expect either cfg["train_txt"] or <dataset_path>/train.txt
|
||||
train_txt = cfg.get("train_txt", "")
|
||||
@@ -67,7 +69,7 @@ def fed_run():
|
||||
|
||||
# --- build clients ---
|
||||
model_name = cfg.get("model_name", "yolo_v11_n")
|
||||
clients = {}
|
||||
clients: dict[str, FedYoloClient] = {}
|
||||
|
||||
for uid in users:
|
||||
c = FedYoloClient(name=uid, model_name=model_name, params=cfg)
|
||||
@@ -84,9 +86,6 @@ def fed_run():
|
||||
# --- push initial global weights ---
|
||||
global_state = server.state_dict()
|
||||
|
||||
# --- args object for client.train() ---
|
||||
# args_train = _make_args_for_client(cfg, args_cli)
|
||||
|
||||
# --- history recorder ---
|
||||
history = {
|
||||
"mAP": [],
|
||||
@@ -98,16 +97,16 @@ def fed_run():
|
||||
}
|
||||
|
||||
# --- main FL loop ---
|
||||
best = 0.0 # best mAP
|
||||
num_round = int(cfg.get("num_round", 50))
|
||||
connection_ratio = float(cfg.get("connection_ratio", 1.0)) # e.g., 1.0 = all clients
|
||||
res_root = cfg.get("res_root", "results")
|
||||
os.makedirs(res_root, exist_ok=True)
|
||||
|
||||
# tqdm logging
|
||||
header = ("%10s" * 2) % ("Round", "client")
|
||||
tqdm.write("\n" + header)
|
||||
p_bar = tqdm(total=num_round, ncols=160, ascii="->>")
|
||||
|
||||
# train loop
|
||||
for rnd in range(num_round):
|
||||
t0 = time.time()
|
||||
# Local training (sequential over all users)
|
||||
@@ -115,7 +114,7 @@ def fed_run():
|
||||
# tqdm desc update
|
||||
p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}"))
|
||||
|
||||
client = clients[uid] # FedYoloClient instance
|
||||
client: FedYoloClient = clients[uid] # FedYoloClient instance
|
||||
client.update(global_state) # load global weights
|
||||
state_dict, n_data, train_loss = client.train(args_cli) # local training
|
||||
server.rec(uid, state_dict, n_data, train_loss)
|
||||
@@ -129,51 +128,82 @@ def fed_run():
|
||||
# Compute a scalar train loss for plotting (sum of components)
|
||||
scalar_train_loss = avg_loss if avg_loss else 0.0
|
||||
|
||||
# Test (if valset provided)
|
||||
mAP, mAP50, recall, precision = server.test() if server.valset is not None else (0.0, 0.0, 0.0, 0.0)
|
||||
if args_cli.local_rank == 0:
|
||||
# Test (if valset provided)
|
||||
mAP, mAP50, recall, precision = server.test() if server.valset is not None else (0.0, 0.0, 0.0, 0.0)
|
||||
|
||||
# Flush per-round client caches
|
||||
server.flush()
|
||||
if mAP > best:
|
||||
best = mAP
|
||||
|
||||
# Record & log
|
||||
history["mAP"].append(mAP)
|
||||
history["mAP50"].append(mAP50)
|
||||
history["precision"].append(precision)
|
||||
history["recall"].append(recall)
|
||||
history["train_loss"].append(scalar_train_loss)
|
||||
history["round_time_sec"].append(time.time() - t0)
|
||||
# Flush per-round client caches
|
||||
server.flush()
|
||||
|
||||
# Log GPU memory usage
|
||||
# gpu_mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
|
||||
# tqdm update
|
||||
desc = {
|
||||
"loss": f"{scalar_train_loss:.6g}",
|
||||
"mAP50": f"{mAP50:.6g}",
|
||||
"mAP": f"{mAP:.6g}",
|
||||
"precision": f"{precision:.6g}",
|
||||
"recall": f"{recall:.6g}",
|
||||
# "gpu_mem": gpu_mem,
|
||||
}
|
||||
p_bar.set_postfix(desc)
|
||||
# Record & log
|
||||
history["mAP"].append(mAP)
|
||||
history["mAP50"].append(mAP50)
|
||||
history["precision"].append(precision)
|
||||
history["recall"].append(recall)
|
||||
history["train_loss"].append(scalar_train_loss)
|
||||
history["round_time_sec"].append(time.time() - t0)
|
||||
|
||||
# Save running JSON (resumable logs)
|
||||
save_name = f"{cfg.get('fed_algo', 'FedAvg')}_{[cfg.get('model_name', 'yolo')]}_{cfg.get('num_client', 0)}c_{cfg.get('num_local_class', 1)}cls_{cfg.get('num_round', 0)}r_{cfg.get('connection_ratio', 1):.2f}cr_{cfg.get('i_seed', 0)}s"
|
||||
out_json = os.path.join(res_root, save_name + ".json")
|
||||
with open(out_json, "w", encoding="utf-8") as f:
|
||||
json.dump(history, f, indent=4)
|
||||
# Log GPU memory usage
|
||||
# gpu_mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
|
||||
# tqdm update
|
||||
desc = {
|
||||
"loss": f"{scalar_train_loss:.6g}",
|
||||
"mAP50": f"{mAP50:.6g}",
|
||||
"mAP": f"{mAP:.6g}",
|
||||
"precision": f"{precision:.6g}",
|
||||
"recall": f"{recall:.6g}",
|
||||
# "gpu_mem": gpu_mem,
|
||||
}
|
||||
p_bar.set_postfix(desc)
|
||||
|
||||
# Save running JSON (resumable logs)
|
||||
# save_name = f"{cfg.get('fed_algo', 'FedAvg')}_{[cfg.get('model_name', 'yolo')]}_{cfg.get('num_client', 0)}c_{cfg.get('num_local_class', 1)}cls_{cfg.get('num_round', 0)}r_{cfg.get('connection_ratio', 1):.2f}cr_{cfg.get('i_seed', 0)}s"
|
||||
|
||||
# out_json = os.path.join(res_root, save_name + ".json")
|
||||
# with open(out_json, "w", encoding="utf-8") as f:
|
||||
# json.dump(history, f, indent=4)
|
||||
|
||||
# Use csv file to save running metrics
|
||||
row = {
|
||||
"round": rnd + 1,
|
||||
"loss": f"{scalar_train_loss:.3f}",
|
||||
"mAP": f"{mAP:.3f}",
|
||||
"mAP50": f"{mAP50:.3f}",
|
||||
"precision": f"{precision:.3f}",
|
||||
"recall": f"{recall:.3f}",
|
||||
"sec": f"{time.time() - t0:.1f}",
|
||||
}
|
||||
|
||||
# log to csv
|
||||
out_csv = os.path.join(res_root, "step.csv")
|
||||
fieldnames = ["round", "loss", "mAP", "mAP50", "precision", "recall", "sec"]
|
||||
mode = "w" if rnd == 0 else "a"
|
||||
with open(file=out_csv, mode=mode, newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
if rnd == 0:
|
||||
writer.writeheader() # write header only once
|
||||
writer.writerow(row)
|
||||
|
||||
# Save final global model weights
|
||||
# FIXME: save model not adaptive YOLOv11-pt specific
|
||||
save_model = {"config": cfg, "model": copy.deepcopy(global_state if global_state else None)}
|
||||
torch.save(save_model, f"{weights_root}/last.pt")
|
||||
if best == mAP:
|
||||
torch.save(save_model, f"{weights_root}/best.pt")
|
||||
del save_model
|
||||
# print(f"[save] final global model weights: {weights_root}/last.pt")
|
||||
p_bar.update(1)
|
||||
|
||||
p_bar.close()
|
||||
|
||||
# Save final global model weights
|
||||
if not os.path.exists("./weights"):
|
||||
os.makedirs("./weights", exist_ok=True)
|
||||
torch.save(global_state, f"./weights/{save_name}_final.pth")
|
||||
print(f"[save] final global model weights: ./weights/{save_name}_final.pth")
|
||||
if args_cli.local_rank == 0:
|
||||
util.strip_optimizer(f"{weights_root}/best.pt")
|
||||
util.strip_optimizer(f"{weights_root}/last.pt")
|
||||
|
||||
# --- final plot ---
|
||||
plot_curves(res_root, history, savename=f"{save_name}_curve.png")
|
||||
plot_curves(res_root, history, savename="train_curve.png")
|
||||
print("[done] training complete.")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user