优化fed_run函数中的进度条显示和训练过程中的日志记录
This commit is contained in:
41
fed_run.py
41
fed_run.py
@@ -13,8 +13,8 @@ import matplotlib.pyplot as plt
|
||||
from utils.dataset import Dataset
|
||||
from fed_algo_cs.client_base import FedYoloClient
|
||||
from fed_algo_cs.server_base import FedYoloServer
|
||||
from utils.args import args_parser # your args parser
|
||||
from utils.fed_util import divide_trainset # divide_trainset is yours
|
||||
from utils.args import args_parser # args parser
|
||||
from utils.fed_util import divide_trainset # divide_trainset
|
||||
|
||||
|
||||
def _read_list_file(txt_path: str):
|
||||
@@ -132,7 +132,7 @@ def fed_run():
|
||||
num_client=int(cfg.get("num_client", 64)),
|
||||
min_data=int(cfg.get("min_data", 100)),
|
||||
max_data=int(cfg.get("max_data", 100)),
|
||||
mode=str(cfg.get("partition_mode", "disjoint")), # "overlap" or "disjoint"
|
||||
mode=str(cfg.get("partition_mode", "overlap")), # "overlap" or "disjoint"
|
||||
seed=int(cfg.get("i_seed", 0)),
|
||||
)
|
||||
|
||||
@@ -143,7 +143,7 @@ def fed_run():
|
||||
model_name = cfg.get("model_name", "yolo_v11_n")
|
||||
clients = {}
|
||||
|
||||
for uid in tqdm(users, desc="Building clients", leave=True, unit="client"):
|
||||
for uid in users:
|
||||
c = FedYoloClient(name=uid, model_name=model_name, params=params)
|
||||
c.load_trainset(user_data[uid]["filename"])
|
||||
clients[uid] = c
|
||||
@@ -177,11 +177,16 @@ def fed_run():
|
||||
res_root = cfg.get("res_root", "results")
|
||||
os.makedirs(res_root, exist_ok=True)
|
||||
|
||||
for rnd in tqdm(range(num_round), desc="main federal loop round:"):
|
||||
t0 = time.time()
|
||||
# tqdm logging
|
||||
header = ("%10s" * 2) % ("Round", "client")
|
||||
tqdm.write("\n" + header)
|
||||
p_bar = tqdm(total=num_round, ncols=160, ascii="->>")
|
||||
|
||||
for rnd in range(num_round):
|
||||
t0 = time.time()
|
||||
# Local training (sequential over all users)
|
||||
for uid in tqdm(users, desc=f"Round {rnd + 1} local training: ", leave=False):
|
||||
for uid in users:
|
||||
p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}"))
|
||||
client = clients[uid] # FedYoloClient instance
|
||||
client.update(global_state) # load global weights
|
||||
state_dict, n_data, loss_dict = client.train(args_cli) # local training
|
||||
@@ -214,12 +219,18 @@ def fed_run():
|
||||
history["train_loss"].append(scalar_train_loss)
|
||||
history["round_time_sec"].append(time.time() - t0)
|
||||
|
||||
tqdm.write(
|
||||
f"[round {rnd + 1:04d}] "
|
||||
f"loss={scalar_train_loss:.4f} mAP50-95={mAP:.4f} mAP50={mAP50:.4f} "
|
||||
f"P={precision:.4f} R={recall:.4f}"
|
||||
f"\n"
|
||||
)
|
||||
# Log GPU memory usage
|
||||
# gpu_mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
|
||||
# tqdm update
|
||||
desc = {
|
||||
"loss": f"{scalar_train_loss:.6g}",
|
||||
"mAP50": f"{mAP50:.6g}",
|
||||
"mAP": f"{mAP:.6g}",
|
||||
"precision": f"{precision:.6g}",
|
||||
"recall": f"{recall:.6g}",
|
||||
# "gpu_mem": gpu_mem,
|
||||
}
|
||||
p_bar.set_postfix(desc)
|
||||
|
||||
# Save running JSON (resumable logs)
|
||||
save_name = (
|
||||
@@ -232,6 +243,10 @@ def fed_run():
|
||||
with open(out_json, "w", encoding="utf-8") as f:
|
||||
json.dump(history, f, indent=2)
|
||||
|
||||
p_bar.update(1)
|
||||
|
||||
p_bar.close()
|
||||
|
||||
# --- final plot ---
|
||||
_plot_curves(res_root, history)
|
||||
print("[done] training complete.")
|
||||
|
Reference in New Issue
Block a user