2025-10-02 16:26:27 +08:00
|
|
|
|
import os
|
|
|
|
|
|
import re
|
|
|
|
|
|
import random
|
2025-10-19 21:29:58 +08:00
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
from utils.dataset import Dataset
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import torch
|
2025-10-02 16:26:27 +08:00
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
from typing import Dict, List, Optional, Set, Any
|
|
|
|
|
|
|
|
|
|
|
|
from nets import nn
|
2025-10-19 21:29:58 +08:00
|
|
|
|
from nets import YOLO
|
2025-10-02 16:26:27 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _image_to_label_path(img_path: str) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Convert an image path like ".../images/train2017/xxx.jpg"
|
|
|
|
|
|
to the corresponding label path ".../labels/train2017/xxx.txt".
|
|
|
|
|
|
Works for POSIX/Windows separators.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# swap "/images/" (or "\images\") to "/labels/"
|
|
|
|
|
|
label_path = re.sub(r"([/\\])images([/\\])", r"\1labels\2", img_path)
|
|
|
|
|
|
# swap extension to .txt
|
|
|
|
|
|
root, _ = os.path.splitext(label_path)
|
|
|
|
|
|
return root + ".txt"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_yolo_label_file(label_path: str) -> Set[int]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Return a set of class_ids found in a YOLO .txt label file.
|
|
|
|
|
|
Empty file -> empty set. Missing file -> empty set.
|
|
|
|
|
|
Robust to blank lines / trailing spaces.
|
2025-10-02 22:34:29 +08:00
|
|
|
|
Args:
|
|
|
|
|
|
label_path: path to the label file
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
set of class IDs (integers) found in the file
|
2025-10-02 16:26:27 +08:00
|
|
|
|
"""
|
|
|
|
|
|
class_ids: Set[int] = set()
|
|
|
|
|
|
if not os.path.exists(label_path):
|
|
|
|
|
|
return class_ids
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open(label_path, "r", encoding="utf-8") as f:
|
|
|
|
|
|
for line in f:
|
|
|
|
|
|
line = line.strip()
|
|
|
|
|
|
if not line:
|
|
|
|
|
|
continue
|
|
|
|
|
|
# YOLO format: cls cx cy w h
|
|
|
|
|
|
parts = line.split()
|
|
|
|
|
|
if not parts:
|
|
|
|
|
|
continue
|
|
|
|
|
|
try:
|
|
|
|
|
|
cls = int(parts[0])
|
|
|
|
|
|
except ValueError:
|
|
|
|
|
|
# handle weird case like '23.0'
|
|
|
|
|
|
try:
|
|
|
|
|
|
cls = int(float(parts[0]))
|
|
|
|
|
|
except ValueError:
|
|
|
|
|
|
# skip malformed line
|
|
|
|
|
|
continue
|
|
|
|
|
|
class_ids.add(cls)
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
# If the file can't be read for some reason, treat as no labels
|
|
|
|
|
|
return set()
|
|
|
|
|
|
return class_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-10-19 21:29:58 +08:00
|
|
|
|
def _read_list_file(txt_path: str):
|
|
|
|
|
|
"""Read one path per line; keep as-is (absolute or relative)."""
|
|
|
|
|
|
if not txt_path or not os.path.exists(txt_path):
|
|
|
|
|
|
return []
|
|
|
|
|
|
with open(txt_path, "r", encoding="utf-8") as f:
|
|
|
|
|
|
return [ln.strip() for ln in f if ln.strip()]
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-10-02 16:26:27 +08:00
|
|
|
|
def divide_trainset(
|
|
|
|
|
|
trainset_path: str,
|
|
|
|
|
|
num_local_class: int,
|
|
|
|
|
|
num_client: int,
|
|
|
|
|
|
min_data: int,
|
|
|
|
|
|
max_data: int,
|
|
|
|
|
|
mode: str = "overlap", # "overlap" or "disjoint"
|
|
|
|
|
|
seed: Optional[int] = None,
|
|
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Build a federated split from a YOLO dataset list file.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
trainset_path: path to a .txt file containing one image path per line
|
|
|
|
|
|
e.g. /COCO/images/train2017/1111.jpg
|
|
|
|
|
|
num_local_class: how many distinct classes to sample for each client
|
|
|
|
|
|
num_client: number of clients
|
|
|
|
|
|
min_data: minimum number of images per client
|
|
|
|
|
|
max_data: maximum number of images per client
|
|
|
|
|
|
mode: "overlap" -> images may be shared across clients
|
|
|
|
|
|
"disjoint" -> each image is used by at most one client
|
|
|
|
|
|
seed: optional random seed for reproducibility
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
trainset_divided = {
|
|
|
|
|
|
"users": ["c_00001", ...],
|
|
|
|
|
|
"user_data": {
|
|
|
|
|
|
"c_00001": {"filename": [img_path, ...]},
|
|
|
|
|
|
...
|
|
|
|
|
|
},
|
|
|
|
|
|
"num_samples": [len(list_for_user1), len(list_for_user2), ...]
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
dataset = divide_trainset(
|
|
|
|
|
|
trainset_path="/COCO/train2017.txt",
|
|
|
|
|
|
num_local_class=3,
|
|
|
|
|
|
num_client=5,
|
|
|
|
|
|
min_data=10,
|
|
|
|
|
|
max_data=20,
|
|
|
|
|
|
mode="disjoint", # or "overlap"
|
|
|
|
|
|
seed=42
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
print(dataset["users"]) # ['c_00001', ..., 'c_00005']
|
|
|
|
|
|
print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15]
|
|
|
|
|
|
print(dataset["user_data"]["c_00001"]["filename"][:3])
|
|
|
|
|
|
"""
|
|
|
|
|
|
if seed is not None:
|
|
|
|
|
|
random.seed(seed)
|
|
|
|
|
|
|
|
|
|
|
|
# ---- Basic validations (defensive programming) ----
|
|
|
|
|
|
if num_client <= 0:
|
|
|
|
|
|
raise ValueError("num_client must be > 0")
|
|
|
|
|
|
if num_local_class <= 0:
|
|
|
|
|
|
raise ValueError("num_local_class must be > 0")
|
|
|
|
|
|
if min_data < 0 or max_data < 0:
|
|
|
|
|
|
raise ValueError("min_data/max_data must be >= 0")
|
|
|
|
|
|
if max_data < min_data:
|
|
|
|
|
|
raise ValueError("max_data must be >= min_data")
|
|
|
|
|
|
if mode not in {"overlap", "disjoint"}:
|
|
|
|
|
|
raise ValueError('mode must be "overlap" or "disjoint"')
|
|
|
|
|
|
|
|
|
|
|
|
# ---- 1) Read image list ----
|
|
|
|
|
|
with open(trainset_path, "r", encoding="utf-8") as f:
|
|
|
|
|
|
all_images_raw = [ln.strip() for ln in f if ln.strip()]
|
|
|
|
|
|
|
|
|
|
|
|
# Normalize and deduplicate image paths (safe)
|
|
|
|
|
|
all_images: List[str] = []
|
|
|
|
|
|
seen = set()
|
|
|
|
|
|
for p in all_images_raw:
|
|
|
|
|
|
# keep exact string (don’t join with cwd), just normalize slashes
|
|
|
|
|
|
norm = os.path.normpath(p)
|
|
|
|
|
|
if norm not in seen:
|
|
|
|
|
|
seen.add(norm)
|
|
|
|
|
|
all_images.append(norm)
|
|
|
|
|
|
|
|
|
|
|
|
# ---- 2) Build mappings from labels ----
|
|
|
|
|
|
class_to_images: Dict[int, Set[str]] = defaultdict(set)
|
|
|
|
|
|
image_to_classes: Dict[str, Set[int]] = {}
|
|
|
|
|
|
|
|
|
|
|
|
missing_label_files = 0
|
|
|
|
|
|
empty_label_files = 0
|
|
|
|
|
|
parsed_images = 0
|
|
|
|
|
|
|
|
|
|
|
|
for img in all_images:
|
|
|
|
|
|
lbl = _image_to_label_path(img)
|
|
|
|
|
|
if not os.path.exists(lbl):
|
|
|
|
|
|
# Missing labels: skip image (no class info)
|
|
|
|
|
|
missing_label_files += 1
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
classes = _parse_yolo_label_file(lbl)
|
|
|
|
|
|
if not classes:
|
|
|
|
|
|
# No objects in this image -> skip (no class bucket)
|
|
|
|
|
|
empty_label_files += 1
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
image_to_classes[img] = classes
|
|
|
|
|
|
for c in classes:
|
|
|
|
|
|
class_to_images[c].add(img)
|
|
|
|
|
|
parsed_images += 1
|
|
|
|
|
|
|
|
|
|
|
|
if not class_to_images:
|
|
|
|
|
|
# No usable images found
|
|
|
|
|
|
return {
|
|
|
|
|
|
"users": [f"c_{i + 1:05d}" for i in range(num_client)],
|
|
|
|
|
|
"user_data": {f"c_{i + 1:05d}": {"filename": []} for i in range(num_client)},
|
|
|
|
|
|
"num_samples": [0 for _ in range(num_client)],
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
all_classes: List[int] = sorted(class_to_images.keys())
|
|
|
|
|
|
# Available pool for disjoint mode (only images with labels)
|
|
|
|
|
|
available_images: Set[str] = set(image_to_classes.keys())
|
|
|
|
|
|
|
|
|
|
|
|
# ---- 3) Allocate to clients ----
|
|
|
|
|
|
result = {"users": [], "user_data": {}, "num_samples": []}
|
|
|
|
|
|
|
|
|
|
|
|
for cid in range(num_client):
|
|
|
|
|
|
user_id = f"c_{cid + 1:05d}"
|
|
|
|
|
|
result["users"].append(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
# Pick the classes for this client (sample without replacement from global class set)
|
|
|
|
|
|
k = min(num_local_class, len(all_classes))
|
|
|
|
|
|
chosen_classes = random.sample(all_classes, k) if k > 0 else []
|
|
|
|
|
|
|
|
|
|
|
|
# Decide how many samples for this client
|
|
|
|
|
|
need = min_data if min_data == max_data else random.randint(min_data, max_data)
|
|
|
|
|
|
|
|
|
|
|
|
# Build the candidate pool for this client
|
|
|
|
|
|
if mode == "overlap":
|
|
|
|
|
|
pool_set: Set[str] = set()
|
|
|
|
|
|
for c in chosen_classes:
|
|
|
|
|
|
pool_set.update(class_to_images[c])
|
|
|
|
|
|
else: # "disjoint": restrict to currently available images
|
|
|
|
|
|
pool_set = set()
|
|
|
|
|
|
for c in chosen_classes:
|
|
|
|
|
|
# intersect with available images
|
|
|
|
|
|
pool_set.update(class_to_images[c] & available_images)
|
|
|
|
|
|
|
|
|
|
|
|
# Deduplicate and sample
|
|
|
|
|
|
pool_list = list(pool_set)
|
|
|
|
|
|
if len(pool_list) <= need:
|
|
|
|
|
|
chosen_imgs = pool_list[:] # take all (can be fewer than need)
|
|
|
|
|
|
else:
|
|
|
|
|
|
chosen_imgs = random.sample(pool_list, need)
|
|
|
|
|
|
|
|
|
|
|
|
# Record for the user
|
|
|
|
|
|
result["user_data"][user_id] = {"filename": chosen_imgs}
|
|
|
|
|
|
result["num_samples"].append(len(chosen_imgs))
|
|
|
|
|
|
|
|
|
|
|
|
# If disjoint, remove selected images from availability everywhere
|
|
|
|
|
|
if mode == "disjoint" and chosen_imgs:
|
|
|
|
|
|
for img in chosen_imgs:
|
|
|
|
|
|
if img in available_images:
|
|
|
|
|
|
available_images.remove(img)
|
|
|
|
|
|
# remove from every class bucket this image belongs to
|
|
|
|
|
|
for c in image_to_classes.get(img, []):
|
|
|
|
|
|
if img in class_to_images[c]:
|
|
|
|
|
|
class_to_images[c].remove(img)
|
|
|
|
|
|
# Optional: prune empty classes from all_classes to speed up later loops
|
|
|
|
|
|
# (keep list stable; just skip empties naturally)
|
|
|
|
|
|
|
|
|
|
|
|
# (Optional) You can print some quick diagnostics if helpful:
|
|
|
|
|
|
# print(f"[INFO] Parsed images with labels: {parsed_images}")
|
|
|
|
|
|
# print(f"[INFO] Missing label files: {missing_label_files}")
|
|
|
|
|
|
# print(f"[INFO] Empty label files: {empty_label_files}")
|
|
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-10-19 21:29:58 +08:00
|
|
|
|
def init_model(model_name, num_classes) -> YOLO:
|
2025-10-02 16:26:27 +08:00
|
|
|
|
"""
|
|
|
|
|
|
Initialize the model for a specific learning task
|
|
|
|
|
|
Args:
|
|
|
|
|
|
:param model_name: Name of the model
|
|
|
|
|
|
:param num_classes: Number of classes
|
|
|
|
|
|
"""
|
|
|
|
|
|
model = None
|
|
|
|
|
|
if model_name == "yolo_v11_n":
|
|
|
|
|
|
model = nn.yolo_v11_n(num_classes=num_classes)
|
|
|
|
|
|
elif model_name == "yolo_v11_s":
|
|
|
|
|
|
model = nn.yolo_v11_s(num_classes=num_classes)
|
|
|
|
|
|
elif model_name == "yolo_v11_m":
|
|
|
|
|
|
model = nn.yolo_v11_m(num_classes=num_classes)
|
|
|
|
|
|
elif model_name == "yolo_v11_l":
|
|
|
|
|
|
model = nn.yolo_v11_l(num_classes=num_classes)
|
|
|
|
|
|
elif model_name == "yolo_v11_x":
|
|
|
|
|
|
model = nn.yolo_v11_x(num_classes=num_classes)
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError("Model {} is not supported.".format(model_name))
|
|
|
|
|
|
|
|
|
|
|
|
return model
|
2025-10-19 21:29:58 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-10-23 13:06:38 +08:00
|
|
|
|
def build_valset_if_available(cfg, params, args=None, val_name: str = "val2017") -> Optional[Dataset]:
|
2025-10-19 21:29:58 +08:00
|
|
|
|
"""
|
|
|
|
|
|
Try to build a validation Dataset.
|
|
|
|
|
|
- If cfg['val_txt'] exists, use it.
|
|
|
|
|
|
- Else if <dataset_path>/val.txt exists, use it.
|
|
|
|
|
|
- Else return None (testing will be skipped).
|
|
|
|
|
|
Args:
|
|
|
|
|
|
cfg: config dict
|
|
|
|
|
|
params: params dict for Dataset
|
2025-10-23 13:06:38 +08:00
|
|
|
|
args: optional args object (for input_size)
|
|
|
|
|
|
val_name: name of the validation set folder with no prefix (default: "val2017")
|
2025-10-19 21:29:58 +08:00
|
|
|
|
Returns:
|
|
|
|
|
|
Dataset or None
|
|
|
|
|
|
"""
|
|
|
|
|
|
input_size = args.input_size if args and hasattr(args, "input_size") else 640
|
|
|
|
|
|
val_txt = cfg.get("val_txt", "")
|
|
|
|
|
|
if not val_txt:
|
|
|
|
|
|
ds_root = cfg.get("dataset_path", "")
|
2025-10-23 13:06:38 +08:00
|
|
|
|
guess = os.path.join(ds_root, f"{val_name}.txt") if ds_root else ""
|
2025-10-19 21:29:58 +08:00
|
|
|
|
val_txt = guess if os.path.exists(guess) else ""
|
|
|
|
|
|
|
2025-10-23 13:06:38 +08:00
|
|
|
|
# val_files = _read_list_file(val_txt)
|
|
|
|
|
|
|
|
|
|
|
|
filenames = []
|
|
|
|
|
|
with open(val_txt, "r", encoding="utf-8") as f:
|
|
|
|
|
|
for filename in f.readlines():
|
|
|
|
|
|
filename = os.path.basename(filename.rstrip())
|
|
|
|
|
|
filenames.append(f"{ds_root}/images/{val_name}/" + filename)
|
|
|
|
|
|
if not filenames:
|
2025-10-19 21:29:58 +08:00
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
|
|
|
|
warnings.warn("No validation dataset found.")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
return Dataset(
|
2025-10-23 13:06:38 +08:00
|
|
|
|
filenames=filenames,
|
2025-10-19 21:29:58 +08:00
|
|
|
|
input_size=input_size,
|
|
|
|
|
|
params=params,
|
|
|
|
|
|
augment=True,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def seed_everything(seed: int):
|
|
|
|
|
|
np.random.seed(seed)
|
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
|
|
random.seed(seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_curves(save_dir, hist, savename="fed_yolo_curves.png"):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round.
|
|
|
|
|
|
Args:
|
|
|
|
|
|
save_dir: directory to save the plot
|
|
|
|
|
|
hist: history dict with keys "mAP", "mAP50", "precision", "recall", "train_loss"
|
|
|
|
|
|
savename: output filename
|
|
|
|
|
|
"""
|
|
|
|
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
|
|
rounds = np.arange(1, len(hist["mAP"]) + 1)
|
|
|
|
|
|
|
|
|
|
|
|
plt.figure()
|
|
|
|
|
|
if hist["mAP"]:
|
|
|
|
|
|
plt.plot(rounds, hist["mAP"], label="mAP50-95")
|
|
|
|
|
|
if hist["mAP50"]:
|
|
|
|
|
|
plt.plot(rounds, hist["mAP50"], label="mAP50")
|
|
|
|
|
|
if hist["precision"]:
|
|
|
|
|
|
plt.plot(rounds, hist["precision"], label="precision")
|
|
|
|
|
|
if hist["recall"]:
|
|
|
|
|
|
plt.plot(rounds, hist["recall"], label="recall")
|
|
|
|
|
|
if hist["train_loss"]:
|
|
|
|
|
|
plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)")
|
|
|
|
|
|
plt.xlabel("Global Round")
|
|
|
|
|
|
plt.ylabel("Metric")
|
|
|
|
|
|
plt.title("Federated YOLO - Server Metrics")
|
|
|
|
|
|
plt.legend()
|
|
|
|
|
|
out_png = os.path.join(save_dir, savename)
|
|
|
|
|
|
plt.savefig(out_png, dpi=150, bbox_inches="tight")
|
|
|
|
|
|
print(f"[plot] saved: {out_png}")
|