import os import re import random from collections import defaultdict from typing import Dict, List, Optional, Set, Any from nets import nn def _image_to_label_path(img_path: str) -> str: """ Convert an image path like ".../images/train2017/xxx.jpg" to the corresponding label path ".../labels/train2017/xxx.txt". Works for POSIX/Windows separators. """ # swap "/images/" (or "\images\") to "/labels/" label_path = re.sub(r"([/\\])images([/\\])", r"\1labels\2", img_path) # swap extension to .txt root, _ = os.path.splitext(label_path) return root + ".txt" def _parse_yolo_label_file(label_path: str) -> Set[int]: """ Return a set of class_ids found in a YOLO .txt label file. Empty file -> empty set. Missing file -> empty set. Robust to blank lines / trailing spaces. """ class_ids: Set[int] = set() if not os.path.exists(label_path): return class_ids try: with open(label_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue # YOLO format: cls cx cy w h parts = line.split() if not parts: continue try: cls = int(parts[0]) except ValueError: # handle weird case like '23.0' try: cls = int(float(parts[0])) except ValueError: # skip malformed line continue class_ids.add(cls) except Exception: # If the file can't be read for some reason, treat as no labels return set() return class_ids def divide_trainset( trainset_path: str, num_local_class: int, num_client: int, min_data: int, max_data: int, mode: str = "overlap", # "overlap" or "disjoint" seed: Optional[int] = None, ) -> Dict[str, Any]: """ Build a federated split from a YOLO dataset list file. Args: trainset_path: path to a .txt file containing one image path per line e.g. /COCO/images/train2017/1111.jpg num_local_class: how many distinct classes to sample for each client num_client: number of clients min_data: minimum number of images per client max_data: maximum number of images per client mode: "overlap" -> images may be shared across clients "disjoint" -> each image is used by at most one client seed: optional random seed for reproducibility Returns: trainset_divided = { "users": ["c_00001", ...], "user_data": { "c_00001": {"filename": [img_path, ...]}, ... }, "num_samples": [len(list_for_user1), len(list_for_user2), ...] } Example: dataset = divide_trainset( trainset_path="/COCO/train2017.txt", num_local_class=3, num_client=5, min_data=10, max_data=20, mode="disjoint", # or "overlap" seed=42 ) print(dataset["users"]) # ['c_00001', ..., 'c_00005'] print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15] print(dataset["user_data"]["c_00001"]["filename"][:3]) """ if seed is not None: random.seed(seed) # ---- Basic validations (defensive programming) ---- if num_client <= 0: raise ValueError("num_client must be > 0") if num_local_class <= 0: raise ValueError("num_local_class must be > 0") if min_data < 0 or max_data < 0: raise ValueError("min_data/max_data must be >= 0") if max_data < min_data: raise ValueError("max_data must be >= min_data") if mode not in {"overlap", "disjoint"}: raise ValueError('mode must be "overlap" or "disjoint"') # ---- 1) Read image list ---- with open(trainset_path, "r", encoding="utf-8") as f: all_images_raw = [ln.strip() for ln in f if ln.strip()] # Normalize and deduplicate image paths (safe) all_images: List[str] = [] seen = set() for p in all_images_raw: # keep exact string (don’t join with cwd), just normalize slashes norm = os.path.normpath(p) if norm not in seen: seen.add(norm) all_images.append(norm) # ---- 2) Build mappings from labels ---- class_to_images: Dict[int, Set[str]] = defaultdict(set) image_to_classes: Dict[str, Set[int]] = {} missing_label_files = 0 empty_label_files = 0 parsed_images = 0 for img in all_images: lbl = _image_to_label_path(img) if not os.path.exists(lbl): # Missing labels: skip image (no class info) missing_label_files += 1 continue classes = _parse_yolo_label_file(lbl) if not classes: # No objects in this image -> skip (no class bucket) empty_label_files += 1 continue image_to_classes[img] = classes for c in classes: class_to_images[c].add(img) parsed_images += 1 if not class_to_images: # No usable images found return { "users": [f"c_{i + 1:05d}" for i in range(num_client)], "user_data": {f"c_{i + 1:05d}": {"filename": []} for i in range(num_client)}, "num_samples": [0 for _ in range(num_client)], } all_classes: List[int] = sorted(class_to_images.keys()) # Available pool for disjoint mode (only images with labels) available_images: Set[str] = set(image_to_classes.keys()) # ---- 3) Allocate to clients ---- result = {"users": [], "user_data": {}, "num_samples": []} for cid in range(num_client): user_id = f"c_{cid + 1:05d}" result["users"].append(user_id) # Pick the classes for this client (sample without replacement from global class set) k = min(num_local_class, len(all_classes)) chosen_classes = random.sample(all_classes, k) if k > 0 else [] # Decide how many samples for this client need = min_data if min_data == max_data else random.randint(min_data, max_data) # Build the candidate pool for this client if mode == "overlap": pool_set: Set[str] = set() for c in chosen_classes: pool_set.update(class_to_images[c]) else: # "disjoint": restrict to currently available images pool_set = set() for c in chosen_classes: # intersect with available images pool_set.update(class_to_images[c] & available_images) # Deduplicate and sample pool_list = list(pool_set) if len(pool_list) <= need: chosen_imgs = pool_list[:] # take all (can be fewer than need) else: chosen_imgs = random.sample(pool_list, need) # Record for the user result["user_data"][user_id] = {"filename": chosen_imgs} result["num_samples"].append(len(chosen_imgs)) # If disjoint, remove selected images from availability everywhere if mode == "disjoint" and chosen_imgs: for img in chosen_imgs: if img in available_images: available_images.remove(img) # remove from every class bucket this image belongs to for c in image_to_classes.get(img, []): if img in class_to_images[c]: class_to_images[c].remove(img) # Optional: prune empty classes from all_classes to speed up later loops # (keep list stable; just skip empties naturally) # (Optional) You can print some quick diagnostics if helpful: # print(f"[INFO] Parsed images with labels: {parsed_images}") # print(f"[INFO] Missing label files: {missing_label_files}") # print(f"[INFO] Empty label files: {empty_label_files}") return result def init_model(model_name, num_classes): """ Initialize the model for a specific learning task Args: :param model_name: Name of the model :param num_classes: Number of classes """ model = None if model_name == "yolo_v11_n": model = nn.yolo_v11_n(num_classes=num_classes) elif model_name == "yolo_v11_s": model = nn.yolo_v11_s(num_classes=num_classes) elif model_name == "yolo_v11_m": model = nn.yolo_v11_m(num_classes=num_classes) elif model_name == "yolo_v11_l": model = nn.yolo_v11_l(num_classes=num_classes) elif model_name == "yolo_v11_x": model = nn.yolo_v11_x(num_classes=num_classes) else: raise ValueError("Model {} is not supported.".format(model_name)) return model