更新配置文件路径,优化文档字符串,增强代码可读性
This commit is contained in:
@@ -7,7 +7,7 @@ def args_parser():
|
||||
|
||||
parser.add_argument("--epochs", type=int, default=16, help="number of rounds of local training")
|
||||
parser.add_argument("--input_size", type=int, default=640, help="image input size")
|
||||
parser.add_argument("--config", type=str, default="./config/coco_cfg.yaml", help="Path to YAML config")
|
||||
parser.add_argument("--config", type=str, default="./config/uav_cfg.yaml", help="Path to YAML config")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import numpy as np
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Set, Any
|
||||
import time
|
||||
|
||||
from nets import nn
|
||||
from nets import YOLO
|
||||
@@ -30,8 +31,10 @@ def _parse_yolo_label_file(label_path: str) -> Set[int]:
|
||||
Return a set of class_ids found in a YOLO .txt label file.
|
||||
Empty file -> empty set. Missing file -> empty set.
|
||||
Robust to blank lines / trailing spaces.
|
||||
|
||||
Args:
|
||||
label_path: path to the label file
|
||||
|
||||
Returns:
|
||||
set of class IDs (integers) found in the file
|
||||
"""
|
||||
@@ -85,7 +88,7 @@ def divide_trainset(
|
||||
Build a federated split from a YOLO dataset list file.
|
||||
|
||||
Args:
|
||||
trainset_path: path to a .txt file containing one image path per line
|
||||
trainset_path (str): path to a .txt file containing one image path per line
|
||||
e.g. /COCO/images/train2017/1111.jpg
|
||||
num_local_class: how many distinct classes to sample for each client
|
||||
num_client: number of clients
|
||||
@@ -95,7 +98,9 @@ def divide_trainset(
|
||||
"disjoint" -> each image is used by at most one client
|
||||
seed: optional random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
Returns::
|
||||
|
||||
>>> \\
|
||||
trainset_divided = {
|
||||
"users": ["c_00001", ...],
|
||||
"user_data": {
|
||||
@@ -105,7 +110,9 @@ def divide_trainset(
|
||||
"num_samples": [len(list_for_user1), len(list_for_user2), ...]
|
||||
}
|
||||
|
||||
Example:
|
||||
Example::
|
||||
|
||||
>>> \\
|
||||
dataset = divide_trainset(
|
||||
trainset_path="/COCO/train2017.txt",
|
||||
num_local_class=3,
|
||||
@@ -114,11 +121,11 @@ def divide_trainset(
|
||||
max_data=20,
|
||||
mode="disjoint", # or "overlap"
|
||||
seed=42
|
||||
)
|
||||
)
|
||||
|
||||
print(dataset["users"]) # ['c_00001', ..., 'c_00005']
|
||||
print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15]
|
||||
print(dataset["user_data"]["c_00001"]["filename"][:3])
|
||||
>>> print(dataset["users"]) # ['c_00001', ..., 'c_00005']
|
||||
>>> print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15]
|
||||
>>> print(dataset["user_data"]["c_00001"]["filename"][:3])
|
||||
"""
|
||||
if seed is not None:
|
||||
random.seed(seed)
|
||||
@@ -247,8 +254,11 @@ def init_model(model_name, num_classes) -> YOLO:
|
||||
"""
|
||||
Initialize the model for a specific learning task
|
||||
Args:
|
||||
:param model_name: Name of the model
|
||||
:param num_classes: Number of classes
|
||||
model_name: Name of the model
|
||||
num_classes: Number of classes
|
||||
|
||||
Returns:
|
||||
model: YOLO model instance
|
||||
"""
|
||||
model = None
|
||||
if model_name == "yolo_v11_n":
|
||||
@@ -273,11 +283,13 @@ def build_valset_if_available(cfg, params, args=None, val_name: str = "val2017")
|
||||
- If cfg['val_txt'] exists, use it.
|
||||
- Else if <dataset_path>/val.txt exists, use it.
|
||||
- Else return None (testing will be skipped).
|
||||
|
||||
Args:
|
||||
cfg: config dict
|
||||
params: params dict for Dataset
|
||||
args: optional args object (for input_size)
|
||||
val_name: name of the validation set folder with no prefix (default: "val2017")
|
||||
|
||||
Returns:
|
||||
Dataset or None
|
||||
"""
|
||||
@@ -344,3 +356,23 @@ def plot_curves(save_dir, hist, savename="fed_yolo_curves.png"):
|
||||
out_png = os.path.join(save_dir, savename)
|
||||
plt.savefig(out_png, dpi=150, bbox_inches="tight")
|
||||
print(f"[plot] saved: {out_png}")
|
||||
|
||||
|
||||
def prepare_result_dir(base_root: str = "results"):
|
||||
"""
|
||||
Prepare result directories for saving outputs.
|
||||
|
||||
Args:
|
||||
base_root (str): base directory for results.
|
||||
|
||||
Returns:
|
||||
(res_dir, weights_dir) (str,str): Path to result directory and weights directory.
|
||||
"""
|
||||
os.makedirs(base_root, exist_ok=True)
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
res_dir = os.path.join(base_root, f"result_{timestamp}")
|
||||
weights_dir = os.path.join(res_dir, f"weight_{timestamp}")
|
||||
os.makedirs(res_dir, exist_ok=True)
|
||||
os.makedirs(weights_dir, exist_ok=True)
|
||||
print(f"[INFO] Saving results to: {res_dir}")
|
||||
return res_dir, weights_dir
|
||||
|
||||
Reference in New Issue
Block a user