优化文档字符串,明确参数说明
This commit is contained in:
		@@ -64,7 +64,7 @@ class FedYoloClient(object):
 | 
			
		||||
        """
 | 
			
		||||
        Load the local training dataset
 | 
			
		||||
        Args:
 | 
			
		||||
            :param train_dataset: Training dataset
 | 
			
		||||
            train_dataset: Training dataset
 | 
			
		||||
        """
 | 
			
		||||
        self.train_dataset = train_dataset
 | 
			
		||||
        self.n_data = len(self.train_dataset)
 | 
			
		||||
@@ -72,8 +72,9 @@ class FedYoloClient(object):
 | 
			
		||||
    def update(self, Global_model_state_dict):
 | 
			
		||||
        """
 | 
			
		||||
        Update the local model with the global model parameters
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            :param Global_model_state_dict: State dictionary of the global model
 | 
			
		||||
            Global_model_state_dict: State dictionary of the global model
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        if not hasattr(self, "model") or self.model is None:
 | 
			
		||||
@@ -85,7 +86,15 @@ class FedYoloClient(object):
 | 
			
		||||
    def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
 | 
			
		||||
        """
 | 
			
		||||
        Train the local model.
 | 
			
		||||
        Returns: (state_dict, n_data, avg_loss_per_image)
 | 
			
		||||
        
 | 
			
		||||
        Args:
 | 
			
		||||
            args: training arguments including
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            (state_dict, n_data, avg_loss_per_image): A tuple including:
 | 
			
		||||
                - state_dict: State dictionary of the trained local model
 | 
			
		||||
                - n_data: Number of training data samples
 | 
			
		||||
                - avg_loss_per_image: Average training loss per image over all epochs
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # ---- Dist init (if any) ----
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user