Compare commits
8 Commits
b19f11125d
...
main
Author | SHA1 | Date | |
---|---|---|---|
101ffa51eb | |||
beaa290c19 | |||
86c7579b42 | |||
33586e0c0c | |||
9a5e6b5b71 | |||
964a8024c0 | |||
0b52cfc4f5 | |||
c2e538898c |
21
README.md
21
README.md
@@ -1,3 +1,24 @@
|
|||||||
# fed-yolo
|
# fed-yolo
|
||||||
|
|
||||||
Combine Federated Learning with YOLOv11.
|
Combine Federated Learning with YOLOv11.
|
||||||
|
|
||||||
|
## requirements
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## how to run
|
||||||
|
```bash
|
||||||
|
nohup python fed_run.py > train.log 2>&1 &
|
||||||
|
```
|
||||||
|
|
||||||
|
## results
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
- Add more FL algorithms (e.g., FedProx, FedAvgM, etc.)
|
||||||
|
- Implement FedProx
|
||||||
|
- Implement SCAFFOLD
|
||||||
|
- Implement FedNova
|
||||||
|
- Add more YOLO versions (e.g., YOLOv8, YOLOv5, etc.)
|
||||||
|
- Implement YOLOv8
|
||||||
|
- Implement YOLOv5
|
126
config/coco128_cfg.yaml
Normal file
126
config/coco128_cfg.yaml
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
# global system:
|
||||||
|
fed_algo: "FedAvg" # federated learning algorithm
|
||||||
|
model_name: "yolo_v11_n" # yolo_v11_n, yolo_v11_t, yolo_v11_s, yolo_v11_m, yolo_v11_l, yolo_v11_x
|
||||||
|
i_seed: 202509 # initial random seed
|
||||||
|
|
||||||
|
num_client: 5 # total number of clients
|
||||||
|
num_round: 5 # total number of communication rounds
|
||||||
|
num_local_class: 80 # number of classes per client
|
||||||
|
|
||||||
|
res_root: "results" # root directory for results
|
||||||
|
dataset_path: "/mnt/DATA/COCO128/"
|
||||||
|
# train_txt: "train.txt" # path to training set txt file
|
||||||
|
# val_txt: "val.txt" # path to validation set txt file
|
||||||
|
# test_txt: "test.txt" # path to test set txt file
|
||||||
|
|
||||||
|
local_batch_size: 32 # local training batch size
|
||||||
|
val_batch_size: 4 # validation batch size
|
||||||
|
|
||||||
|
num_workers: 4 # number of data loader workers
|
||||||
|
min_data: 128 # minimum number of images per client
|
||||||
|
max_data: 128 # maximum number of images per client
|
||||||
|
partition_mode: "overlap" # "overlap" or "disjoint"
|
||||||
|
connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients
|
||||||
|
|
||||||
|
# local training:
|
||||||
|
min_lr: 0.000100000000 # initial learning rate
|
||||||
|
max_lr: 0.010000000000 # maximum learning rate
|
||||||
|
momentum: 0.9370000000 # SGD momentum/Adam beta1
|
||||||
|
weight_decay: 0.000500 # optimizer weight decay
|
||||||
|
|
||||||
|
warmup_epochs: 3.00000 # warmup epochs
|
||||||
|
box: 7.500000000000000 # box loss gain
|
||||||
|
cls: 0.500000000000000 # cls loss gain
|
||||||
|
dfl: 1.500000000000000 # dfl loss gain
|
||||||
|
hsv_h: 0.0150000000000 # image HSV-Hue augmentation (fraction)
|
||||||
|
hsv_s: 0.7000000000000 # image HSV-Saturation augmentation (fraction)
|
||||||
|
hsv_v: 0.4000000000000 # image HSV-Value augmentation (fraction)
|
||||||
|
degrees: 0.00000000000 # image rotation (+/- deg)
|
||||||
|
translate: 0.100000000 # image translation (+/- fraction)
|
||||||
|
scale: 0.5000000000000 # image scale (+/- gain)
|
||||||
|
shear: 0.0000000000000 # image shear (+/- deg)
|
||||||
|
flip_ud: 0.00000000000 # image flip up-down (probability)
|
||||||
|
flip_lr: 0.50000000000 # image flip left-right (probability)
|
||||||
|
mosaic: 1.000000000000 # image mosaic (probability)
|
||||||
|
mix_up: 0.000000000000 # image mix-up (probability)
|
||||||
|
names:
|
||||||
|
0: person
|
||||||
|
1: bicycle
|
||||||
|
2: car
|
||||||
|
3: motorcycle
|
||||||
|
4: airplane
|
||||||
|
5: bus
|
||||||
|
6: train
|
||||||
|
7: truck
|
||||||
|
8: boat
|
||||||
|
9: traffic light
|
||||||
|
10: fire hydrant
|
||||||
|
11: stop sign
|
||||||
|
12: parking meter
|
||||||
|
13: bench
|
||||||
|
14: bird
|
||||||
|
15: cat
|
||||||
|
16: dog
|
||||||
|
17: horse
|
||||||
|
18: sheep
|
||||||
|
19: cow
|
||||||
|
20: elephant
|
||||||
|
21: bear
|
||||||
|
22: zebra
|
||||||
|
23: giraffe
|
||||||
|
24: backpack
|
||||||
|
25: umbrella
|
||||||
|
26: handbag
|
||||||
|
27: tie
|
||||||
|
28: suitcase
|
||||||
|
29: frisbee
|
||||||
|
30: skis
|
||||||
|
31: snowboard
|
||||||
|
32: sports ball
|
||||||
|
33: kite
|
||||||
|
34: baseball bat
|
||||||
|
35: baseball glove
|
||||||
|
36: skateboard
|
||||||
|
37: surfboard
|
||||||
|
38: tennis racket
|
||||||
|
39: bottle
|
||||||
|
40: wine glass
|
||||||
|
41: cup
|
||||||
|
42: fork
|
||||||
|
43: knife
|
||||||
|
44: spoon
|
||||||
|
45: bowl
|
||||||
|
46: banana
|
||||||
|
47: apple
|
||||||
|
48: sandwich
|
||||||
|
49: orange
|
||||||
|
50: broccoli
|
||||||
|
51: carrot
|
||||||
|
52: hot dog
|
||||||
|
53: pizza
|
||||||
|
54: donut
|
||||||
|
55: cake
|
||||||
|
56: chair
|
||||||
|
57: couch
|
||||||
|
58: potted plant
|
||||||
|
59: bed
|
||||||
|
60: dining table
|
||||||
|
61: toilet
|
||||||
|
62: tv
|
||||||
|
63: laptop
|
||||||
|
64: mouse
|
||||||
|
65: remote
|
||||||
|
66: keyboard
|
||||||
|
67: cell phone
|
||||||
|
68: microwave
|
||||||
|
69: oven
|
||||||
|
70: toaster
|
||||||
|
71: sink
|
||||||
|
72: refrigerator
|
||||||
|
73: book
|
||||||
|
74: clock
|
||||||
|
75: vase
|
||||||
|
76: scissors
|
||||||
|
77: teddy bear
|
||||||
|
78: hair drier
|
||||||
|
79: toothbrush
|
@@ -3,6 +3,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils import data
|
from torch.utils import data
|
||||||
from torch.amp.autocast_mode import autocast
|
from torch.amp.autocast_mode import autocast
|
||||||
|
from tqdm import tqdm
|
||||||
from utils.fed_util import init_model
|
from utils.fed_util import init_model
|
||||||
from utils import util
|
from utils import util
|
||||||
from utils.dataset import Dataset
|
from utils.dataset import Dataset
|
||||||
@@ -152,7 +153,6 @@ class FedYoloClient(object):
|
|||||||
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
num_steps = max(1, len(loader))
|
num_steps = max(1, len(loader))
|
||||||
# print(len(loader))
|
|
||||||
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
|
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
|
||||||
# DDP mode
|
# DDP mode
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
@@ -167,7 +167,12 @@ class FedYoloClient(object):
|
|||||||
amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True)
|
amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True)
|
||||||
criterion = util.ComputeLoss(self.model, self.params)
|
criterion = util.ComputeLoss(self.model, self.params)
|
||||||
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
# log
|
||||||
|
# if args.local_rank == 0:
|
||||||
|
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
|
||||||
|
# print("\n" + header)
|
||||||
|
# p_bar = tqdm(total=args.epochs * num_steps, ncols=120)
|
||||||
|
# p_bar.set_description(f"{self.name:>10}")
|
||||||
|
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
@@ -180,10 +185,20 @@ class FedYoloClient(object):
|
|||||||
ds = cast(Dataset, loader.dataset)
|
ds = cast(Dataset, loader.dataset)
|
||||||
ds.mosaic = False
|
ds.mosaic = False
|
||||||
|
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
avg_box_loss = util.AverageMeter()
|
avg_box_loss = util.AverageMeter()
|
||||||
avg_cls_loss = util.AverageMeter()
|
avg_cls_loss = util.AverageMeter()
|
||||||
avg_dfl_loss = util.AverageMeter()
|
avg_dfl_loss = util.AverageMeter()
|
||||||
|
|
||||||
|
# # --- header (once per epoch, YOLO-style) ---
|
||||||
|
# if args.local_rank == 0:
|
||||||
|
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
|
||||||
|
# print("\n" + header)
|
||||||
|
|
||||||
|
# p_bar = enumerate(loader)
|
||||||
|
# if args.local_rank == 0:
|
||||||
|
# p_bar = tqdm(p_bar, total=num_steps, ncols=120)
|
||||||
|
|
||||||
for i, (samples, targets) in enumerate(loader):
|
for i, (samples, targets) in enumerate(loader):
|
||||||
global_step = i + num_steps * epoch
|
global_step = i + num_steps * epoch
|
||||||
scheduler.step(step=global_step, optimizer=optimizer)
|
scheduler.step(step=global_step, optimizer=optimizer)
|
||||||
@@ -195,24 +210,26 @@ class FedYoloClient(object):
|
|||||||
outputs = self.model(samples)
|
outputs = self.model(samples)
|
||||||
box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
|
box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
|
||||||
|
|
||||||
# meters (use the *unscaled* values)
|
# meters (use the *unscaled* values)
|
||||||
bs = samples.size(0)
|
bs = samples.size(0)
|
||||||
avg_box_loss.update(box_loss.item(), bs)
|
avg_box_loss.update(box_loss.item(), bs)
|
||||||
avg_cls_loss.update(cls_loss.item(), bs)
|
avg_cls_loss.update(cls_loss.item(), bs)
|
||||||
avg_dfl_loss.update(dfl_loss.item(), bs)
|
avg_dfl_loss.update(dfl_loss.item(), bs)
|
||||||
|
|
||||||
# scale losses by batch/world if your loss is averaged internally per-sample/device
|
# scale losses by batch/world if your loss is averaged internally per-sample/device
|
||||||
box_loss = box_loss * self._batch_size * args.world_size
|
# box_loss = box_loss * self._batch_size * args.world_size
|
||||||
cls_loss = cls_loss * self._batch_size * args.world_size
|
# cls_loss = cls_loss * self._batch_size * args.world_size
|
||||||
dfl_loss = dfl_loss * self._batch_size * args.world_size
|
# dfl_loss = dfl_loss * self._batch_size * args.world_size
|
||||||
|
|
||||||
total_loss = box_loss + cls_loss + dfl_loss
|
total_loss = box_loss + cls_loss + dfl_loss
|
||||||
|
|
||||||
# Backward
|
# Backward
|
||||||
amp_scale.scale(total_loss).backward()
|
amp_scale.scale(total_loss).backward()
|
||||||
|
|
||||||
# Optimize
|
# Optimize
|
||||||
if (i + 1) % accumulate == 0:
|
if (i + 1) % accumulate == 0:
|
||||||
|
amp_scale.unscale_(optimizer) # unscale gradients
|
||||||
|
util.clip_gradients(model=self.model, max_norm=10.0) # clip gradients
|
||||||
amp_scale.step(optimizer)
|
amp_scale.step(optimizer)
|
||||||
amp_scale.update()
|
amp_scale.update()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
@@ -221,13 +238,28 @@ class FedYoloClient(object):
|
|||||||
|
|
||||||
# torch.cuda.synchronize()
|
# torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# tqdm update
|
||||||
|
# if args.local_rank == 0:
|
||||||
|
# mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
|
||||||
|
# desc = ("%10s" * 2 + "%10.4g" * 3) % (
|
||||||
|
# self.name,
|
||||||
|
# mem,
|
||||||
|
# avg_box_loss.avg,
|
||||||
|
# avg_cls_loss.avg,
|
||||||
|
# avg_dfl_loss.avg,
|
||||||
|
# )
|
||||||
|
# cast(tqdm, p_bar).set_description(desc)
|
||||||
|
# p_bar.update(1)
|
||||||
|
|
||||||
|
# p_bar.close()
|
||||||
|
|
||||||
# clean
|
# clean
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
torch.distributed.destroy_process_group()
|
torch.distributed.destroy_process_group()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return (
|
return (
|
||||||
self.model.state_dict(),
|
self.model.state_dict() if not ema else ema.ema.state_dict(),
|
||||||
self.n_data,
|
self.n_data,
|
||||||
{"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg},
|
{"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg},
|
||||||
)
|
)
|
||||||
|
@@ -49,11 +49,11 @@ class FedYoloServer(object):
|
|||||||
return self.model.state_dict()
|
return self.model.state_dict()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test(self, args):
|
def test(self, args) -> dict:
|
||||||
"""
|
"""
|
||||||
Evaluate global model on validation set using YOLO metrics (mAP, precision, recall).
|
Test the global model on the server's validation set.
|
||||||
Returns:
|
Returns:
|
||||||
dict with {"mAP": ..., "mAP50": ..., "precision": ..., "recall": ...}
|
dict with keys: mAP, mAP50, precision, recall
|
||||||
"""
|
"""
|
||||||
if self.valset is None:
|
if self.valset is None:
|
||||||
return {}
|
return {}
|
||||||
@@ -67,46 +67,47 @@ class FedYoloServer(object):
|
|||||||
collate_fn=Dataset.collate_fn,
|
collate_fn=Dataset.collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model.to(self._device).eval().half()
|
dev = self._device
|
||||||
|
# move to device for eval; keep in float32 for stability
|
||||||
|
self.model.eval().to(dev).float()
|
||||||
|
|
||||||
iou_v = torch.linspace(0.5, 0.95, 10).to(self._device) # IoU thresholds
|
iou_v = torch.linspace(0.5, 0.95, 10, device=dev)
|
||||||
n_iou = iou_v.numel()
|
n_iou = iou_v.numel()
|
||||||
metrics = []
|
metrics = []
|
||||||
|
|
||||||
for samples, targets in loader:
|
for samples, targets in loader:
|
||||||
samples = samples.to(self._device).half() / 255.0
|
samples = samples.to(dev, non_blocking=True).float() / 255.0
|
||||||
_, _, h, w = samples.shape
|
_, _, h, w = samples.shape
|
||||||
scale = torch.tensor((w, h, w, h)).to(self._device)
|
scale = torch.tensor((w, h, w, h), device=dev)
|
||||||
|
|
||||||
outputs = self.model(samples)
|
outputs = self.model(samples)
|
||||||
outputs = util.non_max_suppression(outputs)
|
outputs = util.non_max_suppression(outputs)
|
||||||
|
|
||||||
for i, output in enumerate(outputs):
|
for i, output in enumerate(outputs):
|
||||||
idx = targets["idx"] == i
|
idx = targets["idx"] == i
|
||||||
cls = targets["cls"][idx].to(self._device)
|
cls = targets["cls"][idx].to(dev)
|
||||||
box = targets["box"][idx].to(self._device)
|
box = targets["box"][idx].to(dev)
|
||||||
|
|
||||||
metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=self._device)
|
|
||||||
|
|
||||||
|
metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=dev)
|
||||||
if output.shape[0] == 0:
|
if output.shape[0] == 0:
|
||||||
if cls.shape[0]:
|
if cls.shape[0]:
|
||||||
metrics.append((metric, *torch.zeros((2, 0), device=self._device), cls.squeeze(-1)))
|
metrics.append((metric, *torch.zeros((2, 0), device=dev), cls.squeeze(-1)))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if cls.shape[0]:
|
if cls.shape[0]:
|
||||||
cls_tensor = cls if isinstance(cls, torch.Tensor) else torch.tensor(cls, device=self._device)
|
if cls.dim() == 1:
|
||||||
if cls_tensor.dim() == 1:
|
cls = cls.unsqueeze(1)
|
||||||
cls_tensor = cls_tensor.unsqueeze(1)
|
|
||||||
box_xy = util.wh2xy(box)
|
box_xy = util.wh2xy(box)
|
||||||
if not isinstance(box_xy, torch.Tensor):
|
if not isinstance(box_xy, torch.Tensor):
|
||||||
box_xy = torch.tensor(box_xy, device=self._device)
|
box_xy = torch.tensor(box_xy, device=dev)
|
||||||
target = torch.cat((cls_tensor, box_xy * scale), dim=1)
|
target = torch.cat((cls, box_xy * scale), dim=1)
|
||||||
metric = util.compute_metric(output[:, :6], target, iou_v)
|
metric = util.compute_metric(output[:, :6], target, iou_v)
|
||||||
|
|
||||||
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
|
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
|
||||||
|
|
||||||
# Compute metrics
|
|
||||||
if not metrics:
|
if not metrics:
|
||||||
|
# move back to CPU before returning
|
||||||
|
self.model.to("cpu").float()
|
||||||
return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0}
|
return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0}
|
||||||
|
|
||||||
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)]
|
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)]
|
||||||
@@ -115,60 +116,94 @@ class FedYoloServer(object):
|
|||||||
else:
|
else:
|
||||||
prec, rec, map50, mean_ap = 0, 0, 0, 0
|
prec, rec, map50, mean_ap = 0, 0, 0, 0
|
||||||
|
|
||||||
# Back to float32 for further training
|
# return model to CPU so next agg() stays device-consistent
|
||||||
self.model.float()
|
self.model.to("cpu").float()
|
||||||
|
|
||||||
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
|
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
|
||||||
|
|
||||||
def select_clients(self, connection_ratio=1.0):
|
def select_clients(self, connection_ratio=1.0):
|
||||||
"""Randomly select a fraction of clients."""
|
"""
|
||||||
|
Randomly select a fraction of clients.
|
||||||
|
Args:
|
||||||
|
connection_ratio: fraction of clients to select (0 < connection_ratio <= 1)
|
||||||
|
"""
|
||||||
self.selected_clients = []
|
self.selected_clients = []
|
||||||
self.n_data = 0
|
self.n_data = 0
|
||||||
for client_id in self.client_list:
|
for client_id in self.client_list:
|
||||||
|
# Random selection based on connection ratio
|
||||||
if np.random.rand() <= connection_ratio:
|
if np.random.rand() <= connection_ratio:
|
||||||
self.selected_clients.append(client_id)
|
self.selected_clients.append(client_id)
|
||||||
self.n_data += self.client_n_data.get(client_id, 0)
|
self.n_data += self.client_n_data.get(client_id, 0)
|
||||||
|
|
||||||
def agg(self):
|
def agg(self):
|
||||||
"""Aggregate client updates (FedAvg)."""
|
"""Aggregate client updates (FedAvg) on CPU/FP32, preserving non-float buffers."""
|
||||||
if len(self.selected_clients) == 0 or self.n_data == 0:
|
if len(self.selected_clients) == 0 or self.n_data == 0:
|
||||||
return self.model.state_dict(), {}, 0
|
return self.model.state_dict(), {}, 0
|
||||||
|
|
||||||
model = init_model(self.model_name, self._num_classes)
|
# Ensure global model is on CPU for safe load later
|
||||||
model_state = model.state_dict()
|
self.model.to("cpu")
|
||||||
|
global_state = self.model.state_dict() # may hold CPU or CUDA refs; we’re on CPU now
|
||||||
|
|
||||||
avg_loss = {}
|
avg_loss = {}
|
||||||
for i, name in enumerate(self.selected_clients):
|
total_n = float(self.n_data)
|
||||||
if name not in self.client_state:
|
|
||||||
|
# Prepare accumulators on CPU. For floating tensors, use float32 zeros.
|
||||||
|
# For non-floating tensors (e.g., BN num_batches_tracked int64), we’ll copy from the first client.
|
||||||
|
new_state = {}
|
||||||
|
first_client = None
|
||||||
|
for cid in self.selected_clients:
|
||||||
|
if cid in self.client_state:
|
||||||
|
first_client = cid
|
||||||
|
break
|
||||||
|
|
||||||
|
assert first_client is not None, "No client states available to aggregate."
|
||||||
|
|
||||||
|
for k, v in global_state.items():
|
||||||
|
if v.is_floating_point():
|
||||||
|
new_state[k] = torch.zeros_like(v.detach().cpu(), dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
# For non-float buffers, just copy from the first client (or keep global)
|
||||||
|
new_state[k] = self.client_state[first_client][k].clone()
|
||||||
|
|
||||||
|
# Accumulate floating tensors with weights; keep non-floats as assigned above
|
||||||
|
for cid in self.selected_clients:
|
||||||
|
if cid not in self.client_state:
|
||||||
continue
|
continue
|
||||||
weight = self.client_n_data[name] / self.n_data
|
weight = self.client_n_data[cid] / total_n
|
||||||
for key in model_state.keys():
|
cst = self.client_state[cid]
|
||||||
if i == 0:
|
for k in new_state.keys():
|
||||||
model_state[key] = self.client_state[name][key] * weight
|
if new_state[k].is_floating_point():
|
||||||
else:
|
# cst[k] is CPU; ensure float32 for accumulation
|
||||||
model_state[key] += self.client_state[name][key] * weight
|
new_state[k].add_(cst[k].to(torch.float32), alpha=weight)
|
||||||
|
|
||||||
# Weighted average losses
|
# weighted average losses
|
||||||
for k, v in self.client_loss[name].items():
|
for lk, lv in self.client_loss[cid].items():
|
||||||
avg_loss[k] = avg_loss.get(k, 0.0) + v * weight
|
avg_loss[lk] = avg_loss.get(lk, 0.0) + float(lv) * weight
|
||||||
|
|
||||||
|
# Load aggregated state back into the global model (model is on CPU)
|
||||||
|
with torch.no_grad():
|
||||||
|
self.model.load_state_dict(new_state, strict=True)
|
||||||
|
|
||||||
self.model.load_state_dict(model_state, strict=True)
|
|
||||||
self.round += 1
|
self.round += 1
|
||||||
return model_state, avg_loss, self.n_data
|
# Return CPU state_dict (good for broadcasting to clients)
|
||||||
|
return {k: v.clone() for k, v in self.model.state_dict().items()}, avg_loss, int(self.n_data)
|
||||||
|
|
||||||
def rec(self, name, state_dict, n_data, loss_dict):
|
def rec(self, name, state_dict, n_data, loss_dict):
|
||||||
"""
|
"""
|
||||||
Receive local update from a client.
|
Receive local update from a client.
|
||||||
Args:
|
- Store all floating tensors as CPU float32
|
||||||
name: client ID
|
- Store non-floating tensors (e.g., BN counters) as CPU in original dtype
|
||||||
state_dict: state dictionary of the local model
|
|
||||||
n_data: number of data samples used in local training
|
|
||||||
loss_dict: dict of losses from local training
|
|
||||||
"""
|
"""
|
||||||
self.n_data += n_data
|
self.n_data += n_data
|
||||||
self.client_state[name] = {k: v.cpu() for k, v in state_dict.items()}
|
safe_state = {}
|
||||||
self.client_n_data[name] = n_data
|
with torch.no_grad():
|
||||||
self.client_loss[name] = loss_dict
|
for k, v in state_dict.items():
|
||||||
|
t = v.detach().cpu()
|
||||||
|
if t.is_floating_point():
|
||||||
|
t = t.to(torch.float32)
|
||||||
|
safe_state[k] = t
|
||||||
|
self.client_state[name] = safe_state
|
||||||
|
self.client_n_data[name] = int(n_data)
|
||||||
|
self.client_loss[name] = {k: float(v) for k, v in loss_dict.items()}
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
"""Clear stored client updates."""
|
"""Clear stored client updates."""
|
||||||
|
39
fed_run.py
39
fed_run.py
@@ -13,8 +13,8 @@ import matplotlib.pyplot as plt
|
|||||||
from utils.dataset import Dataset
|
from utils.dataset import Dataset
|
||||||
from fed_algo_cs.client_base import FedYoloClient
|
from fed_algo_cs.client_base import FedYoloClient
|
||||||
from fed_algo_cs.server_base import FedYoloServer
|
from fed_algo_cs.server_base import FedYoloServer
|
||||||
from utils.args import args_parser # your args parser
|
from utils.args import args_parser # args parser
|
||||||
from utils.fed_util import divide_trainset # divide_trainset is yours
|
from utils.fed_util import divide_trainset # divide_trainset
|
||||||
|
|
||||||
|
|
||||||
def _read_list_file(txt_path: str):
|
def _read_list_file(txt_path: str):
|
||||||
@@ -132,7 +132,7 @@ def fed_run():
|
|||||||
num_client=int(cfg.get("num_client", 64)),
|
num_client=int(cfg.get("num_client", 64)),
|
||||||
min_data=int(cfg.get("min_data", 100)),
|
min_data=int(cfg.get("min_data", 100)),
|
||||||
max_data=int(cfg.get("max_data", 100)),
|
max_data=int(cfg.get("max_data", 100)),
|
||||||
mode=str(cfg.get("partition_mode", "disjoint")), # "overlap" or "disjoint"
|
mode=str(cfg.get("partition_mode", "overlap")), # "overlap" or "disjoint"
|
||||||
seed=int(cfg.get("i_seed", 0)),
|
seed=int(cfg.get("i_seed", 0)),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -142,6 +142,7 @@ def fed_run():
|
|||||||
# --- build clients ---
|
# --- build clients ---
|
||||||
model_name = cfg.get("model_name", "yolo_v11_n")
|
model_name = cfg.get("model_name", "yolo_v11_n")
|
||||||
clients = {}
|
clients = {}
|
||||||
|
|
||||||
for uid in users:
|
for uid in users:
|
||||||
c = FedYoloClient(name=uid, model_name=model_name, params=params)
|
c = FedYoloClient(name=uid, model_name=model_name, params=params)
|
||||||
c.load_trainset(user_data[uid]["filename"])
|
c.load_trainset(user_data[uid]["filename"])
|
||||||
@@ -176,11 +177,16 @@ def fed_run():
|
|||||||
res_root = cfg.get("res_root", "results")
|
res_root = cfg.get("res_root", "results")
|
||||||
os.makedirs(res_root, exist_ok=True)
|
os.makedirs(res_root, exist_ok=True)
|
||||||
|
|
||||||
for rnd in tqdm(range(num_round), desc="main federal loop round"):
|
# tqdm logging
|
||||||
t0 = time.time()
|
header = ("%10s" * 2) % ("Round", "client")
|
||||||
|
tqdm.write("\n" + header)
|
||||||
|
p_bar = tqdm(total=num_round, ncols=160, ascii="->>")
|
||||||
|
|
||||||
|
for rnd in range(num_round):
|
||||||
|
t0 = time.time()
|
||||||
# Local training (sequential over all users)
|
# Local training (sequential over all users)
|
||||||
for uid in tqdm(users, desc=f"Round {rnd + 1} local training", leave=False):
|
for uid in users:
|
||||||
|
p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}"))
|
||||||
client = clients[uid] # FedYoloClient instance
|
client = clients[uid] # FedYoloClient instance
|
||||||
client.update(global_state) # load global weights
|
client.update(global_state) # load global weights
|
||||||
state_dict, n_data, loss_dict = client.train(args_cli) # local training
|
state_dict, n_data, loss_dict = client.train(args_cli) # local training
|
||||||
@@ -213,11 +219,18 @@ def fed_run():
|
|||||||
history["train_loss"].append(scalar_train_loss)
|
history["train_loss"].append(scalar_train_loss)
|
||||||
history["round_time_sec"].append(time.time() - t0)
|
history["round_time_sec"].append(time.time() - t0)
|
||||||
|
|
||||||
print(
|
# Log GPU memory usage
|
||||||
f"[round {rnd + 1:04d}] "
|
# gpu_mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
|
||||||
f"loss={scalar_train_loss:.4f} mAP50-95={mAP:.4f} mAP50={mAP50:.4f} "
|
# tqdm update
|
||||||
f"P={precision:.4f} R={recall:.4f}"
|
desc = {
|
||||||
)
|
"loss": f"{scalar_train_loss:.6g}",
|
||||||
|
"mAP50": f"{mAP50:.6g}",
|
||||||
|
"mAP": f"{mAP:.6g}",
|
||||||
|
"precision": f"{precision:.6g}",
|
||||||
|
"recall": f"{recall:.6g}",
|
||||||
|
# "gpu_mem": gpu_mem,
|
||||||
|
}
|
||||||
|
p_bar.set_postfix(desc)
|
||||||
|
|
||||||
# Save running JSON (resumable logs)
|
# Save running JSON (resumable logs)
|
||||||
save_name = (
|
save_name = (
|
||||||
@@ -230,6 +243,10 @@ def fed_run():
|
|||||||
with open(out_json, "w", encoding="utf-8") as f:
|
with open(out_json, "w", encoding="utf-8") as f:
|
||||||
json.dump(history, f, indent=2)
|
json.dump(history, f, indent=2)
|
||||||
|
|
||||||
|
p_bar.update(1)
|
||||||
|
|
||||||
|
p_bar.close()
|
||||||
|
|
||||||
# --- final plot ---
|
# --- final plot ---
|
||||||
_plot_curves(res_root, history)
|
_plot_curves(res_root, history)
|
||||||
print("[done] training complete.")
|
print("[done] training complete.")
|
||||||
|
@@ -25,6 +25,10 @@ def _parse_yolo_label_file(label_path: str) -> Set[int]:
|
|||||||
Return a set of class_ids found in a YOLO .txt label file.
|
Return a set of class_ids found in a YOLO .txt label file.
|
||||||
Empty file -> empty set. Missing file -> empty set.
|
Empty file -> empty set. Missing file -> empty set.
|
||||||
Robust to blank lines / trailing spaces.
|
Robust to blank lines / trailing spaces.
|
||||||
|
Args:
|
||||||
|
label_path: path to the label file
|
||||||
|
Returns:
|
||||||
|
set of class IDs (integers) found in the file
|
||||||
"""
|
"""
|
||||||
class_ids: Set[int] = set()
|
class_ids: Set[int] = set()
|
||||||
if not os.path.exists(label_path):
|
if not os.path.exists(label_path):
|
||||||
|
@@ -151,7 +151,7 @@ def non_max_suppression(outputs, confidence_threshold=0.001, iou_threshold=0.65)
|
|||||||
box = wh2xy(box) # (cx, cy, w, h) to (x1, y1, x2, y2)
|
box = wh2xy(box) # (cx, cy, w, h) to (x1, y1, x2, y2)
|
||||||
if nc > 1:
|
if nc > 1:
|
||||||
i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T
|
i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T
|
||||||
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float()), dim=1)
|
x = torch.cat((box[i], x[i, 4 + j].unsqueeze(1), j[:, None].float()), dim=1)
|
||||||
else: # best class only
|
else: # best class only
|
||||||
conf, j = cls.max(1, keepdim=True)
|
conf, j = cls.max(1, keepdim=True)
|
||||||
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold]
|
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold]
|
||||||
@@ -296,7 +296,8 @@ def compute_ap(tp, conf, output, target, plot=False, names=(), eps=1e-16):
|
|||||||
|
|
||||||
# Integrate area under curve
|
# Integrate area under curve
|
||||||
x = numpy.linspace(start=0, stop=1, num=101) # 101-point interp (COCO)
|
x = numpy.linspace(start=0, stop=1, num=101) # 101-point interp (COCO)
|
||||||
ap[ci, j] = numpy.trapz(numpy.interp(x, m_rec, m_pre), x) # integrate
|
# numpy.trapz is deprecated in numpy 2.0.0 or after version, use numpy.trapezoid instead
|
||||||
|
ap[ci, j] = numpy.trapezoid(numpy.interp(x, m_rec, m_pre), x) # integrate
|
||||||
if plot and j == 0:
|
if plot and j == 0:
|
||||||
py.append(numpy.interp(px, m_rec, m_pre)) # precision at mAP@0.5
|
py.append(numpy.interp(px, m_rec, m_pre)) # precision at mAP@0.5
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user