优化文档字符串,明确参数说明

This commit is contained in:
2025-10-31 13:14:21 +08:00
parent c7a3d34d88
commit 194ca8ee31
2 changed files with 28 additions and 10 deletions

View File

@@ -64,7 +64,7 @@ class FedYoloClient(object):
""" """
Load the local training dataset Load the local training dataset
Args: Args:
:param train_dataset: Training dataset train_dataset: Training dataset
""" """
self.train_dataset = train_dataset self.train_dataset = train_dataset
self.n_data = len(self.train_dataset) self.n_data = len(self.train_dataset)
@@ -72,8 +72,9 @@ class FedYoloClient(object):
def update(self, Global_model_state_dict): def update(self, Global_model_state_dict):
""" """
Update the local model with the global model parameters Update the local model with the global model parameters
Args: Args:
:param Global_model_state_dict: State dictionary of the global model Global_model_state_dict: State dictionary of the global model
""" """
if not hasattr(self, "model") or self.model is None: if not hasattr(self, "model") or self.model is None:
@@ -85,7 +86,15 @@ class FedYoloClient(object):
def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]: def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
""" """
Train the local model. Train the local model.
Returns: (state_dict, n_data, avg_loss_per_image)
Args:
args: training arguments including
Returns:
(state_dict, n_data, avg_loss_per_image): A tuple including:
- state_dict: State dictionary of the trained local model
- n_data: Number of training data samples
- avg_loss_per_image: Average training loss per image over all epochs
""" """
# ---- Dist init (if any) ---- # ---- Dist init (if any) ----

View File

@@ -11,13 +11,13 @@ class FedYoloServer(object):
def __init__(self, client_list, model_name, params): def __init__(self, client_list, model_name, params):
""" """
Federated YOLO Server Federated YOLO Server
Args: Attributes:
client_list: list of connected clients client_list: list of connected clients
model_name: YOLO model architecture name model_name: YOLO model architecture name
params: dict of hyperparameters (must include 'names') params: dict of hyperparameters (must include 'names')
""" """
# Track client updates # Track client updates
self.client_state = {} self.client_state: dict[str, dict[str, torch.Tensor]] = {}
self.client_loss = {} self.client_loss = {}
self.client_n_data = {} self.client_n_data = {}
self.selected_clients = [] self.selected_clients = []
@@ -64,14 +64,19 @@ class FedYoloServer(object):
self.selected_clients.append(client_id) self.selected_clients.append(client_id)
self.n_data += self.client_n_data[client_id] self.n_data += self.client_n_data[client_id]
# TODO: skip the layer which can not be learnted locally
@torch.no_grad() @torch.no_grad()
def agg(self): def agg(self, skip_bn_layer: bool = False):
""" """
Server aggregates the local updates from selected clients using FedAvg. Server aggregates the local updates from selected clients using FedAvg.
:return: model_state: aggregated model weights Args:
:return: avg_loss: weighted average training loss across selected clients skip_bn_layer: whether to skip batch normalization layers during aggregation
:return: n_data: total number of data points across selected clients
Returns:
:model_state: aggregated model weights
:avg_loss: weighted average training loss across selected clients
:n_data: total number of data points across selected clients
""" """
if len(self.selected_clients) == 0 or self.n_data == 0: if len(self.selected_clients) == 0 or self.n_data == 0:
import warnings import warnings
@@ -144,11 +149,13 @@ class FedYoloServer(object):
def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]: def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]:
""" """
Evaluate the model on the validation dataset. Evaluate the model on the validation dataset.
Args: Args:
valset: validation dataset valset: validation dataset
params: dict of parameters (must include 'names') params: dict of parameters (must include 'names')
model: YOLO model to evaluate model: YOLO model to evaluate
batch_size: batch size for evaluation batch_size: batch size for evaluation
Returns: Returns:
dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap) dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap)
""" """
@@ -214,7 +221,9 @@ def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[f
# Compute metrics # Compute metrics
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy
if len(metrics) and metrics[0].any(): if len(metrics) and metrics[0].any():
tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics, plot=False, names=params["names"]) tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(
*metrics, plot=False, names=params["names"]
) # set plot=True to plot metric curve
# Print results # Print results
# print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap)) # print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap))
# Return results # Return results