Files
fed-yolo/utils/fed_util.py

255 lines
8.8 KiB
Python
Raw Permalink Normal View History

2025-10-02 16:26:27 +08:00
import os
import re
import random
from collections import defaultdict
from typing import Dict, List, Optional, Set, Any
from nets import nn
def _image_to_label_path(img_path: str) -> str:
"""
Convert an image path like ".../images/train2017/xxx.jpg"
to the corresponding label path ".../labels/train2017/xxx.txt".
Works for POSIX/Windows separators.
"""
# swap "/images/" (or "\images\") to "/labels/"
label_path = re.sub(r"([/\\])images([/\\])", r"\1labels\2", img_path)
# swap extension to .txt
root, _ = os.path.splitext(label_path)
return root + ".txt"
def _parse_yolo_label_file(label_path: str) -> Set[int]:
"""
Return a set of class_ids found in a YOLO .txt label file.
Empty file -> empty set. Missing file -> empty set.
Robust to blank lines / trailing spaces.
Args:
label_path: path to the label file
Returns:
set of class IDs (integers) found in the file
2025-10-02 16:26:27 +08:00
"""
class_ids: Set[int] = set()
if not os.path.exists(label_path):
return class_ids
try:
with open(label_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
# YOLO format: cls cx cy w h
parts = line.split()
if not parts:
continue
try:
cls = int(parts[0])
except ValueError:
# handle weird case like '23.0'
try:
cls = int(float(parts[0]))
except ValueError:
# skip malformed line
continue
class_ids.add(cls)
except Exception:
# If the file can't be read for some reason, treat as no labels
return set()
return class_ids
def divide_trainset(
trainset_path: str,
num_local_class: int,
num_client: int,
min_data: int,
max_data: int,
mode: str = "overlap", # "overlap" or "disjoint"
seed: Optional[int] = None,
) -> Dict[str, Any]:
"""
Build a federated split from a YOLO dataset list file.
Args:
trainset_path: path to a .txt file containing one image path per line
e.g. /COCO/images/train2017/1111.jpg
num_local_class: how many distinct classes to sample for each client
num_client: number of clients
min_data: minimum number of images per client
max_data: maximum number of images per client
mode: "overlap" -> images may be shared across clients
"disjoint" -> each image is used by at most one client
seed: optional random seed for reproducibility
Returns:
trainset_divided = {
"users": ["c_00001", ...],
"user_data": {
"c_00001": {"filename": [img_path, ...]},
...
},
"num_samples": [len(list_for_user1), len(list_for_user2), ...]
}
Example:
dataset = divide_trainset(
trainset_path="/COCO/train2017.txt",
num_local_class=3,
num_client=5,
min_data=10,
max_data=20,
mode="disjoint", # or "overlap"
seed=42
)
print(dataset["users"]) # ['c_00001', ..., 'c_00005']
print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15]
print(dataset["user_data"]["c_00001"]["filename"][:3])
"""
if seed is not None:
random.seed(seed)
# ---- Basic validations (defensive programming) ----
if num_client <= 0:
raise ValueError("num_client must be > 0")
if num_local_class <= 0:
raise ValueError("num_local_class must be > 0")
if min_data < 0 or max_data < 0:
raise ValueError("min_data/max_data must be >= 0")
if max_data < min_data:
raise ValueError("max_data must be >= min_data")
if mode not in {"overlap", "disjoint"}:
raise ValueError('mode must be "overlap" or "disjoint"')
# ---- 1) Read image list ----
with open(trainset_path, "r", encoding="utf-8") as f:
all_images_raw = [ln.strip() for ln in f if ln.strip()]
# Normalize and deduplicate image paths (safe)
all_images: List[str] = []
seen = set()
for p in all_images_raw:
# keep exact string (dont join with cwd), just normalize slashes
norm = os.path.normpath(p)
if norm not in seen:
seen.add(norm)
all_images.append(norm)
# ---- 2) Build mappings from labels ----
class_to_images: Dict[int, Set[str]] = defaultdict(set)
image_to_classes: Dict[str, Set[int]] = {}
missing_label_files = 0
empty_label_files = 0
parsed_images = 0
for img in all_images:
lbl = _image_to_label_path(img)
if not os.path.exists(lbl):
# Missing labels: skip image (no class info)
missing_label_files += 1
continue
classes = _parse_yolo_label_file(lbl)
if not classes:
# No objects in this image -> skip (no class bucket)
empty_label_files += 1
continue
image_to_classes[img] = classes
for c in classes:
class_to_images[c].add(img)
parsed_images += 1
if not class_to_images:
# No usable images found
return {
"users": [f"c_{i + 1:05d}" for i in range(num_client)],
"user_data": {f"c_{i + 1:05d}": {"filename": []} for i in range(num_client)},
"num_samples": [0 for _ in range(num_client)],
}
all_classes: List[int] = sorted(class_to_images.keys())
# Available pool for disjoint mode (only images with labels)
available_images: Set[str] = set(image_to_classes.keys())
# ---- 3) Allocate to clients ----
result = {"users": [], "user_data": {}, "num_samples": []}
for cid in range(num_client):
user_id = f"c_{cid + 1:05d}"
result["users"].append(user_id)
# Pick the classes for this client (sample without replacement from global class set)
k = min(num_local_class, len(all_classes))
chosen_classes = random.sample(all_classes, k) if k > 0 else []
# Decide how many samples for this client
need = min_data if min_data == max_data else random.randint(min_data, max_data)
# Build the candidate pool for this client
if mode == "overlap":
pool_set: Set[str] = set()
for c in chosen_classes:
pool_set.update(class_to_images[c])
else: # "disjoint": restrict to currently available images
pool_set = set()
for c in chosen_classes:
# intersect with available images
pool_set.update(class_to_images[c] & available_images)
# Deduplicate and sample
pool_list = list(pool_set)
if len(pool_list) <= need:
chosen_imgs = pool_list[:] # take all (can be fewer than need)
else:
chosen_imgs = random.sample(pool_list, need)
# Record for the user
result["user_data"][user_id] = {"filename": chosen_imgs}
result["num_samples"].append(len(chosen_imgs))
# If disjoint, remove selected images from availability everywhere
if mode == "disjoint" and chosen_imgs:
for img in chosen_imgs:
if img in available_images:
available_images.remove(img)
# remove from every class bucket this image belongs to
for c in image_to_classes.get(img, []):
if img in class_to_images[c]:
class_to_images[c].remove(img)
# Optional: prune empty classes from all_classes to speed up later loops
# (keep list stable; just skip empties naturally)
# (Optional) You can print some quick diagnostics if helpful:
# print(f"[INFO] Parsed images with labels: {parsed_images}")
# print(f"[INFO] Missing label files: {missing_label_files}")
# print(f"[INFO] Empty label files: {empty_label_files}")
return result
def init_model(model_name, num_classes):
"""
Initialize the model for a specific learning task
Args:
:param model_name: Name of the model
:param num_classes: Number of classes
"""
model = None
if model_name == "yolo_v11_n":
model = nn.yolo_v11_n(num_classes=num_classes)
elif model_name == "yolo_v11_s":
model = nn.yolo_v11_s(num_classes=num_classes)
elif model_name == "yolo_v11_m":
model = nn.yolo_v11_m(num_classes=num_classes)
elif model_name == "yolo_v11_l":
model = nn.yolo_v11_l(num_classes=num_classes)
elif model_name == "yolo_v11_x":
model = nn.yolo_v11_x(num_classes=num_classes)
else:
raise ValueError("Model {} is not supported.".format(model_name))
return model