2025-10-19 21:31:08 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from utils.fed_util import init_model
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from fed_algo_cs.server_base import test
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import os
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import yaml
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from utils.args import args_parser  # args parser
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from fed_algo_cs.client_base import FedYoloClient  # FedYoloClient
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from fed_algo_cs.server_base import FedYoloServer  # FedYoloServer
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from utils import Dataset  # Dataset
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								if __name__ == "__main__":
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-23 13:07:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    if not os.path.exists("model.txt"):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # model structure test
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        model = init_model("yolo_v11_n", num_classes=1)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:31:08 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-23 13:07:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        with open("model.txt", "w", encoding="utf-8") as f:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            print(model, file=f)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:31:08 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-23 13:07:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    if not os.path.exists("model_key_value.txt"):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # loop over model key and values
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        with open("model_key_value.txt", "w", encoding="utf-8") as f:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            for k, v in model.state_dict().items():
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                print(f"{k}: {v.shape}", file=f)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:31:08 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # test agg function
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-23 13:07:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    # from fed_algo_cs.server_base import FedYoloServer
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # import torch
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # import yaml
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:31:08 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-23 13:07:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    # with open("./config/coco128_cfg.yaml", "r", encoding="utf-8") as f:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    #     cfg = yaml.safe_load(f)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # # params = dict(cfg)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:31:08 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-23 13:07:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    # server = FedYoloServer(client_list=["c1", "c2", "c3"], model_name="yolo_v11_n", params=cfg)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # state1 = {k: torch.ones_like(v) for k, v in server.model.state_dict().items()}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # state2 = {k: torch.ones_like(v) * 2 for k, v in server.model.state_dict().items()}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # state3 = {k: torch.ones_like(v) * 3 for k, v in server.model.state_dict().items()}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # server.rec("c1", state1, n_data=20, loss=0.1)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # server.rec("c2", state2, n_data=30, loss=0.2)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # server.rec("c3", state3, n_data=50, loss=0.3)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # server.select_clients(connection_ratio=1.0)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # model_state, avg_loss, n_data = server.agg()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # with open("agg_model.txt", "w", encoding="utf-8") as f:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    #     for k, v in model_state.items():
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    #         print(f"{k}: {v.float().mean()}", file=f)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # print(f"avg_loss: {avg_loss}, n_data: {n_data}")
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:31:08 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # test single client training (should be the same as standalone training)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    args = args_parser()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    with open(args.config, "r", encoding="utf-8") as f:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        cfg = yaml.safe_load(f)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # params = dict(cfg)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    client = FedYoloClient(name="c1", params=cfg, model_name="yolo_v11_n")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    filenames = []
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-23 13:07:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    data_dir = "/mnt/DATA/COCO128"
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:31:08 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    with open(f"{data_dir}/train.txt") as f:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        for filename in f.readlines():
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            filename = os.path.basename(filename.rstrip())
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            filenames.append(f"{data_dir}/images/train2017/" + filename)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    client.load_trainset(train_dataset=filenames)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    model_state, n_data, avg_loss = client.train(args=args)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    model = init_model("yolo_v11_n", num_classes=80)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    model.load_state_dict(model_state)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    valset = Dataset(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        filenames=filenames,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        input_size=640,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        params=cfg,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        augment=False,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    if valset is not None:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        precision, recall, map50, map = test(valset=valset, params=cfg, model=model, batch_size=128)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        print(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            f"precision: {precision}, recall: {recall}, map50: {map50}, map: {map}, loss: {avg_loss}, n_data: {n_data}"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    else:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        raise ValueError("valset is None, please provide a valid valset in config file.")
							 |