import numpy as np import torch from torch.utils.data import DataLoader from utils.fed_util import init_model from utils.dataset import Dataset from utils import util class FedYoloServer(object): def __init__(self, client_list, model_name, params): """ Federated YOLO Server Args: client_list: list of connected clients model_name: YOLO model architecture name params: dict of hyperparameters (must include 'names') """ # Track client updates self.client_state = {} self.client_loss = {} self.client_n_data = {} self.selected_clients = [] self._batch_size = params.get("val_batch_size", 4) self.client_list = client_list self.valset = None # Federated bookkeeping self.round = 0 # Total number of classes self.n_data = 0 # Device gpu = 0 self._device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu") # Global model self._num_classes = len(params["names"]) self.model_name = model_name self.model = init_model(model_name, self._num_classes) self.params = params def load_valset(self, valset): """Server loads the validation dataset.""" self.valset = valset def state_dict(self): """Return global model weights.""" return self.model.state_dict() @torch.no_grad() def test(self, args) -> dict: """ Test the global model on the server's validation set. Returns: dict with keys: mAP, mAP50, precision, recall """ if self.valset is None: return {} loader = DataLoader( self.valset, batch_size=self._batch_size, shuffle=False, num_workers=4, pin_memory=True, collate_fn=Dataset.collate_fn, ) dev = self._device # move to device for eval; keep in float32 for stability self.model.eval().to(dev).float() iou_v = torch.linspace(0.5, 0.95, 10, device=dev) n_iou = iou_v.numel() metrics = [] for samples, targets in loader: samples = samples.to(dev, non_blocking=True).float() / 255.0 _, _, h, w = samples.shape scale = torch.tensor((w, h, w, h), device=dev) outputs = self.model(samples) outputs = util.non_max_suppression(outputs) for i, output in enumerate(outputs): idx = targets["idx"] == i cls = targets["cls"][idx].to(dev) box = targets["box"][idx].to(dev) metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=dev) if output.shape[0] == 0: if cls.shape[0]: metrics.append((metric, *torch.zeros((2, 0), device=dev), cls.squeeze(-1))) continue if cls.shape[0]: if cls.dim() == 1: cls = cls.unsqueeze(1) box_xy = util.wh2xy(box) if not isinstance(box_xy, torch.Tensor): box_xy = torch.tensor(box_xy, device=dev) target = torch.cat((cls, box_xy * scale), dim=1) metric = util.compute_metric(output[:, :6], target, iou_v) metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1))) if not metrics: # move back to CPU before returning self.model.to("cpu").float() return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0} metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] if len(metrics) and metrics[0].any(): _, _, prec, rec, map50, mean_ap = util.compute_ap(*metrics, names=self.params["names"], plot=False) else: prec, rec, map50, mean_ap = 0, 0, 0, 0 # return model to CPU so next agg() stays device-consistent self.model.to("cpu").float() return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)} def select_clients(self, connection_ratio=1.0): """ Randomly select a fraction of clients. Args: connection_ratio: fraction of clients to select (0 < connection_ratio <= 1) """ self.selected_clients = [] self.n_data = 0 for client_id in self.client_list: # Random selection based on connection ratio if np.random.rand() <= connection_ratio: self.selected_clients.append(client_id) self.n_data += self.client_n_data.get(client_id, 0) def agg(self): """Aggregate client updates (FedAvg) on CPU/FP32, preserving non-float buffers.""" if len(self.selected_clients) == 0 or self.n_data == 0: return self.model.state_dict(), {}, 0 # Ensure global model is on CPU for safe load later self.model.to("cpu") global_state = self.model.state_dict() # may hold CPU or CUDA refs; we’re on CPU now avg_loss = {} total_n = float(self.n_data) # Prepare accumulators on CPU. For floating tensors, use float32 zeros. # For non-floating tensors (e.g., BN num_batches_tracked int64), we’ll copy from the first client. new_state = {} first_client = None for cid in self.selected_clients: if cid in self.client_state: first_client = cid break assert first_client is not None, "No client states available to aggregate." for k, v in global_state.items(): if v.is_floating_point(): new_state[k] = torch.zeros_like(v.detach().cpu(), dtype=torch.float32) else: # For non-float buffers, just copy from the first client (or keep global) new_state[k] = self.client_state[first_client][k].clone() # Accumulate floating tensors with weights; keep non-floats as assigned above for cid in self.selected_clients: if cid not in self.client_state: continue weight = self.client_n_data[cid] / total_n cst = self.client_state[cid] for k in new_state.keys(): if new_state[k].is_floating_point(): # cst[k] is CPU; ensure float32 for accumulation new_state[k].add_(cst[k].to(torch.float32), alpha=weight) # weighted average losses for lk, lv in self.client_loss[cid].items(): avg_loss[lk] = avg_loss.get(lk, 0.0) + float(lv) * weight # Load aggregated state back into the global model (model is on CPU) with torch.no_grad(): self.model.load_state_dict(new_state, strict=True) self.round += 1 # Return CPU state_dict (good for broadcasting to clients) return {k: v.clone() for k, v in self.model.state_dict().items()}, avg_loss, int(self.n_data) def rec(self, name, state_dict, n_data, loss_dict): """ Receive local update from a client. - Store all floating tensors as CPU float32 - Store non-floating tensors (e.g., BN counters) as CPU in original dtype """ self.n_data += n_data safe_state = {} with torch.no_grad(): for k, v in state_dict.items(): t = v.detach().cpu() if t.is_floating_point(): t = t.to(torch.float32) safe_state[k] = t self.client_state[name] = safe_state self.client_n_data[name] = int(n_data) self.client_loss[name] = {k: float(v) for k, v in loss_dict.items()} def flush(self): """Clear stored client updates.""" self.n_data = 0 self.client_state.clear() self.client_n_data.clear() self.client_loss.clear()