2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import numpy as np
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import torch
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from torch import nn
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from torch.utils import data
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from torch.amp.autocast_mode import autocast
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from utils.fed_util import init_model
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from utils import util
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from utils.dataset import Dataset
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from typing import cast
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								from tqdm import tqdm
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								class FedYoloClient(object):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    def __init__(self, name, model_name, params):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        """
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        Initialize the client k for federated learning
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        Args:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            :param name: Name of the client k
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            :param model_name: Name of the model
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            :param params: config file including the hyperparameters for local training
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                - batch_size: Local training batch size in the client k
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                - num_workers: Number of data loader workers
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                - min_lr: Minimum learning rate
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                - max_lr: Maximum learning rate
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                - momentum: Momentum for local training
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                - weight_decay: Weight decay for local training
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        """
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.params = params
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # initialize the metadata in local client k
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.target_ip = "127.0.0.3"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.port = 9999
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.name = name
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # initialize the parameters in local client k
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self._batch_size = self.params["local_batch_size"]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self._min_lr = self.params["min_lr"]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self._max_lr = self.params["max_lr"]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self._momentum = self.params["momentum"]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.num_workers = self.params["num_workers"]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.loss_record = []
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # train set length
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.n_data = 0
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # initialize the local training and testing dataset
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.train_dataset = None
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.val_dataset = None
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # initialize the local model
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self._num_classes = len(self.params["names"])
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self._weight_decay = self.params["weight_decay"]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.model_name = model_name
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.model = init_model(model_name, self._num_classes)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        model_parameters = filter(lambda p: p.requires_grad, self.model.parameters())
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.parameter_number = sum([np.prod(p.size()) for p in model_parameters])
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # GPU
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    def load_trainset(self, train_dataset: list[str]):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        """
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        Load the local training dataset
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        Args:
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-31 13:14:21 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            train_dataset: Training dataset
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        """
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.train_dataset = train_dataset
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.n_data = len(self.train_dataset)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    def update(self, Global_model_state_dict):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        """
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        Update the local model with the global model parameters
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-31 13:14:21 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        Args:
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-31 13:14:21 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            Global_model_state_dict: State dictionary of the global model
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        """
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        if not hasattr(self, "model") or self.model is None:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            self.model = init_model(self.model_name, self._num_classes)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # load the global model parameters
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.model.load_state_dict(Global_model_state_dict, strict=True)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        """
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        Train the local model.
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-31 13:14:21 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        Args:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            args: training arguments including
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        Returns:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            (state_dict, n_data, avg_loss_per_image): A tuple including:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                - state_dict: State dictionary of the trained local model
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                - n_data: Number of training data samples
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                - avg_loss_per_image: Average training loss per image over all epochs
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        """
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        # ---- Dist init (if any) ----
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        if args.distributed:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            torch.cuda.set_device(device=args.local_rank)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            torch.distributed.init_process_group(backend="nccl", init_method="env://")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        util.setup_seed()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        util.setup_multi_processes()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        # device = torch.device(f"cuda:{args.local_rank}" if torch.cuda.is_available() else "cpu")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # self.model.to(device)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.model.cuda()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # show model architecture
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # print(self.model)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        # ---- Optimizer / WD scaling & LR warmup/schedule ----
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # accumulate = effective grad-accumulation steps to emulate global batch 64
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        world_size = getattr(args, "world_size", 1)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        accumulate = max(round(64 / (self._batch_size * max(world_size, 1))), 1)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        # scale weight_decay like YOLO recipes
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        scaled_wd = self._weight_decay * self._batch_size * max(world_size, 1) * accumulate / 64
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        optimizer = torch.optim.SGD(
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            util.set_params(self.model, scaled_wd),
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            lr=self._min_lr,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            momentum=self._momentum,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            nesterov=True,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        # ---- EMA (track the underlying module if DDP) ----
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # track_model = self.model.module if is_ddp else self.model
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        ema = util.EMA(self.model) if args.local_rank == 0 else None
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-23 13:06:13 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        # print(type(self.train_dataset))
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # ---- Data ----
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        dataset = Dataset(
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            filenames=self.train_dataset,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            input_size=args.input_size,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            params=self.params,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            augment=True,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        if args.distributed:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            train_sampler = data.DistributedSampler(
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                dataset, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        else:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            train_sampler = None
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        loader = data.DataLoader(
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            dataset,
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            batch_size=self._batch_size,
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            shuffle=(train_sampler is None),
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            sampler=train_sampler,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            num_workers=self.num_workers,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            pin_memory=True,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            collate_fn=Dataset.collate_fn,
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            drop_last=False,
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        num_steps = max(1, len(loader))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # ---- SyncBN + DDP (if any) ----
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        is_ddp = bool(args.distributed)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        if is_ddp:
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            self.model = nn.parallel.DistributedDataParallel(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                module=self.model,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                device_ids=[args.local_rank],
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                output_device=args.local_rank,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                find_unused_parameters=False,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        # ---- AMP + loss ----
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        scaler = torch.amp.grad_scaler.GradScaler(enabled=True)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # criterion = util.ComputeLoss(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        #     self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        #     self.params,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # )
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        criterion = util.ComputeLoss(self.model, self.params)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        # ---- Training ----
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        for epoch in range(args.epochs):
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            # (self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model).train()
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            self.model.train()
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            if is_ddp and train_sampler is not None:
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                train_sampler.set_epoch(epoch)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            # disable mosaic in the last 10 epochs (if dataset supports it)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            if args.epochs - epoch == 10 and hasattr(loader.dataset, "mosaic"):
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                ds = cast(Dataset, loader.dataset)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                ds.mosaic = False
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-03 20:23:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            optimizer.zero_grad(set_to_none=True)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            loss_box_meter = util.AverageMeter()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            loss_cls_meter = util.AverageMeter()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            loss_dfl_meter = util.AverageMeter()
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            for i, (images, targets) in enumerate(loader):
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-23 13:06:13 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                # print(f"Client {self.name} - Epoch {epoch + 1}/{args.epochs} - Step {i + 1}/{num_steps}")
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                step = i + epoch * num_steps
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-03 20:23:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                # scheduler per-step (your util.LinearLR expects step)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                scheduler.step(step=step, optimizer=optimizer)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-03 20:23:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                # images = images.to(device, non_blocking=True).float() / 255.0
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                images = images.cuda().float() / 255.0
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                bs = images.size(0)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                # total_imgs_seen += bs
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                # targets: keep as your ComputeLoss expects (often CPU lists/tensors).
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                # Move to GPU here only if your loss requires it.
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                with autocast(device_type="cuda", enabled=True):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    outputs = self.model(images)  # DDP wraps forward
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                    # total_loss = box_loss + cls_loss + dfl_loss
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    # Gradient accumulation: normalize by 'accumulate' so LR stays effective
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    # total_loss = total_loss / accumulate
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                # IMPORTANT: assume criterion returns **average per image** in the batch.
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                # Keep logging on the true (unscaled) values:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                loss_box_meter.update(box_loss.item(), bs)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                loss_cls_meter.update(cls_loss.item(), bs)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                loss_dfl_meter.update(dfl_loss.item(), bs)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                box_loss *= self._batch_size
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                cls_loss *= self._batch_size
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                dfl_loss *= self._batch_size
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                box_loss *= args.world_size
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                cls_loss *= args.world_size
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                dfl_loss *= args.world_size
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-03 20:23:34 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                total_loss = box_loss + cls_loss + dfl_loss
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                scaler.scale(total_loss).backward()
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                # optimize
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                if step % accumulate == 0:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    # scaler.unscale_(optimizer)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    # util.clip_gradients(self.model)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    scaler.step(optimizer)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    scaler.update()
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    optimizer.zero_grad(set_to_none=True)
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    # # Step when we have 'accumulate' micro-batches, or at the end
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    # if ((i + 1) % accumulate == 0) or (i + 1 == len(loader)):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    #     scaler.unscale_(optimizer)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    #     util.clip_gradients(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    #         model=(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    #             self.model.module
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    #             if isinstance(self.model, nn.parallel.DistributedDataParallel)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    #             else self.model
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    #         ),
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    #         max_norm=10.0,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    #     )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    #     scaler.step(optimizer)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    #     scaler.update()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    #     optimizer.zero_grad(set_to_none=True)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    if ema:
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                        # Update EMA from the underlying module
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                        ema.update(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                            self.model.module
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                            if isinstance(self.model, nn.parallel.DistributedDataParallel)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                            else self.model
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                        )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    # print loss to test
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-23 13:06:13 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                    # print(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    #     f"loss: {total_loss.item() * accumulate:.4f}, box: {box_loss.item():.4f}, cls: {cls_loss.item():.4f}, dfl: {dfl_loss.item():.4f}"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    # )
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                torch.cuda.synchronize()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # ---- Final average loss (per image) over the whole epoch span ----
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        avg_loss_per_image = loss_box_meter.avg + loss_cls_meter.avg + loss_dfl_meter.avg
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # ---- Cleanup DDP ----
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        if is_ddp:
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-02 16:26:27 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            torch.distributed.destroy_process_group()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        torch.cuda.empty_cache()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-10-19 21:27:19 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        # ---- Choose which weights to return ----
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        #   - If EMA exists, return EMA weights (common YOLO eval practice)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        #   - Be careful with DDP: grab state_dict from the underlying module / EMA model
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        if ema:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            # print("Using EMA weights")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            return (ema.ema.state_dict(), self.n_data, avg_loss_per_image)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        else:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            # Safely get the underlying module if wrapped by DDP; getattr returns the module or the original object.
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            model_obj = getattr(self.model, "module", self.model)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            # If it's a proper nn.Module, call state_dict(); if it's already a state dict, use it;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            # otherwise try to call state_dict() and finally fall back to wrapping the object.
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            if isinstance(model_obj, torch.nn.Module):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                model_to_return = model_obj.state_dict()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            elif isinstance(model_obj, dict):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                model_to_return = model_obj
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            else:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                try:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    model_to_return = model_obj.state_dict()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                except Exception:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    # fallback: if model_obj is a tensor or unexpected object, wrap it in a dict
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    model_to_return = {"state": model_obj}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            return model_to_return, self.n_data, avg_loss_per_image
							 |