304 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			304 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import numpy as np
 | 
						|
import torch
 | 
						|
from torch import nn
 | 
						|
from torch.utils import data
 | 
						|
from torch.amp.autocast_mode import autocast
 | 
						|
from utils.fed_util import init_model
 | 
						|
from utils import util
 | 
						|
from utils.dataset import Dataset
 | 
						|
from typing import cast
 | 
						|
from tqdm import tqdm
 | 
						|
 | 
						|
 | 
						|
class FedYoloClient(object):
 | 
						|
    def __init__(self, name, model_name, params):
 | 
						|
        """
 | 
						|
        Initialize the client k for federated learning
 | 
						|
        Args:
 | 
						|
            :param name: Name of the client k
 | 
						|
            :param model_name: Name of the model
 | 
						|
            :param params: config file including the hyperparameters for local training
 | 
						|
                - batch_size: Local training batch size in the client k
 | 
						|
                - num_workers: Number of data loader workers
 | 
						|
 | 
						|
                - min_lr: Minimum learning rate
 | 
						|
                - max_lr: Maximum learning rate
 | 
						|
                - momentum: Momentum for local training
 | 
						|
                - weight_decay: Weight decay for local training
 | 
						|
        """
 | 
						|
        self.params = params
 | 
						|
        # initialize the metadata in local client k
 | 
						|
        self.target_ip = "127.0.0.3"
 | 
						|
        self.port = 9999
 | 
						|
        self.name = name
 | 
						|
 | 
						|
        # initialize the parameters in local client k
 | 
						|
        self._batch_size = self.params["local_batch_size"]
 | 
						|
        self._min_lr = self.params["min_lr"]
 | 
						|
        self._max_lr = self.params["max_lr"]
 | 
						|
        self._momentum = self.params["momentum"]
 | 
						|
        self.num_workers = self.params["num_workers"]
 | 
						|
 | 
						|
        self.loss_record = []
 | 
						|
        # train set length
 | 
						|
        self.n_data = 0
 | 
						|
 | 
						|
        # initialize the local training and testing dataset
 | 
						|
        self.train_dataset = None
 | 
						|
        self.val_dataset = None
 | 
						|
 | 
						|
        # initialize the local model
 | 
						|
        self._num_classes = len(self.params["names"])
 | 
						|
        self._weight_decay = self.params["weight_decay"]
 | 
						|
 | 
						|
        self.model_name = model_name
 | 
						|
        self.model = init_model(model_name, self._num_classes)
 | 
						|
 | 
						|
        model_parameters = filter(lambda p: p.requires_grad, self.model.parameters())
 | 
						|
        self.parameter_number = sum([np.prod(p.size()) for p in model_parameters])
 | 
						|
 | 
						|
        # GPU
 | 
						|
        self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 | 
						|
 | 
						|
    def load_trainset(self, train_dataset: list[str]):
 | 
						|
        """
 | 
						|
        Load the local training dataset
 | 
						|
        Args:
 | 
						|
            train_dataset: Training dataset
 | 
						|
        """
 | 
						|
        self.train_dataset = train_dataset
 | 
						|
        self.n_data = len(self.train_dataset)
 | 
						|
 | 
						|
    def update(self, Global_model_state_dict):
 | 
						|
        """
 | 
						|
        Update the local model with the global model parameters
 | 
						|
 | 
						|
        Args:
 | 
						|
            Global_model_state_dict: State dictionary of the global model
 | 
						|
        """
 | 
						|
 | 
						|
        if not hasattr(self, "model") or self.model is None:
 | 
						|
            self.model = init_model(self.model_name, self._num_classes)
 | 
						|
 | 
						|
        # load the global model parameters
 | 
						|
        self.model.load_state_dict(Global_model_state_dict, strict=True)
 | 
						|
 | 
						|
    def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
 | 
						|
        """
 | 
						|
        Train the local model.
 | 
						|
        
 | 
						|
        Args:
 | 
						|
            args: training arguments including
 | 
						|
 | 
						|
        Returns:
 | 
						|
            (state_dict, n_data, avg_loss_per_image): A tuple including:
 | 
						|
                - state_dict: State dictionary of the trained local model
 | 
						|
                - n_data: Number of training data samples
 | 
						|
                - avg_loss_per_image: Average training loss per image over all epochs
 | 
						|
        """
 | 
						|
 | 
						|
        # ---- Dist init (if any) ----
 | 
						|
        if args.distributed:
 | 
						|
            torch.cuda.set_device(device=args.local_rank)
 | 
						|
            torch.distributed.init_process_group(backend="nccl", init_method="env://")
 | 
						|
 | 
						|
        util.setup_seed()
 | 
						|
        util.setup_multi_processes()
 | 
						|
 | 
						|
        # device = torch.device(f"cuda:{args.local_rank}" if torch.cuda.is_available() else "cpu")
 | 
						|
        # self.model.to(device)
 | 
						|
        self.model.cuda()
 | 
						|
        # show model architecture
 | 
						|
        # print(self.model)
 | 
						|
 | 
						|
        # ---- Optimizer / WD scaling & LR warmup/schedule ----
 | 
						|
        # accumulate = effective grad-accumulation steps to emulate global batch 64
 | 
						|
        world_size = getattr(args, "world_size", 1)
 | 
						|
        accumulate = max(round(64 / (self._batch_size * max(world_size, 1))), 1)
 | 
						|
 | 
						|
        # scale weight_decay like YOLO recipes
 | 
						|
        scaled_wd = self._weight_decay * self._batch_size * max(world_size, 1) * accumulate / 64
 | 
						|
        optimizer = torch.optim.SGD(
 | 
						|
            util.set_params(self.model, scaled_wd),
 | 
						|
            lr=self._min_lr,
 | 
						|
            momentum=self._momentum,
 | 
						|
            nesterov=True,
 | 
						|
        )
 | 
						|
 | 
						|
        # ---- EMA (track the underlying module if DDP) ----
 | 
						|
        # track_model = self.model.module if is_ddp else self.model
 | 
						|
        ema = util.EMA(self.model) if args.local_rank == 0 else None
 | 
						|
 | 
						|
        # print(type(self.train_dataset))
 | 
						|
 | 
						|
        # ---- Data ----
 | 
						|
        dataset = Dataset(
 | 
						|
            filenames=self.train_dataset,
 | 
						|
            input_size=args.input_size,
 | 
						|
            params=self.params,
 | 
						|
            augment=True,
 | 
						|
        )
 | 
						|
 | 
						|
        if args.distributed:
 | 
						|
            train_sampler = data.DistributedSampler(
 | 
						|
                dataset, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            train_sampler = None
 | 
						|
 | 
						|
        loader = data.DataLoader(
 | 
						|
            dataset,
 | 
						|
            batch_size=self._batch_size,
 | 
						|
            shuffle=(train_sampler is None),
 | 
						|
            sampler=train_sampler,
 | 
						|
            num_workers=self.num_workers,
 | 
						|
            pin_memory=True,
 | 
						|
            collate_fn=Dataset.collate_fn,
 | 
						|
            drop_last=False,
 | 
						|
        )
 | 
						|
 | 
						|
        num_steps = max(1, len(loader))
 | 
						|
        scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
 | 
						|
 | 
						|
        # ---- SyncBN + DDP (if any) ----
 | 
						|
        is_ddp = bool(args.distributed)
 | 
						|
        if is_ddp:
 | 
						|
            self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
 | 
						|
            self.model = nn.parallel.DistributedDataParallel(
 | 
						|
                module=self.model,
 | 
						|
                device_ids=[args.local_rank],
 | 
						|
                output_device=args.local_rank,
 | 
						|
                find_unused_parameters=False,
 | 
						|
            )
 | 
						|
 | 
						|
        # ---- AMP + loss ----
 | 
						|
        scaler = torch.amp.grad_scaler.GradScaler(enabled=True)
 | 
						|
        # criterion = util.ComputeLoss(
 | 
						|
        #     self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model,
 | 
						|
        #     self.params,
 | 
						|
        # )
 | 
						|
        criterion = util.ComputeLoss(self.model, self.params)
 | 
						|
 | 
						|
        # ---- Training ----
 | 
						|
        for epoch in range(args.epochs):
 | 
						|
            # (self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model).train()
 | 
						|
            self.model.train()
 | 
						|
            if is_ddp and train_sampler is not None:
 | 
						|
                train_sampler.set_epoch(epoch)
 | 
						|
 | 
						|
            # disable mosaic in the last 10 epochs (if dataset supports it)
 | 
						|
            if args.epochs - epoch == 10 and hasattr(loader.dataset, "mosaic"):
 | 
						|
                ds = cast(Dataset, loader.dataset)
 | 
						|
                ds.mosaic = False
 | 
						|
 | 
						|
            optimizer.zero_grad(set_to_none=True)
 | 
						|
            loss_box_meter = util.AverageMeter()
 | 
						|
            loss_cls_meter = util.AverageMeter()
 | 
						|
            loss_dfl_meter = util.AverageMeter()
 | 
						|
 | 
						|
            for i, (images, targets) in enumerate(loader):
 | 
						|
                # print(f"Client {self.name} - Epoch {epoch + 1}/{args.epochs} - Step {i + 1}/{num_steps}")
 | 
						|
                step = i + epoch * num_steps
 | 
						|
 | 
						|
                # scheduler per-step (your util.LinearLR expects step)
 | 
						|
                scheduler.step(step=step, optimizer=optimizer)
 | 
						|
 | 
						|
                # images = images.to(device, non_blocking=True).float() / 255.0
 | 
						|
                images = images.cuda().float() / 255.0
 | 
						|
                bs = images.size(0)
 | 
						|
                # total_imgs_seen += bs
 | 
						|
 | 
						|
                # targets: keep as your ComputeLoss expects (often CPU lists/tensors).
 | 
						|
                # Move to GPU here only if your loss requires it.
 | 
						|
 | 
						|
                with autocast(device_type="cuda", enabled=True):
 | 
						|
                    outputs = self.model(images)  # DDP wraps forward
 | 
						|
                    box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
 | 
						|
 | 
						|
                    # total_loss = box_loss + cls_loss + dfl_loss
 | 
						|
                    # Gradient accumulation: normalize by 'accumulate' so LR stays effective
 | 
						|
                    # total_loss = total_loss / accumulate
 | 
						|
 | 
						|
                # IMPORTANT: assume criterion returns **average per image** in the batch.
 | 
						|
                # Keep logging on the true (unscaled) values:
 | 
						|
                loss_box_meter.update(box_loss.item(), bs)
 | 
						|
                loss_cls_meter.update(cls_loss.item(), bs)
 | 
						|
                loss_dfl_meter.update(dfl_loss.item(), bs)
 | 
						|
 | 
						|
                box_loss *= self._batch_size
 | 
						|
                cls_loss *= self._batch_size
 | 
						|
                dfl_loss *= self._batch_size
 | 
						|
                box_loss *= args.world_size
 | 
						|
                cls_loss *= args.world_size
 | 
						|
                dfl_loss *= args.world_size
 | 
						|
                total_loss = box_loss + cls_loss + dfl_loss
 | 
						|
 | 
						|
                scaler.scale(total_loss).backward()
 | 
						|
 | 
						|
                # optimize
 | 
						|
                if step % accumulate == 0:
 | 
						|
                    # scaler.unscale_(optimizer)
 | 
						|
                    # util.clip_gradients(self.model)
 | 
						|
                    scaler.step(optimizer)
 | 
						|
                    scaler.update()
 | 
						|
                    optimizer.zero_grad(set_to_none=True)
 | 
						|
 | 
						|
                    # # Step when we have 'accumulate' micro-batches, or at the end
 | 
						|
                    # if ((i + 1) % accumulate == 0) or (i + 1 == len(loader)):
 | 
						|
                    #     scaler.unscale_(optimizer)
 | 
						|
                    #     util.clip_gradients(
 | 
						|
                    #         model=(
 | 
						|
                    #             self.model.module
 | 
						|
                    #             if isinstance(self.model, nn.parallel.DistributedDataParallel)
 | 
						|
                    #             else self.model
 | 
						|
                    #         ),
 | 
						|
                    #         max_norm=10.0,
 | 
						|
                    #     )
 | 
						|
                    #     scaler.step(optimizer)
 | 
						|
                    #     scaler.update()
 | 
						|
                    #     optimizer.zero_grad(set_to_none=True)
 | 
						|
 | 
						|
                    if ema:
 | 
						|
                        # Update EMA from the underlying module
 | 
						|
                        ema.update(
 | 
						|
                            self.model.module
 | 
						|
                            if isinstance(self.model, nn.parallel.DistributedDataParallel)
 | 
						|
                            else self.model
 | 
						|
                        )
 | 
						|
                    # print loss to test
 | 
						|
                    # print(
 | 
						|
                    #     f"loss: {total_loss.item() * accumulate:.4f}, box: {box_loss.item():.4f}, cls: {cls_loss.item():.4f}, dfl: {dfl_loss.item():.4f}"
 | 
						|
                    # )
 | 
						|
                torch.cuda.synchronize()
 | 
						|
 | 
						|
        # ---- Final average loss (per image) over the whole epoch span ----
 | 
						|
        avg_loss_per_image = loss_box_meter.avg + loss_cls_meter.avg + loss_dfl_meter.avg
 | 
						|
 | 
						|
        # ---- Cleanup DDP ----
 | 
						|
        if is_ddp:
 | 
						|
            torch.distributed.destroy_process_group()
 | 
						|
        torch.cuda.empty_cache()
 | 
						|
 | 
						|
        # ---- Choose which weights to return ----
 | 
						|
        #   - If EMA exists, return EMA weights (common YOLO eval practice)
 | 
						|
        #   - Be careful with DDP: grab state_dict from the underlying module / EMA model
 | 
						|
        if ema:
 | 
						|
            # print("Using EMA weights")
 | 
						|
            return (ema.ema.state_dict(), self.n_data, avg_loss_per_image)
 | 
						|
        else:
 | 
						|
            # Safely get the underlying module if wrapped by DDP; getattr returns the module or the original object.
 | 
						|
            model_obj = getattr(self.model, "module", self.model)
 | 
						|
            # If it's a proper nn.Module, call state_dict(); if it's already a state dict, use it;
 | 
						|
            # otherwise try to call state_dict() and finally fall back to wrapping the object.
 | 
						|
            if isinstance(model_obj, torch.nn.Module):
 | 
						|
                model_to_return = model_obj.state_dict()
 | 
						|
            elif isinstance(model_obj, dict):
 | 
						|
                model_to_return = model_obj
 | 
						|
            else:
 | 
						|
                try:
 | 
						|
                    model_to_return = model_obj.state_dict()
 | 
						|
                except Exception:
 | 
						|
                    # fallback: if model_obj is a tensor or unexpected object, wrap it in a dict
 | 
						|
                    model_to_return = {"state": model_obj}
 | 
						|
            return model_to_return, self.n_data, avg_loss_per_image
 |