修改聚合逻辑,每次聚合不再创建新模型

This commit is contained in:
2025-10-02 22:37:22 +08:00
parent c2e538898c
commit 0b52cfc4f5

View File

@@ -121,40 +121,53 @@ class FedYoloServer(object):
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)} return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
def select_clients(self, connection_ratio=1.0): def select_clients(self, connection_ratio=1.0):
"""Randomly select a fraction of clients.""" """
Randomly select a fraction of clients.
Args:
connection_ratio: fraction of clients to select (0 < connection_ratio <= 1)
"""
self.selected_clients = [] self.selected_clients = []
self.n_data = 0 self.n_data = 0
for client_id in self.client_list: for client_id in self.client_list:
# Random selection based on connection ratio
if np.random.rand() <= connection_ratio: if np.random.rand() <= connection_ratio:
self.selected_clients.append(client_id) self.selected_clients.append(client_id)
self.n_data += self.client_n_data.get(client_id, 0) self.n_data += self.client_n_data.get(client_id, 0)
def agg(self): def agg(self):
"""Aggregate client updates (FedAvg).""" """
Aggregate client updates (FedAvg).
Returns:
global_state: aggregated model state dictionary
avg_loss: dict of averaged losses
n_data: total number of data classes samples used in this round
"""
if len(self.selected_clients) == 0 or self.n_data == 0: if len(self.selected_clients) == 0 or self.n_data == 0:
return self.model.state_dict(), {}, 0 return self.model.state_dict(), {}, 0
model = init_model(self.model_name, self._num_classes) # start from current global model
model_state = model.state_dict() global_state = self.model.state_dict()
# zero buffer for accumulation
new_state = {k: torch.zeros_like(v, dtype=torch.float32) for k, v in global_state.items()}
avg_loss = {} avg_loss = {}
for i, name in enumerate(self.selected_clients): for name in self.selected_clients:
if name not in self.client_state: if name not in self.client_state:
continue continue
weight = self.client_n_data[name] / self.n_data weight = self.client_n_data[name] / self.n_data
for key in model_state.keys(): for k in new_state.keys():
if i == 0: # accumulate in float32 to avoid fp16 issues
model_state[key] = self.client_state[name][key] * weight new_state[k] += self.client_state[name][k].to(torch.float32) * weight
else:
model_state[key] += self.client_state[name][key] * weight
# Weighted average losses # losses
for k, v in self.client_loss[name].items(): for k, v in self.client_loss[name].items():
avg_loss[k] = avg_loss.get(k, 0.0) + v * weight avg_loss[k] = avg_loss.get(k, 0.0) + v * weight
self.model.load_state_dict(model_state, strict=True) # load aggregated params back into global model
self.model.load_state_dict(new_state, strict=True)
self.round += 1 self.round += 1
return model_state, avg_loss, self.n_data return self.model.state_dict(), avg_loss, self.n_data
def rec(self, name, state_dict, n_data, loss_dict): def rec(self, name, state_dict, n_data, loss_dict):
""" """