优化fed_run函数中的进度条显示和训练过程中的日志记录
This commit is contained in:
@@ -3,6 +3,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils import data
|
from torch.utils import data
|
||||||
from torch.amp.autocast_mode import autocast
|
from torch.amp.autocast_mode import autocast
|
||||||
|
from tqdm import tqdm
|
||||||
from utils.fed_util import init_model
|
from utils.fed_util import init_model
|
||||||
from utils import util
|
from utils import util
|
||||||
from utils.dataset import Dataset
|
from utils.dataset import Dataset
|
||||||
@@ -152,7 +153,6 @@ class FedYoloClient(object):
|
|||||||
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
num_steps = max(1, len(loader))
|
num_steps = max(1, len(loader))
|
||||||
# print(len(loader))
|
|
||||||
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
|
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
|
||||||
# DDP mode
|
# DDP mode
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
@@ -167,7 +167,12 @@ class FedYoloClient(object):
|
|||||||
amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True)
|
amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True)
|
||||||
criterion = util.ComputeLoss(self.model, self.params)
|
criterion = util.ComputeLoss(self.model, self.params)
|
||||||
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
# log
|
||||||
|
# if args.local_rank == 0:
|
||||||
|
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
|
||||||
|
# print("\n" + header)
|
||||||
|
# p_bar = tqdm(total=args.epochs * num_steps, ncols=120)
|
||||||
|
# p_bar.set_description(f"{self.name:>10}")
|
||||||
|
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
@@ -180,10 +185,20 @@ class FedYoloClient(object):
|
|||||||
ds = cast(Dataset, loader.dataset)
|
ds = cast(Dataset, loader.dataset)
|
||||||
ds.mosaic = False
|
ds.mosaic = False
|
||||||
|
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
avg_box_loss = util.AverageMeter()
|
avg_box_loss = util.AverageMeter()
|
||||||
avg_cls_loss = util.AverageMeter()
|
avg_cls_loss = util.AverageMeter()
|
||||||
avg_dfl_loss = util.AverageMeter()
|
avg_dfl_loss = util.AverageMeter()
|
||||||
|
|
||||||
|
# # --- header (once per epoch, YOLO-style) ---
|
||||||
|
# if args.local_rank == 0:
|
||||||
|
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
|
||||||
|
# print("\n" + header)
|
||||||
|
|
||||||
|
# p_bar = enumerate(loader)
|
||||||
|
# if args.local_rank == 0:
|
||||||
|
# p_bar = tqdm(p_bar, total=num_steps, ncols=120)
|
||||||
|
|
||||||
for i, (samples, targets) in enumerate(loader):
|
for i, (samples, targets) in enumerate(loader):
|
||||||
global_step = i + num_steps * epoch
|
global_step = i + num_steps * epoch
|
||||||
scheduler.step(step=global_step, optimizer=optimizer)
|
scheduler.step(step=global_step, optimizer=optimizer)
|
||||||
@@ -202,9 +217,9 @@ class FedYoloClient(object):
|
|||||||
avg_dfl_loss.update(dfl_loss.item(), bs)
|
avg_dfl_loss.update(dfl_loss.item(), bs)
|
||||||
|
|
||||||
# scale losses by batch/world if your loss is averaged internally per-sample/device
|
# scale losses by batch/world if your loss is averaged internally per-sample/device
|
||||||
box_loss = box_loss * self._batch_size * args.world_size
|
# box_loss = box_loss * self._batch_size * args.world_size
|
||||||
cls_loss = cls_loss * self._batch_size * args.world_size
|
# cls_loss = cls_loss * self._batch_size * args.world_size
|
||||||
dfl_loss = dfl_loss * self._batch_size * args.world_size
|
# dfl_loss = dfl_loss * self._batch_size * args.world_size
|
||||||
|
|
||||||
total_loss = box_loss + cls_loss + dfl_loss
|
total_loss = box_loss + cls_loss + dfl_loss
|
||||||
|
|
||||||
@@ -213,6 +228,8 @@ class FedYoloClient(object):
|
|||||||
|
|
||||||
# Optimize
|
# Optimize
|
||||||
if (i + 1) % accumulate == 0:
|
if (i + 1) % accumulate == 0:
|
||||||
|
amp_scale.unscale_(optimizer) # unscale gradients
|
||||||
|
util.clip_gradients(model=self.model, max_norm=10.0) # clip gradients
|
||||||
amp_scale.step(optimizer)
|
amp_scale.step(optimizer)
|
||||||
amp_scale.update()
|
amp_scale.update()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
@@ -221,13 +238,28 @@ class FedYoloClient(object):
|
|||||||
|
|
||||||
# torch.cuda.synchronize()
|
# torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# tqdm update
|
||||||
|
# if args.local_rank == 0:
|
||||||
|
# mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
|
||||||
|
# desc = ("%10s" * 2 + "%10.4g" * 3) % (
|
||||||
|
# self.name,
|
||||||
|
# mem,
|
||||||
|
# avg_box_loss.avg,
|
||||||
|
# avg_cls_loss.avg,
|
||||||
|
# avg_dfl_loss.avg,
|
||||||
|
# )
|
||||||
|
# cast(tqdm, p_bar).set_description(desc)
|
||||||
|
# p_bar.update(1)
|
||||||
|
|
||||||
|
# p_bar.close()
|
||||||
|
|
||||||
# clean
|
# clean
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
torch.distributed.destroy_process_group()
|
torch.distributed.destroy_process_group()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return (
|
return (
|
||||||
self.model.state_dict(),
|
self.model.state_dict() if not ema else ema.ema.state_dict(),
|
||||||
self.n_data,
|
self.n_data,
|
||||||
{"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg},
|
{"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg},
|
||||||
)
|
)
|
||||||
|
41
fed_run.py
41
fed_run.py
@@ -13,8 +13,8 @@ import matplotlib.pyplot as plt
|
|||||||
from utils.dataset import Dataset
|
from utils.dataset import Dataset
|
||||||
from fed_algo_cs.client_base import FedYoloClient
|
from fed_algo_cs.client_base import FedYoloClient
|
||||||
from fed_algo_cs.server_base import FedYoloServer
|
from fed_algo_cs.server_base import FedYoloServer
|
||||||
from utils.args import args_parser # your args parser
|
from utils.args import args_parser # args parser
|
||||||
from utils.fed_util import divide_trainset # divide_trainset is yours
|
from utils.fed_util import divide_trainset # divide_trainset
|
||||||
|
|
||||||
|
|
||||||
def _read_list_file(txt_path: str):
|
def _read_list_file(txt_path: str):
|
||||||
@@ -132,7 +132,7 @@ def fed_run():
|
|||||||
num_client=int(cfg.get("num_client", 64)),
|
num_client=int(cfg.get("num_client", 64)),
|
||||||
min_data=int(cfg.get("min_data", 100)),
|
min_data=int(cfg.get("min_data", 100)),
|
||||||
max_data=int(cfg.get("max_data", 100)),
|
max_data=int(cfg.get("max_data", 100)),
|
||||||
mode=str(cfg.get("partition_mode", "disjoint")), # "overlap" or "disjoint"
|
mode=str(cfg.get("partition_mode", "overlap")), # "overlap" or "disjoint"
|
||||||
seed=int(cfg.get("i_seed", 0)),
|
seed=int(cfg.get("i_seed", 0)),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -143,7 +143,7 @@ def fed_run():
|
|||||||
model_name = cfg.get("model_name", "yolo_v11_n")
|
model_name = cfg.get("model_name", "yolo_v11_n")
|
||||||
clients = {}
|
clients = {}
|
||||||
|
|
||||||
for uid in tqdm(users, desc="Building clients", leave=True, unit="client"):
|
for uid in users:
|
||||||
c = FedYoloClient(name=uid, model_name=model_name, params=params)
|
c = FedYoloClient(name=uid, model_name=model_name, params=params)
|
||||||
c.load_trainset(user_data[uid]["filename"])
|
c.load_trainset(user_data[uid]["filename"])
|
||||||
clients[uid] = c
|
clients[uid] = c
|
||||||
@@ -177,11 +177,16 @@ def fed_run():
|
|||||||
res_root = cfg.get("res_root", "results")
|
res_root = cfg.get("res_root", "results")
|
||||||
os.makedirs(res_root, exist_ok=True)
|
os.makedirs(res_root, exist_ok=True)
|
||||||
|
|
||||||
for rnd in tqdm(range(num_round), desc="main federal loop round:"):
|
# tqdm logging
|
||||||
t0 = time.time()
|
header = ("%10s" * 2) % ("Round", "client")
|
||||||
|
tqdm.write("\n" + header)
|
||||||
|
p_bar = tqdm(total=num_round, ncols=160, ascii="->>")
|
||||||
|
|
||||||
|
for rnd in range(num_round):
|
||||||
|
t0 = time.time()
|
||||||
# Local training (sequential over all users)
|
# Local training (sequential over all users)
|
||||||
for uid in tqdm(users, desc=f"Round {rnd + 1} local training: ", leave=False):
|
for uid in users:
|
||||||
|
p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}"))
|
||||||
client = clients[uid] # FedYoloClient instance
|
client = clients[uid] # FedYoloClient instance
|
||||||
client.update(global_state) # load global weights
|
client.update(global_state) # load global weights
|
||||||
state_dict, n_data, loss_dict = client.train(args_cli) # local training
|
state_dict, n_data, loss_dict = client.train(args_cli) # local training
|
||||||
@@ -214,12 +219,18 @@ def fed_run():
|
|||||||
history["train_loss"].append(scalar_train_loss)
|
history["train_loss"].append(scalar_train_loss)
|
||||||
history["round_time_sec"].append(time.time() - t0)
|
history["round_time_sec"].append(time.time() - t0)
|
||||||
|
|
||||||
tqdm.write(
|
# Log GPU memory usage
|
||||||
f"[round {rnd + 1:04d}] "
|
# gpu_mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
|
||||||
f"loss={scalar_train_loss:.4f} mAP50-95={mAP:.4f} mAP50={mAP50:.4f} "
|
# tqdm update
|
||||||
f"P={precision:.4f} R={recall:.4f}"
|
desc = {
|
||||||
f"\n"
|
"loss": f"{scalar_train_loss:.6g}",
|
||||||
)
|
"mAP50": f"{mAP50:.6g}",
|
||||||
|
"mAP": f"{mAP:.6g}",
|
||||||
|
"precision": f"{precision:.6g}",
|
||||||
|
"recall": f"{recall:.6g}",
|
||||||
|
# "gpu_mem": gpu_mem,
|
||||||
|
}
|
||||||
|
p_bar.set_postfix(desc)
|
||||||
|
|
||||||
# Save running JSON (resumable logs)
|
# Save running JSON (resumable logs)
|
||||||
save_name = (
|
save_name = (
|
||||||
@@ -232,6 +243,10 @@ def fed_run():
|
|||||||
with open(out_json, "w", encoding="utf-8") as f:
|
with open(out_json, "w", encoding="utf-8") as f:
|
||||||
json.dump(history, f, indent=2)
|
json.dump(history, f, indent=2)
|
||||||
|
|
||||||
|
p_bar.update(1)
|
||||||
|
|
||||||
|
p_bar.close()
|
||||||
|
|
||||||
# --- final plot ---
|
# --- final plot ---
|
||||||
_plot_curves(res_root, history)
|
_plot_curves(res_root, history)
|
||||||
print("[done] training complete.")
|
print("[done] training complete.")
|
||||||
|
Reference in New Issue
Block a user