优化文档字符串,明确参数说明

This commit is contained in:
2025-10-31 13:14:21 +08:00
parent c7a3d34d88
commit 194ca8ee31
2 changed files with 28 additions and 10 deletions

View File

@@ -64,7 +64,7 @@ class FedYoloClient(object):
"""
Load the local training dataset
Args:
:param train_dataset: Training dataset
train_dataset: Training dataset
"""
self.train_dataset = train_dataset
self.n_data = len(self.train_dataset)
@@ -72,8 +72,9 @@ class FedYoloClient(object):
def update(self, Global_model_state_dict):
"""
Update the local model with the global model parameters
Args:
:param Global_model_state_dict: State dictionary of the global model
Global_model_state_dict: State dictionary of the global model
"""
if not hasattr(self, "model") or self.model is None:
@@ -85,7 +86,15 @@ class FedYoloClient(object):
def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
"""
Train the local model.
Returns: (state_dict, n_data, avg_loss_per_image)
Args:
args: training arguments including
Returns:
(state_dict, n_data, avg_loss_per_image): A tuple including:
- state_dict: State dictionary of the trained local model
- n_data: Number of training data samples
- avg_loss_per_image: Average training loss per image over all epochs
"""
# ---- Dist init (if any) ----