Files
fed-yolo/fed_algo_cs/server_base.py

179 lines
6.3 KiB
Python
Raw Normal View History

2025-10-02 16:26:27 +08:00
import numpy as np
import torch
from torch.utils.data import DataLoader
from utils.fed_util import init_model
from utils.dataset import Dataset
from utils import util
class FedYoloServer(object):
def __init__(self, client_list, model_name, params):
"""
Federated YOLO Server
Args:
client_list: list of connected clients
model_name: YOLO model architecture name
params: dict of hyperparameters (must include 'names')
"""
# Track client updates
self.client_state = {}
self.client_loss = {}
self.client_n_data = {}
self.selected_clients = []
self._batch_size = params.get("val_batch_size", 4)
self.client_list = client_list
self.valset = None
# Federated bookkeeping
self.round = 0
# Total number of classes
self.n_data = 0
# Device
gpu = 0
self._device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")
# Global model
self._num_classes = len(params["names"])
self.model_name = model_name
self.model = init_model(model_name, self._num_classes)
self.params = params
def load_valset(self, valset):
"""Server loads the validation dataset."""
self.valset = valset
def state_dict(self):
"""Return global model weights."""
return self.model.state_dict()
@torch.no_grad()
def test(self, args):
"""
Evaluate global model on validation set using YOLO metrics (mAP, precision, recall).
Returns:
dict with {"mAP": ..., "mAP50": ..., "precision": ..., "recall": ...}
"""
if self.valset is None:
return {}
loader = DataLoader(
self.valset,
batch_size=self._batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
collate_fn=Dataset.collate_fn,
)
self.model.to(self._device).eval().half()
iou_v = torch.linspace(0.5, 0.95, 10).to(self._device) # IoU thresholds
n_iou = iou_v.numel()
metrics = []
for samples, targets in loader:
samples = samples.to(self._device).half() / 255.0
_, _, h, w = samples.shape
scale = torch.tensor((w, h, w, h)).to(self._device)
outputs = self.model(samples)
outputs = util.non_max_suppression(outputs)
for i, output in enumerate(outputs):
idx = targets["idx"] == i
cls = targets["cls"][idx].to(self._device)
box = targets["box"][idx].to(self._device)
metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=self._device)
if output.shape[0] == 0:
if cls.shape[0]:
metrics.append((metric, *torch.zeros((2, 0), device=self._device), cls.squeeze(-1)))
continue
if cls.shape[0]:
cls_tensor = cls if isinstance(cls, torch.Tensor) else torch.tensor(cls, device=self._device)
if cls_tensor.dim() == 1:
cls_tensor = cls_tensor.unsqueeze(1)
box_xy = util.wh2xy(box)
if not isinstance(box_xy, torch.Tensor):
box_xy = torch.tensor(box_xy, device=self._device)
target = torch.cat((cls_tensor, box_xy * scale), dim=1)
metric = util.compute_metric(output[:, :6], target, iou_v)
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
# Compute metrics
if not metrics:
return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0}
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)]
if len(metrics) and metrics[0].any():
_, _, prec, rec, map50, mean_ap = util.compute_ap(*metrics, names=self.params["names"], plot=False)
else:
prec, rec, map50, mean_ap = 0, 0, 0, 0
# Back to float32 for further training
self.model.float()
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
def select_clients(self, connection_ratio=1.0):
"""Randomly select a fraction of clients."""
self.selected_clients = []
self.n_data = 0
for client_id in self.client_list:
if np.random.rand() <= connection_ratio:
self.selected_clients.append(client_id)
self.n_data += self.client_n_data.get(client_id, 0)
def agg(self):
"""Aggregate client updates (FedAvg)."""
if len(self.selected_clients) == 0 or self.n_data == 0:
return self.model.state_dict(), {}, 0
model = init_model(self.model_name, self._num_classes)
model_state = model.state_dict()
avg_loss = {}
for i, name in enumerate(self.selected_clients):
if name not in self.client_state:
continue
weight = self.client_n_data[name] / self.n_data
for key in model_state.keys():
if i == 0:
model_state[key] = self.client_state[name][key] * weight
else:
model_state[key] += self.client_state[name][key] * weight
# Weighted average losses
for k, v in self.client_loss[name].items():
avg_loss[k] = avg_loss.get(k, 0.0) + v * weight
self.model.load_state_dict(model_state, strict=True)
self.round += 1
return model_state, avg_loss, self.n_data
def rec(self, name, state_dict, n_data, loss_dict):
"""
Receive local update from a client.
Args:
name: client ID
state_dict: state dictionary of the local model
n_data: number of data samples used in local training
loss_dict: dict of losses from local training
"""
self.n_data += n_data
self.client_state[name] = {k: v.cpu() for k, v in state_dict.items()}
self.client_n_data[name] = n_data
self.client_loss[name] = loss_dict
def flush(self):
"""Clear stored client updates."""
self.n_data = 0
self.client_state.clear()
self.client_n_data.clear()
self.client_loss.clear()