更新配置文件路径,优化文档字符串,增强代码可读性

This commit is contained in:
2025-10-31 13:14:29 +08:00
parent 194ca8ee31
commit c7afef2dc2
2 changed files with 42 additions and 10 deletions

View File

@@ -7,7 +7,7 @@ def args_parser():
parser.add_argument("--epochs", type=int, default=16, help="number of rounds of local training")
parser.add_argument("--input_size", type=int, default=640, help="image input size")
parser.add_argument("--config", type=str, default="./config/coco_cfg.yaml", help="Path to YAML config")
parser.add_argument("--config", type=str, default="./config/uav_cfg.yaml", help="Path to YAML config")
args = parser.parse_args()

View File

@@ -7,6 +7,7 @@ import numpy as np
import torch
from collections import defaultdict
from typing import Dict, List, Optional, Set, Any
import time
from nets import nn
from nets import YOLO
@@ -30,8 +31,10 @@ def _parse_yolo_label_file(label_path: str) -> Set[int]:
Return a set of class_ids found in a YOLO .txt label file.
Empty file -> empty set. Missing file -> empty set.
Robust to blank lines / trailing spaces.
Args:
label_path: path to the label file
Returns:
set of class IDs (integers) found in the file
"""
@@ -85,7 +88,7 @@ def divide_trainset(
Build a federated split from a YOLO dataset list file.
Args:
trainset_path: path to a .txt file containing one image path per line
trainset_path (str): path to a .txt file containing one image path per line
e.g. /COCO/images/train2017/1111.jpg
num_local_class: how many distinct classes to sample for each client
num_client: number of clients
@@ -95,7 +98,9 @@ def divide_trainset(
"disjoint" -> each image is used by at most one client
seed: optional random seed for reproducibility
Returns:
Returns::
>>> \\
trainset_divided = {
"users": ["c_00001", ...],
"user_data": {
@@ -105,7 +110,9 @@ def divide_trainset(
"num_samples": [len(list_for_user1), len(list_for_user2), ...]
}
Example:
Example::
>>> \\
dataset = divide_trainset(
trainset_path="/COCO/train2017.txt",
num_local_class=3,
@@ -114,11 +121,11 @@ def divide_trainset(
max_data=20,
mode="disjoint", # or "overlap"
seed=42
)
)
print(dataset["users"]) # ['c_00001', ..., 'c_00005']
print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15]
print(dataset["user_data"]["c_00001"]["filename"][:3])
>>> print(dataset["users"]) # ['c_00001', ..., 'c_00005']
>>> print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15]
>>> print(dataset["user_data"]["c_00001"]["filename"][:3])
"""
if seed is not None:
random.seed(seed)
@@ -247,8 +254,11 @@ def init_model(model_name, num_classes) -> YOLO:
"""
Initialize the model for a specific learning task
Args:
:param model_name: Name of the model
:param num_classes: Number of classes
model_name: Name of the model
num_classes: Number of classes
Returns:
model: YOLO model instance
"""
model = None
if model_name == "yolo_v11_n":
@@ -273,11 +283,13 @@ def build_valset_if_available(cfg, params, args=None, val_name: str = "val2017")
- If cfg['val_txt'] exists, use it.
- Else if <dataset_path>/val.txt exists, use it.
- Else return None (testing will be skipped).
Args:
cfg: config dict
params: params dict for Dataset
args: optional args object (for input_size)
val_name: name of the validation set folder with no prefix (default: "val2017")
Returns:
Dataset or None
"""
@@ -344,3 +356,23 @@ def plot_curves(save_dir, hist, savename="fed_yolo_curves.png"):
out_png = os.path.join(save_dir, savename)
plt.savefig(out_png, dpi=150, bbox_inches="tight")
print(f"[plot] saved: {out_png}")
def prepare_result_dir(base_root: str = "results"):
"""
Prepare result directories for saving outputs.
Args:
base_root (str): base directory for results.
Returns:
(res_dir, weights_dir) (str,str): Path to result directory and weights directory.
"""
os.makedirs(base_root, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
res_dir = os.path.join(base_root, f"result_{timestamp}")
weights_dir = os.path.join(res_dir, f"weight_{timestamp}")
os.makedirs(res_dir, exist_ok=True)
os.makedirs(weights_dir, exist_ok=True)
print(f"[INFO] Saving results to: {res_dir}")
return res_dir, weights_dir