From 194ca8ee31fadbe87d3ffe0edb88aa134e7499e6 Mon Sep 17 00:00:00 2001 From: Yunhao Meng Date: Fri, 31 Oct 2025 13:14:21 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=96=87=E6=A1=A3=E5=AD=97?= =?UTF-8?q?=E7=AC=A6=E4=B8=B2=EF=BC=8C=E6=98=8E=E7=A1=AE=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fed_algo_cs/client_base.py | 15 ++++++++++++--- fed_algo_cs/server_base.py | 23 ++++++++++++++++------- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/fed_algo_cs/client_base.py b/fed_algo_cs/client_base.py index 85ea227..f89907a 100644 --- a/fed_algo_cs/client_base.py +++ b/fed_algo_cs/client_base.py @@ -64,7 +64,7 @@ class FedYoloClient(object): """ Load the local training dataset Args: - :param train_dataset: Training dataset + train_dataset: Training dataset """ self.train_dataset = train_dataset self.n_data = len(self.train_dataset) @@ -72,8 +72,9 @@ class FedYoloClient(object): def update(self, Global_model_state_dict): """ Update the local model with the global model parameters + Args: - :param Global_model_state_dict: State dictionary of the global model + Global_model_state_dict: State dictionary of the global model """ if not hasattr(self, "model") or self.model is None: @@ -85,7 +86,15 @@ class FedYoloClient(object): def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]: """ Train the local model. - Returns: (state_dict, n_data, avg_loss_per_image) + + Args: + args: training arguments including + + Returns: + (state_dict, n_data, avg_loss_per_image): A tuple including: + - state_dict: State dictionary of the trained local model + - n_data: Number of training data samples + - avg_loss_per_image: Average training loss per image over all epochs """ # ---- Dist init (if any) ---- diff --git a/fed_algo_cs/server_base.py b/fed_algo_cs/server_base.py index ad213a1..73ce8be 100644 --- a/fed_algo_cs/server_base.py +++ b/fed_algo_cs/server_base.py @@ -11,13 +11,13 @@ class FedYoloServer(object): def __init__(self, client_list, model_name, params): """ Federated YOLO Server - Args: + Attributes: client_list: list of connected clients model_name: YOLO model architecture name params: dict of hyperparameters (must include 'names') """ # Track client updates - self.client_state = {} + self.client_state: dict[str, dict[str, torch.Tensor]] = {} self.client_loss = {} self.client_n_data = {} self.selected_clients = [] @@ -64,14 +64,19 @@ class FedYoloServer(object): self.selected_clients.append(client_id) self.n_data += self.client_n_data[client_id] + # TODO: skip the layer which can not be learnted locally @torch.no_grad() - def agg(self): + def agg(self, skip_bn_layer: bool = False): """ Server aggregates the local updates from selected clients using FedAvg. - :return: model_state: aggregated model weights - :return: avg_loss: weighted average training loss across selected clients - :return: n_data: total number of data points across selected clients + Args: + skip_bn_layer: whether to skip batch normalization layers during aggregation + + Returns: + :model_state: aggregated model weights + :avg_loss: weighted average training loss across selected clients + :n_data: total number of data points across selected clients """ if len(self.selected_clients) == 0 or self.n_data == 0: import warnings @@ -144,11 +149,13 @@ class FedYoloServer(object): def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]: """ Evaluate the model on the validation dataset. + Args: valset: validation dataset params: dict of parameters (must include 'names') model: YOLO model to evaluate batch_size: batch size for evaluation + Returns: dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap) """ @@ -214,7 +221,9 @@ def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[f # Compute metrics metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy if len(metrics) and metrics[0].any(): - tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics, plot=False, names=params["names"]) + tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap( + *metrics, plot=False, names=params["names"] + ) # set plot=True to plot metric curve # Print results # print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap)) # Return results