Files
fed-yolo/fed_algo_cs/client_base.py

266 lines
9.6 KiB
Python
Raw Normal View History

2025-10-02 16:26:27 +08:00
import numpy as np
import torch
from torch import nn
from torch.utils import data
from torch.amp.autocast_mode import autocast
from tqdm import tqdm
2025-10-02 16:26:27 +08:00
from utils.fed_util import init_model
from utils import util
from utils.dataset import Dataset
from typing import cast
class FedYoloClient(object):
def __init__(self, name, model_name, params):
"""
Initialize the client k for federated learning
Args:
:param name: Name of the client k
:param model_name: Name of the model
:param params: config file including the hyperparameters for local training
- batch_size: Local training batch size in the client k
- num_workers: Number of data loader workers
- min_lr: Minimum learning rate
- max_lr: Maximum learning rate
- momentum: Momentum for local training
- weight_decay: Weight decay for local training
"""
self.params = params
# initialize the metadata in local client k
self.target_ip = "127.0.0.3"
self.port = 9999
self.name = name
# initialize the parameters in local client k
self._batch_size = self.params["local_batch_size"]
self._min_lr = self.params["min_lr"]
self._max_lr = self.params["max_lr"]
self._momentum = self.params["momentum"]
self.num_workers = self.params["num_workers"]
self.loss_record = []
# train set length
self.n_data = 0
# initialize the local training and testing dataset
self.train_dataset = None
self.val_dataset = None
# initialize the local model
self._num_classes = len(self.params["names"])
self._weight_decay = self.params["weight_decay"]
self.model_name = model_name
self.model = init_model(model_name, self._num_classes)
model_parameters = filter(lambda p: p.requires_grad, self.model.parameters())
self.parameter_number = sum([np.prod(p.size()) for p in model_parameters])
# GPU
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def load_trainset(self, train_dataset: list[str]):
"""
Load the local training dataset
Args:
:param train_dataset: Training dataset
"""
self.train_dataset = train_dataset
self.n_data = len(self.train_dataset)
def update(self, Global_model_state_dict):
"""
Update the local model with the global model parameters
Args:
:param Global_model_state_dict: State dictionary of the global model
"""
if not hasattr(self, "model") or self.model is None:
self.model = init_model(self.model_name, self._num_classes)
# load the global model parameters
self.model.load_state_dict(Global_model_state_dict, strict=True)
def train(self, args):
"""
Train the local model
Args:
:param args: Command line arguments
- local_rank: Local rank for distributed training
- world_size: World size for distributed training
- distributed: Whether to use distributed training
- input_size: Input size for the model
Returns:
:return: Local updated model, number of local data points, training loss
"""
if args.distributed:
torch.cuda.set_device(device=args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
# print(f"Client {self.name} - distributed training on {world_size} GPUs, local rank: {local_rank}")
# self._device = torch.device("cuda", local_rank)
if args.local_rank == 0:
pass
# if not os.path.exists("weights"):
# os.makedirs("weights")
util.setup_seed()
util.setup_multi_processes()
# model
# init model have been done in __init__()
self.model.to(self._device)
# Optimizer
accumulate = max(round(64 / (self._batch_size * args.world_size)), 1)
self._weight_decay = self._batch_size * args.world_size * accumulate / 64
optimizer = torch.optim.SGD(
util.set_params(self.model, self._weight_decay),
lr=self._min_lr,
momentum=self._momentum,
nesterov=True,
)
# EMA
ema = util.EMA(self.model) if args.local_rank == 0 else None
data_set = Dataset(
filenames=self.train_dataset,
input_size=args.input_size,
params=self.params,
augment=True,
)
if args.distributed:
train_sampler = data.DistributedSampler(
data_set, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
)
else:
train_sampler = None
loader = data.DataLoader(
data_set,
batch_size=self._batch_size,
shuffle=train_sampler is None,
sampler=train_sampler,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=Dataset.collate_fn,
)
# Scheduler
num_steps = max(1, len(loader))
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
# DDP mode
if args.distributed:
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
self.model = nn.parallel.DistributedDataParallel(
module=self.model,
device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=False,
)
amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True)
criterion = util.ComputeLoss(self.model, self.params)
# log
# if args.local_rank == 0:
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
# print("\n" + header)
# p_bar = tqdm(total=args.epochs * num_steps, ncols=120)
# p_bar.set_description(f"{self.name:>10}")
2025-10-02 16:26:27 +08:00
for epoch in range(args.epochs):
self.model.train()
# when distributed, set epoch for shuffling
if args.distributed and train_sampler is not None:
train_sampler.set_epoch(epoch)
if args.epochs - epoch == 10:
# disable mosaic augmentation in the last 10 epochs
ds = cast(Dataset, loader.dataset)
ds.mosaic = False
optimizer.zero_grad(set_to_none=True)
2025-10-02 16:26:27 +08:00
avg_box_loss = util.AverageMeter()
avg_cls_loss = util.AverageMeter()
avg_dfl_loss = util.AverageMeter()
# # --- header (once per epoch, YOLO-style) ---
# if args.local_rank == 0:
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
# print("\n" + header)
# p_bar = enumerate(loader)
# if args.local_rank == 0:
# p_bar = tqdm(p_bar, total=num_steps, ncols=120)
2025-10-02 16:26:27 +08:00
for i, (samples, targets) in enumerate(loader):
global_step = i + num_steps * epoch
scheduler.step(step=global_step, optimizer=optimizer)
samples = samples.cuda(non_blocking=True).float() / 255.0
# Forward
with autocast("cuda", enabled=True):
outputs = self.model(samples)
box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
# meters (use the *unscaled* values)
bs = samples.size(0)
avg_box_loss.update(box_loss.item(), bs)
avg_cls_loss.update(cls_loss.item(), bs)
avg_dfl_loss.update(dfl_loss.item(), bs)
2025-10-02 16:26:27 +08:00
# scale losses by batch/world if your loss is averaged internally per-sample/device
# box_loss = box_loss * self._batch_size * args.world_size
# cls_loss = cls_loss * self._batch_size * args.world_size
# dfl_loss = dfl_loss * self._batch_size * args.world_size
2025-10-02 16:26:27 +08:00
total_loss = box_loss + cls_loss + dfl_loss
2025-10-02 16:26:27 +08:00
# Backward
amp_scale.scale(total_loss).backward()
# Optimize
if (i + 1) % accumulate == 0:
amp_scale.unscale_(optimizer) # unscale gradients
util.clip_gradients(model=self.model, max_norm=10.0) # clip gradients
2025-10-02 16:26:27 +08:00
amp_scale.step(optimizer)
amp_scale.update()
optimizer.zero_grad(set_to_none=True)
if ema:
ema.update(self.model)
# torch.cuda.synchronize()
# tqdm update
# if args.local_rank == 0:
# mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
# desc = ("%10s" * 2 + "%10.4g" * 3) % (
# self.name,
# mem,
# avg_box_loss.avg,
# avg_cls_loss.avg,
# avg_dfl_loss.avg,
# )
# cast(tqdm, p_bar).set_description(desc)
# p_bar.update(1)
# p_bar.close()
2025-10-02 16:26:27 +08:00
# clean
if args.distributed:
torch.distributed.destroy_process_group()
torch.cuda.empty_cache()
return (
self.model.state_dict() if not ema else ema.ema.state_dict(),
2025-10-02 16:26:27 +08:00
self.n_data,
{"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg},
)