更新配置文件路径,优化文档字符串,增强代码可读性
This commit is contained in:
@@ -7,7 +7,7 @@ def args_parser():
|
|||||||
|
|
||||||
parser.add_argument("--epochs", type=int, default=16, help="number of rounds of local training")
|
parser.add_argument("--epochs", type=int, default=16, help="number of rounds of local training")
|
||||||
parser.add_argument("--input_size", type=int, default=640, help="image input size")
|
parser.add_argument("--input_size", type=int, default=640, help="image input size")
|
||||||
parser.add_argument("--config", type=str, default="./config/coco_cfg.yaml", help="Path to YAML config")
|
parser.add_argument("--config", type=str, default="./config/uav_cfg.yaml", help="Path to YAML config")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Optional, Set, Any
|
from typing import Dict, List, Optional, Set, Any
|
||||||
|
import time
|
||||||
|
|
||||||
from nets import nn
|
from nets import nn
|
||||||
from nets import YOLO
|
from nets import YOLO
|
||||||
@@ -30,8 +31,10 @@ def _parse_yolo_label_file(label_path: str) -> Set[int]:
|
|||||||
Return a set of class_ids found in a YOLO .txt label file.
|
Return a set of class_ids found in a YOLO .txt label file.
|
||||||
Empty file -> empty set. Missing file -> empty set.
|
Empty file -> empty set. Missing file -> empty set.
|
||||||
Robust to blank lines / trailing spaces.
|
Robust to blank lines / trailing spaces.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
label_path: path to the label file
|
label_path: path to the label file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
set of class IDs (integers) found in the file
|
set of class IDs (integers) found in the file
|
||||||
"""
|
"""
|
||||||
@@ -85,7 +88,7 @@ def divide_trainset(
|
|||||||
Build a federated split from a YOLO dataset list file.
|
Build a federated split from a YOLO dataset list file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
trainset_path: path to a .txt file containing one image path per line
|
trainset_path (str): path to a .txt file containing one image path per line
|
||||||
e.g. /COCO/images/train2017/1111.jpg
|
e.g. /COCO/images/train2017/1111.jpg
|
||||||
num_local_class: how many distinct classes to sample for each client
|
num_local_class: how many distinct classes to sample for each client
|
||||||
num_client: number of clients
|
num_client: number of clients
|
||||||
@@ -95,7 +98,9 @@ def divide_trainset(
|
|||||||
"disjoint" -> each image is used by at most one client
|
"disjoint" -> each image is used by at most one client
|
||||||
seed: optional random seed for reproducibility
|
seed: optional random seed for reproducibility
|
||||||
|
|
||||||
Returns:
|
Returns::
|
||||||
|
|
||||||
|
>>> \\
|
||||||
trainset_divided = {
|
trainset_divided = {
|
||||||
"users": ["c_00001", ...],
|
"users": ["c_00001", ...],
|
||||||
"user_data": {
|
"user_data": {
|
||||||
@@ -105,7 +110,9 @@ def divide_trainset(
|
|||||||
"num_samples": [len(list_for_user1), len(list_for_user2), ...]
|
"num_samples": [len(list_for_user1), len(list_for_user2), ...]
|
||||||
}
|
}
|
||||||
|
|
||||||
Example:
|
Example::
|
||||||
|
|
||||||
|
>>> \\
|
||||||
dataset = divide_trainset(
|
dataset = divide_trainset(
|
||||||
trainset_path="/COCO/train2017.txt",
|
trainset_path="/COCO/train2017.txt",
|
||||||
num_local_class=3,
|
num_local_class=3,
|
||||||
@@ -114,11 +121,11 @@ def divide_trainset(
|
|||||||
max_data=20,
|
max_data=20,
|
||||||
mode="disjoint", # or "overlap"
|
mode="disjoint", # or "overlap"
|
||||||
seed=42
|
seed=42
|
||||||
)
|
)
|
||||||
|
|
||||||
print(dataset["users"]) # ['c_00001', ..., 'c_00005']
|
>>> print(dataset["users"]) # ['c_00001', ..., 'c_00005']
|
||||||
print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15]
|
>>> print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15]
|
||||||
print(dataset["user_data"]["c_00001"]["filename"][:3])
|
>>> print(dataset["user_data"]["c_00001"]["filename"][:3])
|
||||||
"""
|
"""
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
@@ -247,8 +254,11 @@ def init_model(model_name, num_classes) -> YOLO:
|
|||||||
"""
|
"""
|
||||||
Initialize the model for a specific learning task
|
Initialize the model for a specific learning task
|
||||||
Args:
|
Args:
|
||||||
:param model_name: Name of the model
|
model_name: Name of the model
|
||||||
:param num_classes: Number of classes
|
num_classes: Number of classes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model: YOLO model instance
|
||||||
"""
|
"""
|
||||||
model = None
|
model = None
|
||||||
if model_name == "yolo_v11_n":
|
if model_name == "yolo_v11_n":
|
||||||
@@ -273,11 +283,13 @@ def build_valset_if_available(cfg, params, args=None, val_name: str = "val2017")
|
|||||||
- If cfg['val_txt'] exists, use it.
|
- If cfg['val_txt'] exists, use it.
|
||||||
- Else if <dataset_path>/val.txt exists, use it.
|
- Else if <dataset_path>/val.txt exists, use it.
|
||||||
- Else return None (testing will be skipped).
|
- Else return None (testing will be skipped).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: config dict
|
cfg: config dict
|
||||||
params: params dict for Dataset
|
params: params dict for Dataset
|
||||||
args: optional args object (for input_size)
|
args: optional args object (for input_size)
|
||||||
val_name: name of the validation set folder with no prefix (default: "val2017")
|
val_name: name of the validation set folder with no prefix (default: "val2017")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset or None
|
Dataset or None
|
||||||
"""
|
"""
|
||||||
@@ -344,3 +356,23 @@ def plot_curves(save_dir, hist, savename="fed_yolo_curves.png"):
|
|||||||
out_png = os.path.join(save_dir, savename)
|
out_png = os.path.join(save_dir, savename)
|
||||||
plt.savefig(out_png, dpi=150, bbox_inches="tight")
|
plt.savefig(out_png, dpi=150, bbox_inches="tight")
|
||||||
print(f"[plot] saved: {out_png}")
|
print(f"[plot] saved: {out_png}")
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_result_dir(base_root: str = "results"):
|
||||||
|
"""
|
||||||
|
Prepare result directories for saving outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_root (str): base directory for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(res_dir, weights_dir) (str,str): Path to result directory and weights directory.
|
||||||
|
"""
|
||||||
|
os.makedirs(base_root, exist_ok=True)
|
||||||
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||||
|
res_dir = os.path.join(base_root, f"result_{timestamp}")
|
||||||
|
weights_dir = os.path.join(res_dir, f"weight_{timestamp}")
|
||||||
|
os.makedirs(res_dir, exist_ok=True)
|
||||||
|
os.makedirs(weights_dir, exist_ok=True)
|
||||||
|
print(f"[INFO] Saving results to: {res_dir}")
|
||||||
|
return res_dir, weights_dir
|
||||||
|
|||||||
Reference in New Issue
Block a user