From c7afef2dc2b364b96d38a4d255243473579b7c0c Mon Sep 17 00:00:00 2001 From: Yunhao Meng Date: Fri, 31 Oct 2025 13:14:29 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E9=85=8D=E7=BD=AE=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E8=B7=AF=E5=BE=84=EF=BC=8C=E4=BC=98=E5=8C=96=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E5=AD=97=E7=AC=A6=E4=B8=B2=EF=BC=8C=E5=A2=9E=E5=BC=BA?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E5=8F=AF=E8=AF=BB=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- utils/args.py | 2 +- utils/fed_util.py | 50 ++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/utils/args.py b/utils/args.py index 3696ab2..cb796b8 100644 --- a/utils/args.py +++ b/utils/args.py @@ -7,7 +7,7 @@ def args_parser(): parser.add_argument("--epochs", type=int, default=16, help="number of rounds of local training") parser.add_argument("--input_size", type=int, default=640, help="image input size") - parser.add_argument("--config", type=str, default="./config/coco_cfg.yaml", help="Path to YAML config") + parser.add_argument("--config", type=str, default="./config/uav_cfg.yaml", help="Path to YAML config") args = parser.parse_args() diff --git a/utils/fed_util.py b/utils/fed_util.py index 094b47b..1709f57 100644 --- a/utils/fed_util.py +++ b/utils/fed_util.py @@ -7,6 +7,7 @@ import numpy as np import torch from collections import defaultdict from typing import Dict, List, Optional, Set, Any +import time from nets import nn from nets import YOLO @@ -30,8 +31,10 @@ def _parse_yolo_label_file(label_path: str) -> Set[int]: Return a set of class_ids found in a YOLO .txt label file. Empty file -> empty set. Missing file -> empty set. Robust to blank lines / trailing spaces. + Args: label_path: path to the label file + Returns: set of class IDs (integers) found in the file """ @@ -85,7 +88,7 @@ def divide_trainset( Build a federated split from a YOLO dataset list file. Args: - trainset_path: path to a .txt file containing one image path per line + trainset_path (str): path to a .txt file containing one image path per line e.g. /COCO/images/train2017/1111.jpg num_local_class: how many distinct classes to sample for each client num_client: number of clients @@ -95,7 +98,9 @@ def divide_trainset( "disjoint" -> each image is used by at most one client seed: optional random seed for reproducibility - Returns: + Returns:: + + >>> \\ trainset_divided = { "users": ["c_00001", ...], "user_data": { @@ -105,7 +110,9 @@ def divide_trainset( "num_samples": [len(list_for_user1), len(list_for_user2), ...] } - Example: + Example:: + + >>> \\ dataset = divide_trainset( trainset_path="/COCO/train2017.txt", num_local_class=3, @@ -114,11 +121,11 @@ def divide_trainset( max_data=20, mode="disjoint", # or "overlap" seed=42 - ) + ) - print(dataset["users"]) # ['c_00001', ..., 'c_00005'] - print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15] - print(dataset["user_data"]["c_00001"]["filename"][:3]) + >>> print(dataset["users"]) # ['c_00001', ..., 'c_00005'] + >>> print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15] + >>> print(dataset["user_data"]["c_00001"]["filename"][:3]) """ if seed is not None: random.seed(seed) @@ -247,8 +254,11 @@ def init_model(model_name, num_classes) -> YOLO: """ Initialize the model for a specific learning task Args: - :param model_name: Name of the model - :param num_classes: Number of classes + model_name: Name of the model + num_classes: Number of classes + + Returns: + model: YOLO model instance """ model = None if model_name == "yolo_v11_n": @@ -273,11 +283,13 @@ def build_valset_if_available(cfg, params, args=None, val_name: str = "val2017") - If cfg['val_txt'] exists, use it. - Else if /val.txt exists, use it. - Else return None (testing will be skipped). + Args: cfg: config dict params: params dict for Dataset args: optional args object (for input_size) val_name: name of the validation set folder with no prefix (default: "val2017") + Returns: Dataset or None """ @@ -344,3 +356,23 @@ def plot_curves(save_dir, hist, savename="fed_yolo_curves.png"): out_png = os.path.join(save_dir, savename) plt.savefig(out_png, dpi=150, bbox_inches="tight") print(f"[plot] saved: {out_png}") + + +def prepare_result_dir(base_root: str = "results"): + """ + Prepare result directories for saving outputs. + + Args: + base_root (str): base directory for results. + + Returns: + (res_dir, weights_dir) (str,str): Path to result directory and weights directory. + """ + os.makedirs(base_root, exist_ok=True) + timestamp = time.strftime("%Y%m%d_%H%M%S") + res_dir = os.path.join(base_root, f"result_{timestamp}") + weights_dir = os.path.join(res_dir, f"weight_{timestamp}") + os.makedirs(res_dir, exist_ok=True) + os.makedirs(weights_dir, exist_ok=True) + print(f"[INFO] Saving results to: {res_dir}") + return res_dir, weights_dir