diff --git a/fed_run.py b/fed_run.py index 97cb9d1..a333379 100644 --- a/fed_run.py +++ b/fed_run.py @@ -14,6 +14,7 @@ from fed_algo_cs.server_base import FedYoloServer from utils.args import args_parser # args parser from utils.fed_util import divide_trainset # divide_trainset from utils import util +from utils import fed_util from utils.fed_util import prepare_result_dir @@ -189,7 +190,9 @@ def fed_run(): # Save final global model weights # FIXME: save model not adaptive YOLOv11-pt specific - save_model = {"config": cfg, "model": copy.deepcopy(global_state if global_state else None)} + global_model = fed_util.init_model(model_name, num_classes=len(cfg["names"])) + global_model.load_state_dict(global_state) + save_model = {"config": cfg, "model": copy.deepcopy(global_model if global_model else None)} torch.save(save_model, f"{weights_root}/last.pt") if best == mAP: torch.save(save_model, f"{weights_root}/best.pt")