import numpy as np import torch from torch import nn from torch.utils import data from torch.amp.autocast_mode import autocast from utils.fed_util import init_model from utils import util from utils.dataset import Dataset from typing import cast class FedYoloClient(object): def __init__(self, name, model_name, params): """ Initialize the client k for federated learning Args: :param name: Name of the client k :param model_name: Name of the model :param params: config file including the hyperparameters for local training - batch_size: Local training batch size in the client k - num_workers: Number of data loader workers - min_lr: Minimum learning rate - max_lr: Maximum learning rate - momentum: Momentum for local training - weight_decay: Weight decay for local training """ self.params = params # initialize the metadata in local client k self.target_ip = "127.0.0.3" self.port = 9999 self.name = name # initialize the parameters in local client k self._batch_size = self.params["local_batch_size"] self._min_lr = self.params["min_lr"] self._max_lr = self.params["max_lr"] self._momentum = self.params["momentum"] self.num_workers = self.params["num_workers"] self.loss_record = [] # train set length self.n_data = 0 # initialize the local training and testing dataset self.train_dataset = None self.val_dataset = None # initialize the local model self._num_classes = len(self.params["names"]) self._weight_decay = self.params["weight_decay"] self.model_name = model_name self.model = init_model(model_name, self._num_classes) model_parameters = filter(lambda p: p.requires_grad, self.model.parameters()) self.parameter_number = sum([np.prod(p.size()) for p in model_parameters]) # GPU self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def load_trainset(self, train_dataset: list[str]): """ Load the local training dataset Args: :param train_dataset: Training dataset """ self.train_dataset = train_dataset self.n_data = len(self.train_dataset) def update(self, Global_model_state_dict): """ Update the local model with the global model parameters Args: :param Global_model_state_dict: State dictionary of the global model """ if not hasattr(self, "model") or self.model is None: self.model = init_model(self.model_name, self._num_classes) # load the global model parameters self.model.load_state_dict(Global_model_state_dict, strict=True) def train(self, args): """ Train the local model Args: :param args: Command line arguments - local_rank: Local rank for distributed training - world_size: World size for distributed training - distributed: Whether to use distributed training - input_size: Input size for the model Returns: :return: Local updated model, number of local data points, training loss """ if args.distributed: torch.cuda.set_device(device=args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") # print(f"Client {self.name} - distributed training on {world_size} GPUs, local rank: {local_rank}") # self._device = torch.device("cuda", local_rank) if args.local_rank == 0: pass # if not os.path.exists("weights"): # os.makedirs("weights") util.setup_seed() util.setup_multi_processes() # model # init model have been done in __init__() self.model.to(self._device) # Optimizer accumulate = max(round(64 / (self._batch_size * args.world_size)), 1) self._weight_decay = self._batch_size * args.world_size * accumulate / 64 optimizer = torch.optim.SGD( util.set_params(self.model, self._weight_decay), lr=self._min_lr, momentum=self._momentum, nesterov=True, ) # EMA ema = util.EMA(self.model) if args.local_rank == 0 else None data_set = Dataset( filenames=self.train_dataset, input_size=args.input_size, params=self.params, augment=True, ) if args.distributed: train_sampler = data.DistributedSampler( data_set, num_replicas=args.world_size, rank=args.local_rank, shuffle=True ) else: train_sampler = None loader = data.DataLoader( data_set, batch_size=self._batch_size, shuffle=train_sampler is None, sampler=train_sampler, num_workers=self.num_workers, pin_memory=True, collate_fn=Dataset.collate_fn, ) # Scheduler num_steps = max(1, len(loader)) # print(len(loader)) scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps) # DDP mode if args.distributed: self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model) self.model = nn.parallel.DistributedDataParallel( module=self.model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=False, ) amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True) criterion = util.ComputeLoss(self.model, self.params) optimizer.zero_grad(set_to_none=True) for epoch in range(args.epochs): self.model.train() # when distributed, set epoch for shuffling if args.distributed and train_sampler is not None: train_sampler.set_epoch(epoch) if args.epochs - epoch == 10: # disable mosaic augmentation in the last 10 epochs ds = cast(Dataset, loader.dataset) ds.mosaic = False avg_box_loss = util.AverageMeter() avg_cls_loss = util.AverageMeter() avg_dfl_loss = util.AverageMeter() for i, (samples, targets) in enumerate(loader): global_step = i + num_steps * epoch scheduler.step(step=global_step, optimizer=optimizer) samples = samples.cuda(non_blocking=True).float() / 255.0 # Forward with autocast("cuda", enabled=True): outputs = self.model(samples) box_loss, cls_loss, dfl_loss = criterion(outputs, targets) # meters (use the *unscaled* values) bs = samples.size(0) avg_box_loss.update(box_loss.item(), bs) avg_cls_loss.update(cls_loss.item(), bs) avg_dfl_loss.update(dfl_loss.item(), bs) # scale losses by batch/world if your loss is averaged internally per-sample/device box_loss = box_loss * self._batch_size * args.world_size cls_loss = cls_loss * self._batch_size * args.world_size dfl_loss = dfl_loss * self._batch_size * args.world_size total_loss = box_loss + cls_loss + dfl_loss # Backward amp_scale.scale(total_loss).backward() # Optimize if (i + 1) % accumulate == 0: amp_scale.step(optimizer) amp_scale.update() optimizer.zero_grad(set_to_none=True) if ema: ema.update(self.model) # torch.cuda.synchronize() # clean if args.distributed: torch.distributed.destroy_process_group() torch.cuda.empty_cache() return ( self.model.state_dict(), self.n_data, {"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg}, )