更新配置文件路径,优化文档字符串,增强代码可读性

This commit is contained in:
2025-10-31 13:14:29 +08:00
parent 194ca8ee31
commit c7afef2dc2
2 changed files with 42 additions and 10 deletions

View File

@@ -7,7 +7,7 @@ def args_parser():
parser.add_argument("--epochs", type=int, default=16, help="number of rounds of local training") parser.add_argument("--epochs", type=int, default=16, help="number of rounds of local training")
parser.add_argument("--input_size", type=int, default=640, help="image input size") parser.add_argument("--input_size", type=int, default=640, help="image input size")
parser.add_argument("--config", type=str, default="./config/coco_cfg.yaml", help="Path to YAML config") parser.add_argument("--config", type=str, default="./config/uav_cfg.yaml", help="Path to YAML config")
args = parser.parse_args() args = parser.parse_args()

View File

@@ -7,6 +7,7 @@ import numpy as np
import torch import torch
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional, Set, Any from typing import Dict, List, Optional, Set, Any
import time
from nets import nn from nets import nn
from nets import YOLO from nets import YOLO
@@ -30,8 +31,10 @@ def _parse_yolo_label_file(label_path: str) -> Set[int]:
Return a set of class_ids found in a YOLO .txt label file. Return a set of class_ids found in a YOLO .txt label file.
Empty file -> empty set. Missing file -> empty set. Empty file -> empty set. Missing file -> empty set.
Robust to blank lines / trailing spaces. Robust to blank lines / trailing spaces.
Args: Args:
label_path: path to the label file label_path: path to the label file
Returns: Returns:
set of class IDs (integers) found in the file set of class IDs (integers) found in the file
""" """
@@ -85,7 +88,7 @@ def divide_trainset(
Build a federated split from a YOLO dataset list file. Build a federated split from a YOLO dataset list file.
Args: Args:
trainset_path: path to a .txt file containing one image path per line trainset_path (str): path to a .txt file containing one image path per line
e.g. /COCO/images/train2017/1111.jpg e.g. /COCO/images/train2017/1111.jpg
num_local_class: how many distinct classes to sample for each client num_local_class: how many distinct classes to sample for each client
num_client: number of clients num_client: number of clients
@@ -95,7 +98,9 @@ def divide_trainset(
"disjoint" -> each image is used by at most one client "disjoint" -> each image is used by at most one client
seed: optional random seed for reproducibility seed: optional random seed for reproducibility
Returns: Returns::
>>> \\
trainset_divided = { trainset_divided = {
"users": ["c_00001", ...], "users": ["c_00001", ...],
"user_data": { "user_data": {
@@ -105,7 +110,9 @@ def divide_trainset(
"num_samples": [len(list_for_user1), len(list_for_user2), ...] "num_samples": [len(list_for_user1), len(list_for_user2), ...]
} }
Example: Example::
>>> \\
dataset = divide_trainset( dataset = divide_trainset(
trainset_path="/COCO/train2017.txt", trainset_path="/COCO/train2017.txt",
num_local_class=3, num_local_class=3,
@@ -114,11 +121,11 @@ def divide_trainset(
max_data=20, max_data=20,
mode="disjoint", # or "overlap" mode="disjoint", # or "overlap"
seed=42 seed=42
) )
print(dataset["users"]) # ['c_00001', ..., 'c_00005'] >>> print(dataset["users"]) # ['c_00001', ..., 'c_00005']
print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15] >>> print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15]
print(dataset["user_data"]["c_00001"]["filename"][:3]) >>> print(dataset["user_data"]["c_00001"]["filename"][:3])
""" """
if seed is not None: if seed is not None:
random.seed(seed) random.seed(seed)
@@ -247,8 +254,11 @@ def init_model(model_name, num_classes) -> YOLO:
""" """
Initialize the model for a specific learning task Initialize the model for a specific learning task
Args: Args:
:param model_name: Name of the model model_name: Name of the model
:param num_classes: Number of classes num_classes: Number of classes
Returns:
model: YOLO model instance
""" """
model = None model = None
if model_name == "yolo_v11_n": if model_name == "yolo_v11_n":
@@ -273,11 +283,13 @@ def build_valset_if_available(cfg, params, args=None, val_name: str = "val2017")
- If cfg['val_txt'] exists, use it. - If cfg['val_txt'] exists, use it.
- Else if <dataset_path>/val.txt exists, use it. - Else if <dataset_path>/val.txt exists, use it.
- Else return None (testing will be skipped). - Else return None (testing will be skipped).
Args: Args:
cfg: config dict cfg: config dict
params: params dict for Dataset params: params dict for Dataset
args: optional args object (for input_size) args: optional args object (for input_size)
val_name: name of the validation set folder with no prefix (default: "val2017") val_name: name of the validation set folder with no prefix (default: "val2017")
Returns: Returns:
Dataset or None Dataset or None
""" """
@@ -344,3 +356,23 @@ def plot_curves(save_dir, hist, savename="fed_yolo_curves.png"):
out_png = os.path.join(save_dir, savename) out_png = os.path.join(save_dir, savename)
plt.savefig(out_png, dpi=150, bbox_inches="tight") plt.savefig(out_png, dpi=150, bbox_inches="tight")
print(f"[plot] saved: {out_png}") print(f"[plot] saved: {out_png}")
def prepare_result_dir(base_root: str = "results"):
"""
Prepare result directories for saving outputs.
Args:
base_root (str): base directory for results.
Returns:
(res_dir, weights_dir) (str,str): Path to result directory and weights directory.
"""
os.makedirs(base_root, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
res_dir = os.path.join(base_root, f"result_{timestamp}")
weights_dir = os.path.join(res_dir, f"weight_{timestamp}")
os.makedirs(res_dir, exist_ok=True)
os.makedirs(weights_dir, exist_ok=True)
print(f"[INFO] Saving results to: {res_dir}")
return res_dir, weights_dir