优化文档字符串,明确参数说明
This commit is contained in:
@@ -64,7 +64,7 @@ class FedYoloClient(object):
|
|||||||
"""
|
"""
|
||||||
Load the local training dataset
|
Load the local training dataset
|
||||||
Args:
|
Args:
|
||||||
:param train_dataset: Training dataset
|
train_dataset: Training dataset
|
||||||
"""
|
"""
|
||||||
self.train_dataset = train_dataset
|
self.train_dataset = train_dataset
|
||||||
self.n_data = len(self.train_dataset)
|
self.n_data = len(self.train_dataset)
|
||||||
@@ -72,8 +72,9 @@ class FedYoloClient(object):
|
|||||||
def update(self, Global_model_state_dict):
|
def update(self, Global_model_state_dict):
|
||||||
"""
|
"""
|
||||||
Update the local model with the global model parameters
|
Update the local model with the global model parameters
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
:param Global_model_state_dict: State dictionary of the global model
|
Global_model_state_dict: State dictionary of the global model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not hasattr(self, "model") or self.model is None:
|
if not hasattr(self, "model") or self.model is None:
|
||||||
@@ -85,7 +86,15 @@ class FedYoloClient(object):
|
|||||||
def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
|
def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
|
||||||
"""
|
"""
|
||||||
Train the local model.
|
Train the local model.
|
||||||
Returns: (state_dict, n_data, avg_loss_per_image)
|
|
||||||
|
Args:
|
||||||
|
args: training arguments including
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(state_dict, n_data, avg_loss_per_image): A tuple including:
|
||||||
|
- state_dict: State dictionary of the trained local model
|
||||||
|
- n_data: Number of training data samples
|
||||||
|
- avg_loss_per_image: Average training loss per image over all epochs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ---- Dist init (if any) ----
|
# ---- Dist init (if any) ----
|
||||||
|
|||||||
@@ -11,13 +11,13 @@ class FedYoloServer(object):
|
|||||||
def __init__(self, client_list, model_name, params):
|
def __init__(self, client_list, model_name, params):
|
||||||
"""
|
"""
|
||||||
Federated YOLO Server
|
Federated YOLO Server
|
||||||
Args:
|
Attributes:
|
||||||
client_list: list of connected clients
|
client_list: list of connected clients
|
||||||
model_name: YOLO model architecture name
|
model_name: YOLO model architecture name
|
||||||
params: dict of hyperparameters (must include 'names')
|
params: dict of hyperparameters (must include 'names')
|
||||||
"""
|
"""
|
||||||
# Track client updates
|
# Track client updates
|
||||||
self.client_state = {}
|
self.client_state: dict[str, dict[str, torch.Tensor]] = {}
|
||||||
self.client_loss = {}
|
self.client_loss = {}
|
||||||
self.client_n_data = {}
|
self.client_n_data = {}
|
||||||
self.selected_clients = []
|
self.selected_clients = []
|
||||||
@@ -64,14 +64,19 @@ class FedYoloServer(object):
|
|||||||
self.selected_clients.append(client_id)
|
self.selected_clients.append(client_id)
|
||||||
self.n_data += self.client_n_data[client_id]
|
self.n_data += self.client_n_data[client_id]
|
||||||
|
|
||||||
|
# TODO: skip the layer which can not be learnted locally
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def agg(self):
|
def agg(self, skip_bn_layer: bool = False):
|
||||||
"""
|
"""
|
||||||
Server aggregates the local updates from selected clients using FedAvg.
|
Server aggregates the local updates from selected clients using FedAvg.
|
||||||
|
|
||||||
:return: model_state: aggregated model weights
|
Args:
|
||||||
:return: avg_loss: weighted average training loss across selected clients
|
skip_bn_layer: whether to skip batch normalization layers during aggregation
|
||||||
:return: n_data: total number of data points across selected clients
|
|
||||||
|
Returns:
|
||||||
|
:model_state: aggregated model weights
|
||||||
|
:avg_loss: weighted average training loss across selected clients
|
||||||
|
:n_data: total number of data points across selected clients
|
||||||
"""
|
"""
|
||||||
if len(self.selected_clients) == 0 or self.n_data == 0:
|
if len(self.selected_clients) == 0 or self.n_data == 0:
|
||||||
import warnings
|
import warnings
|
||||||
@@ -144,11 +149,13 @@ class FedYoloServer(object):
|
|||||||
def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]:
|
def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]:
|
||||||
"""
|
"""
|
||||||
Evaluate the model on the validation dataset.
|
Evaluate the model on the validation dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
valset: validation dataset
|
valset: validation dataset
|
||||||
params: dict of parameters (must include 'names')
|
params: dict of parameters (must include 'names')
|
||||||
model: YOLO model to evaluate
|
model: YOLO model to evaluate
|
||||||
batch_size: batch size for evaluation
|
batch_size: batch size for evaluation
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap)
|
dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap)
|
||||||
"""
|
"""
|
||||||
@@ -214,7 +221,9 @@ def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[f
|
|||||||
# Compute metrics
|
# Compute metrics
|
||||||
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy
|
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy
|
||||||
if len(metrics) and metrics[0].any():
|
if len(metrics) and metrics[0].any():
|
||||||
tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics, plot=False, names=params["names"])
|
tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(
|
||||||
|
*metrics, plot=False, names=params["names"]
|
||||||
|
) # set plot=True to plot metric curve
|
||||||
# Print results
|
# Print results
|
||||||
# print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap))
|
# print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap))
|
||||||
# Return results
|
# Return results
|
||||||
|
|||||||
Reference in New Issue
Block a user