From 0b52cfc4f50ffeed56435b5ec33d430967f83d9c Mon Sep 17 00:00:00 2001 From: Yunhao Meng Date: Thu, 2 Oct 2025 22:37:22 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E8=81=9A=E5=90=88=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E6=AF=8F=E6=AC=A1=E8=81=9A=E5=90=88=E4=B8=8D?= =?UTF-8?q?=E5=86=8D=E5=88=9B=E5=BB=BA=E6=96=B0=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fed_algo_cs/server_base.py | 39 +++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/fed_algo_cs/server_base.py b/fed_algo_cs/server_base.py index 5dd82b2..5043675 100644 --- a/fed_algo_cs/server_base.py +++ b/fed_algo_cs/server_base.py @@ -121,40 +121,53 @@ class FedYoloServer(object): return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)} def select_clients(self, connection_ratio=1.0): - """Randomly select a fraction of clients.""" + """ + Randomly select a fraction of clients. + Args: + connection_ratio: fraction of clients to select (0 < connection_ratio <= 1) + """ self.selected_clients = [] self.n_data = 0 for client_id in self.client_list: + # Random selection based on connection ratio if np.random.rand() <= connection_ratio: self.selected_clients.append(client_id) self.n_data += self.client_n_data.get(client_id, 0) def agg(self): - """Aggregate client updates (FedAvg).""" + """ + Aggregate client updates (FedAvg). + Returns: + global_state: aggregated model state dictionary + avg_loss: dict of averaged losses + n_data: total number of data classes samples used in this round + """ if len(self.selected_clients) == 0 or self.n_data == 0: return self.model.state_dict(), {}, 0 - model = init_model(self.model_name, self._num_classes) - model_state = model.state_dict() + # start from current global model + global_state = self.model.state_dict() + + # zero buffer for accumulation + new_state = {k: torch.zeros_like(v, dtype=torch.float32) for k, v in global_state.items()} avg_loss = {} - for i, name in enumerate(self.selected_clients): + for name in self.selected_clients: if name not in self.client_state: continue weight = self.client_n_data[name] / self.n_data - for key in model_state.keys(): - if i == 0: - model_state[key] = self.client_state[name][key] * weight - else: - model_state[key] += self.client_state[name][key] * weight + for k in new_state.keys(): + # accumulate in float32 to avoid fp16 issues + new_state[k] += self.client_state[name][k].to(torch.float32) * weight - # Weighted average losses + # losses for k, v in self.client_loss[name].items(): avg_loss[k] = avg_loss.get(k, 0.0) + v * weight - self.model.load_state_dict(model_state, strict=True) + # load aggregated params back into global model + self.model.load_state_dict(new_state, strict=True) self.round += 1 - return model_state, avg_loss, self.n_data + return self.model.state_dict(), avg_loss, self.n_data def rec(self, name, state_dict, n_data, loss_dict): """