更新配置文件路径,优化文档字符串,增强代码可读性
This commit is contained in:
		@@ -7,6 +7,7 @@ import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from typing import Dict, List, Optional, Set, Any
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
from nets import nn
 | 
			
		||||
from nets import YOLO
 | 
			
		||||
@@ -30,8 +31,10 @@ def _parse_yolo_label_file(label_path: str) -> Set[int]:
 | 
			
		||||
    Return a set of class_ids found in a YOLO .txt label file.
 | 
			
		||||
    Empty file -> empty set. Missing file -> empty set.
 | 
			
		||||
    Robust to blank lines / trailing spaces.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        label_path: path to the label file
 | 
			
		||||
        
 | 
			
		||||
    Returns:
 | 
			
		||||
        set of class IDs (integers) found in the file
 | 
			
		||||
    """
 | 
			
		||||
@@ -85,7 +88,7 @@ def divide_trainset(
 | 
			
		||||
    Build a federated split from a YOLO dataset list file.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        trainset_path: path to a .txt file containing one image path per line
 | 
			
		||||
        trainset_path (str): path to a .txt file containing one image path per line
 | 
			
		||||
                       e.g. /COCO/images/train2017/1111.jpg
 | 
			
		||||
        num_local_class: how many distinct classes to sample for each client
 | 
			
		||||
        num_client: number of clients
 | 
			
		||||
@@ -95,7 +98,9 @@ def divide_trainset(
 | 
			
		||||
              "disjoint" -> each image is used by at most one client
 | 
			
		||||
        seed: optional random seed for reproducibility
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
    Returns::
 | 
			
		||||
 | 
			
		||||
    >>> \\
 | 
			
		||||
        trainset_divided = {
 | 
			
		||||
            "users": ["c_00001", ...],
 | 
			
		||||
            "user_data": {
 | 
			
		||||
@@ -105,7 +110,9 @@ def divide_trainset(
 | 
			
		||||
            "num_samples": [len(list_for_user1), len(list_for_user2), ...]
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    Example:
 | 
			
		||||
    Example::
 | 
			
		||||
    
 | 
			
		||||
    >>> \\
 | 
			
		||||
        dataset = divide_trainset(
 | 
			
		||||
        trainset_path="/COCO/train2017.txt",
 | 
			
		||||
        num_local_class=3,
 | 
			
		||||
@@ -114,11 +121,11 @@ def divide_trainset(
 | 
			
		||||
        max_data=20,
 | 
			
		||||
        mode="disjoint",   # or "overlap"
 | 
			
		||||
        seed=42
 | 
			
		||||
    )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    print(dataset["users"])            # ['c_00001', ..., 'c_00005']
 | 
			
		||||
    print(dataset["num_samples"])      # e.g. [10, 12, 18, 9, 15]
 | 
			
		||||
    print(dataset["user_data"]["c_00001"]["filename"][:3])
 | 
			
		||||
    >>> print(dataset["users"])            # ['c_00001', ..., 'c_00005']
 | 
			
		||||
    >>> print(dataset["num_samples"])      # e.g. [10, 12, 18, 9, 15]
 | 
			
		||||
    >>> print(dataset["user_data"]["c_00001"]["filename"][:3])
 | 
			
		||||
    """
 | 
			
		||||
    if seed is not None:
 | 
			
		||||
        random.seed(seed)
 | 
			
		||||
@@ -247,8 +254,11 @@ def init_model(model_name, num_classes) -> YOLO:
 | 
			
		||||
    """
 | 
			
		||||
    Initialize the model for a specific learning task
 | 
			
		||||
    Args:
 | 
			
		||||
        :param model_name: Name of the model
 | 
			
		||||
        :param num_classes: Number of classes
 | 
			
		||||
        model_name: Name of the model
 | 
			
		||||
        num_classes: Number of classes
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        model: YOLO model instance
 | 
			
		||||
    """
 | 
			
		||||
    model = None
 | 
			
		||||
    if model_name == "yolo_v11_n":
 | 
			
		||||
@@ -273,11 +283,13 @@ def build_valset_if_available(cfg, params, args=None, val_name: str = "val2017")
 | 
			
		||||
    - If cfg['val_txt'] exists, use it.
 | 
			
		||||
    - Else if <dataset_path>/val.txt exists, use it.
 | 
			
		||||
    - Else return None (testing will be skipped).
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        cfg: config dict
 | 
			
		||||
        params: params dict for Dataset
 | 
			
		||||
        args: optional args object (for input_size)
 | 
			
		||||
        val_name: name of the validation set folder with no prefix (default: "val2017")
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Dataset or None
 | 
			
		||||
    """
 | 
			
		||||
@@ -344,3 +356,23 @@ def plot_curves(save_dir, hist, savename="fed_yolo_curves.png"):
 | 
			
		||||
    out_png = os.path.join(save_dir, savename)
 | 
			
		||||
    plt.savefig(out_png, dpi=150, bbox_inches="tight")
 | 
			
		||||
    print(f"[plot] saved: {out_png}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def prepare_result_dir(base_root: str = "results"):
 | 
			
		||||
    """
 | 
			
		||||
    Prepare result directories for saving outputs.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        base_root (str): base directory for results.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        (res_dir, weights_dir) (str,str): Path to result directory and weights directory.
 | 
			
		||||
    """
 | 
			
		||||
    os.makedirs(base_root, exist_ok=True)
 | 
			
		||||
    timestamp = time.strftime("%Y%m%d_%H%M%S")
 | 
			
		||||
    res_dir = os.path.join(base_root, f"result_{timestamp}")
 | 
			
		||||
    weights_dir = os.path.join(res_dir, f"weight_{timestamp}")
 | 
			
		||||
    os.makedirs(res_dir, exist_ok=True)
 | 
			
		||||
    os.makedirs(weights_dir, exist_ok=True)
 | 
			
		||||
    print(f"[INFO] Saving results to: {res_dir}")
 | 
			
		||||
    return res_dir, weights_dir
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user