379 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			379 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						||
import re
 | 
						||
import random
 | 
						||
import matplotlib.pyplot as plt
 | 
						||
from utils.dataset import Dataset
 | 
						||
import numpy as np
 | 
						||
import torch
 | 
						||
from collections import defaultdict
 | 
						||
from typing import Dict, List, Optional, Set, Any
 | 
						||
import time
 | 
						||
 | 
						||
from nets import nn
 | 
						||
from nets import YOLO
 | 
						||
 | 
						||
 | 
						||
def _image_to_label_path(img_path: str) -> str:
 | 
						||
    """
 | 
						||
    Convert an image path like ".../images/train2017/xxx.jpg"
 | 
						||
    to the corresponding label path ".../labels/train2017/xxx.txt".
 | 
						||
    Works for POSIX/Windows separators.
 | 
						||
    """
 | 
						||
    # swap "/images/" (or "\images\") to "/labels/"
 | 
						||
    label_path = re.sub(r"([/\\])images([/\\])", r"\1labels\2", img_path)
 | 
						||
    # swap extension to .txt
 | 
						||
    root, _ = os.path.splitext(label_path)
 | 
						||
    return root + ".txt"
 | 
						||
 | 
						||
 | 
						||
def _parse_yolo_label_file(label_path: str) -> Set[int]:
 | 
						||
    """
 | 
						||
    Return a set of class_ids found in a YOLO .txt label file.
 | 
						||
    Empty file -> empty set. Missing file -> empty set.
 | 
						||
    Robust to blank lines / trailing spaces.
 | 
						||
 | 
						||
    Args:
 | 
						||
        label_path: path to the label file
 | 
						||
        
 | 
						||
    Returns:
 | 
						||
        set of class IDs (integers) found in the file
 | 
						||
    """
 | 
						||
    class_ids: Set[int] = set()
 | 
						||
    if not os.path.exists(label_path):
 | 
						||
        return class_ids
 | 
						||
    try:
 | 
						||
        with open(label_path, "r", encoding="utf-8") as f:
 | 
						||
            for line in f:
 | 
						||
                line = line.strip()
 | 
						||
                if not line:
 | 
						||
                    continue
 | 
						||
                # YOLO format: cls cx cy w h
 | 
						||
                parts = line.split()
 | 
						||
                if not parts:
 | 
						||
                    continue
 | 
						||
                try:
 | 
						||
                    cls = int(parts[0])
 | 
						||
                except ValueError:
 | 
						||
                    # handle weird case like '23.0'
 | 
						||
                    try:
 | 
						||
                        cls = int(float(parts[0]))
 | 
						||
                    except ValueError:
 | 
						||
                        # skip malformed line
 | 
						||
                        continue
 | 
						||
                class_ids.add(cls)
 | 
						||
    except Exception:
 | 
						||
        # If the file can't be read for some reason, treat as no labels
 | 
						||
        return set()
 | 
						||
    return class_ids
 | 
						||
 | 
						||
 | 
						||
def _read_list_file(txt_path: str):
 | 
						||
    """Read one path per line; keep as-is (absolute or relative)."""
 | 
						||
    if not txt_path or not os.path.exists(txt_path):
 | 
						||
        return []
 | 
						||
    with open(txt_path, "r", encoding="utf-8") as f:
 | 
						||
        return [ln.strip() for ln in f if ln.strip()]
 | 
						||
 | 
						||
 | 
						||
def divide_trainset(
 | 
						||
    trainset_path: str,
 | 
						||
    num_local_class: int,
 | 
						||
    num_client: int,
 | 
						||
    min_data: int,
 | 
						||
    max_data: int,
 | 
						||
    mode: str = "overlap",  # "overlap" or "disjoint"
 | 
						||
    seed: Optional[int] = None,
 | 
						||
) -> Dict[str, Any]:
 | 
						||
    """
 | 
						||
    Build a federated split from a YOLO dataset list file.
 | 
						||
 | 
						||
    Args:
 | 
						||
        trainset_path (str): path to a .txt file containing one image path per line
 | 
						||
                       e.g. /COCO/images/train2017/1111.jpg
 | 
						||
        num_local_class: how many distinct classes to sample for each client
 | 
						||
        num_client: number of clients
 | 
						||
        min_data: minimum number of images per client
 | 
						||
        max_data: maximum number of images per client
 | 
						||
        mode: "overlap"  -> images may be shared across clients
 | 
						||
              "disjoint" -> each image is used by at most one client
 | 
						||
        seed: optional random seed for reproducibility
 | 
						||
 | 
						||
    Returns::
 | 
						||
 | 
						||
    >>> \\
 | 
						||
        trainset_divided = {
 | 
						||
            "users": ["c_00001", ...],
 | 
						||
            "user_data": {
 | 
						||
                "c_00001": {"filename": [img_path, ...]},
 | 
						||
                ...
 | 
						||
            },
 | 
						||
            "num_samples": [len(list_for_user1), len(list_for_user2), ...]
 | 
						||
        }
 | 
						||
 | 
						||
    Example::
 | 
						||
    
 | 
						||
    >>> \\
 | 
						||
        dataset = divide_trainset(
 | 
						||
        trainset_path="/COCO/train2017.txt",
 | 
						||
        num_local_class=3,
 | 
						||
        num_client=5,
 | 
						||
        min_data=10,
 | 
						||
        max_data=20,
 | 
						||
        mode="disjoint",   # or "overlap"
 | 
						||
        seed=42
 | 
						||
        )
 | 
						||
 | 
						||
    >>> print(dataset["users"])            # ['c_00001', ..., 'c_00005']
 | 
						||
    >>> print(dataset["num_samples"])      # e.g. [10, 12, 18, 9, 15]
 | 
						||
    >>> print(dataset["user_data"]["c_00001"]["filename"][:3])
 | 
						||
    """
 | 
						||
    if seed is not None:
 | 
						||
        random.seed(seed)
 | 
						||
 | 
						||
    # ---- Basic validations (defensive programming) ----
 | 
						||
    if num_client <= 0:
 | 
						||
        raise ValueError("num_client must be > 0")
 | 
						||
    if num_local_class <= 0:
 | 
						||
        raise ValueError("num_local_class must be > 0")
 | 
						||
    if min_data < 0 or max_data < 0:
 | 
						||
        raise ValueError("min_data/max_data must be >= 0")
 | 
						||
    if max_data < min_data:
 | 
						||
        raise ValueError("max_data must be >= min_data")
 | 
						||
    if mode not in {"overlap", "disjoint"}:
 | 
						||
        raise ValueError('mode must be "overlap" or "disjoint"')
 | 
						||
 | 
						||
    # ---- 1) Read image list ----
 | 
						||
    with open(trainset_path, "r", encoding="utf-8") as f:
 | 
						||
        all_images_raw = [ln.strip() for ln in f if ln.strip()]
 | 
						||
 | 
						||
    # Normalize and deduplicate image paths (safe)
 | 
						||
    all_images: List[str] = []
 | 
						||
    seen = set()
 | 
						||
    for p in all_images_raw:
 | 
						||
        # keep exact string (don’t join with cwd), just normalize slashes
 | 
						||
        norm = os.path.normpath(p)
 | 
						||
        if norm not in seen:
 | 
						||
            seen.add(norm)
 | 
						||
            all_images.append(norm)
 | 
						||
 | 
						||
    # ---- 2) Build mappings from labels ----
 | 
						||
    class_to_images: Dict[int, Set[str]] = defaultdict(set)
 | 
						||
    image_to_classes: Dict[str, Set[int]] = {}
 | 
						||
 | 
						||
    missing_label_files = 0
 | 
						||
    empty_label_files = 0
 | 
						||
    parsed_images = 0
 | 
						||
 | 
						||
    for img in all_images:
 | 
						||
        lbl = _image_to_label_path(img)
 | 
						||
        if not os.path.exists(lbl):
 | 
						||
            # Missing labels: skip image (no class info)
 | 
						||
            missing_label_files += 1
 | 
						||
            continue
 | 
						||
 | 
						||
        classes = _parse_yolo_label_file(lbl)
 | 
						||
        if not classes:
 | 
						||
            # No objects in this image -> skip (no class bucket)
 | 
						||
            empty_label_files += 1
 | 
						||
            continue
 | 
						||
 | 
						||
        image_to_classes[img] = classes
 | 
						||
        for c in classes:
 | 
						||
            class_to_images[c].add(img)
 | 
						||
        parsed_images += 1
 | 
						||
 | 
						||
    if not class_to_images:
 | 
						||
        # No usable images found
 | 
						||
        return {
 | 
						||
            "users": [f"c_{i + 1:05d}" for i in range(num_client)],
 | 
						||
            "user_data": {f"c_{i + 1:05d}": {"filename": []} for i in range(num_client)},
 | 
						||
            "num_samples": [0 for _ in range(num_client)],
 | 
						||
        }
 | 
						||
 | 
						||
    all_classes: List[int] = sorted(class_to_images.keys())
 | 
						||
    # Available pool for disjoint mode (only images with labels)
 | 
						||
    available_images: Set[str] = set(image_to_classes.keys())
 | 
						||
 | 
						||
    # ---- 3) Allocate to clients ----
 | 
						||
    result = {"users": [], "user_data": {}, "num_samples": []}
 | 
						||
 | 
						||
    for cid in range(num_client):
 | 
						||
        user_id = f"c_{cid + 1:05d}"
 | 
						||
        result["users"].append(user_id)
 | 
						||
 | 
						||
        # Pick the classes for this client (sample without replacement from global class set)
 | 
						||
        k = min(num_local_class, len(all_classes))
 | 
						||
        chosen_classes = random.sample(all_classes, k) if k > 0 else []
 | 
						||
 | 
						||
        # Decide how many samples for this client
 | 
						||
        need = min_data if min_data == max_data else random.randint(min_data, max_data)
 | 
						||
 | 
						||
        # Build the candidate pool for this client
 | 
						||
        if mode == "overlap":
 | 
						||
            pool_set: Set[str] = set()
 | 
						||
            for c in chosen_classes:
 | 
						||
                pool_set.update(class_to_images[c])
 | 
						||
        else:  # "disjoint": restrict to currently available images
 | 
						||
            pool_set = set()
 | 
						||
            for c in chosen_classes:
 | 
						||
                # intersect with available images
 | 
						||
                pool_set.update(class_to_images[c] & available_images)
 | 
						||
 | 
						||
        # Deduplicate and sample
 | 
						||
        pool_list = list(pool_set)
 | 
						||
        if len(pool_list) <= need:
 | 
						||
            chosen_imgs = pool_list[:]  # take all (can be fewer than need)
 | 
						||
        else:
 | 
						||
            chosen_imgs = random.sample(pool_list, need)
 | 
						||
 | 
						||
        # Record for the user
 | 
						||
        result["user_data"][user_id] = {"filename": chosen_imgs}
 | 
						||
        result["num_samples"].append(len(chosen_imgs))
 | 
						||
 | 
						||
        # If disjoint, remove selected images from availability everywhere
 | 
						||
        if mode == "disjoint" and chosen_imgs:
 | 
						||
            for img in chosen_imgs:
 | 
						||
                if img in available_images:
 | 
						||
                    available_images.remove(img)
 | 
						||
                # remove from every class bucket this image belongs to
 | 
						||
                for c in image_to_classes.get(img, []):
 | 
						||
                    if img in class_to_images[c]:
 | 
						||
                        class_to_images[c].remove(img)
 | 
						||
            # Optional: prune empty classes from all_classes to speed up later loops
 | 
						||
            # (keep list stable; just skip empties naturally)
 | 
						||
 | 
						||
    # (Optional) You can print some quick diagnostics if helpful:
 | 
						||
    # print(f"[INFO] Parsed images with labels: {parsed_images}")
 | 
						||
    # print(f"[INFO] Missing label files: {missing_label_files}")
 | 
						||
    # print(f"[INFO] Empty label files: {empty_label_files}")
 | 
						||
 | 
						||
    return result
 | 
						||
 | 
						||
 | 
						||
def init_model(model_name, num_classes) -> YOLO:
 | 
						||
    """
 | 
						||
    Initialize the model for a specific learning task
 | 
						||
    Args:
 | 
						||
        model_name: Name of the model
 | 
						||
        num_classes: Number of classes
 | 
						||
 | 
						||
    Returns:
 | 
						||
        model: YOLO model instance
 | 
						||
    """
 | 
						||
    model = None
 | 
						||
    if model_name == "yolo_v11_n":
 | 
						||
        model = nn.yolo_v11_n(num_classes=num_classes)
 | 
						||
    elif model_name == "yolo_v11_s":
 | 
						||
        model = nn.yolo_v11_s(num_classes=num_classes)
 | 
						||
    elif model_name == "yolo_v11_m":
 | 
						||
        model = nn.yolo_v11_m(num_classes=num_classes)
 | 
						||
    elif model_name == "yolo_v11_l":
 | 
						||
        model = nn.yolo_v11_l(num_classes=num_classes)
 | 
						||
    elif model_name == "yolo_v11_x":
 | 
						||
        model = nn.yolo_v11_x(num_classes=num_classes)
 | 
						||
    else:
 | 
						||
        raise ValueError("Model {} is not supported.".format(model_name))
 | 
						||
 | 
						||
    return model
 | 
						||
 | 
						||
 | 
						||
def build_valset_if_available(cfg, params, args=None, val_name: str = "val2017") -> Optional[Dataset]:
 | 
						||
    """
 | 
						||
    Try to build a validation Dataset.
 | 
						||
    - If cfg['val_txt'] exists, use it.
 | 
						||
    - Else if <dataset_path>/val.txt exists, use it.
 | 
						||
    - Else return None (testing will be skipped).
 | 
						||
 | 
						||
    Args:
 | 
						||
        cfg: config dict
 | 
						||
        params: params dict for Dataset
 | 
						||
        args: optional args object (for input_size)
 | 
						||
        val_name: name of the validation set folder with no prefix (default: "val2017")
 | 
						||
 | 
						||
    Returns:
 | 
						||
        Dataset or None
 | 
						||
    """
 | 
						||
    input_size = args.input_size if args and hasattr(args, "input_size") else 640
 | 
						||
    val_txt = cfg.get("val_txt", "")
 | 
						||
    if not val_txt:
 | 
						||
        ds_root = cfg.get("dataset_path", "")
 | 
						||
        guess = os.path.join(ds_root, f"{val_name}.txt") if ds_root else ""
 | 
						||
        val_txt = guess if os.path.exists(guess) else ""
 | 
						||
 | 
						||
    # val_files = _read_list_file(val_txt)
 | 
						||
 | 
						||
    filenames = []
 | 
						||
    with open(val_txt, "r", encoding="utf-8") as f:
 | 
						||
        for filename in f.readlines():
 | 
						||
            filename = os.path.basename(filename.rstrip())
 | 
						||
            filenames.append(f"{ds_root}/images/{val_name}/" + filename)
 | 
						||
    if not filenames:
 | 
						||
        import warnings
 | 
						||
 | 
						||
        warnings.warn("No validation dataset found.")
 | 
						||
        return None
 | 
						||
 | 
						||
    return Dataset(
 | 
						||
        filenames=filenames,
 | 
						||
        input_size=input_size,
 | 
						||
        params=params,
 | 
						||
        augment=True,
 | 
						||
    )
 | 
						||
 | 
						||
 | 
						||
def seed_everything(seed: int):
 | 
						||
    np.random.seed(seed)
 | 
						||
    torch.manual_seed(seed)
 | 
						||
    random.seed(seed)
 | 
						||
 | 
						||
 | 
						||
def plot_curves(save_dir, hist, savename="fed_yolo_curves.png"):
 | 
						||
    """
 | 
						||
    Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round.
 | 
						||
    Args:
 | 
						||
        save_dir: directory to save the plot
 | 
						||
        hist: history dict with keys "mAP", "mAP50", "precision", "recall", "train_loss"
 | 
						||
        savename: output filename
 | 
						||
    """
 | 
						||
    os.makedirs(save_dir, exist_ok=True)
 | 
						||
    rounds = np.arange(1, len(hist["mAP"]) + 1)
 | 
						||
 | 
						||
    plt.figure()
 | 
						||
    if hist["mAP"]:
 | 
						||
        plt.plot(rounds, hist["mAP"], label="mAP50-95")
 | 
						||
    if hist["mAP50"]:
 | 
						||
        plt.plot(rounds, hist["mAP50"], label="mAP50")
 | 
						||
    if hist["precision"]:
 | 
						||
        plt.plot(rounds, hist["precision"], label="precision")
 | 
						||
    if hist["recall"]:
 | 
						||
        plt.plot(rounds, hist["recall"], label="recall")
 | 
						||
    if hist["train_loss"]:
 | 
						||
        plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)")
 | 
						||
    plt.xlabel("Global Round")
 | 
						||
    plt.ylabel("Metric")
 | 
						||
    plt.title("Federated YOLO - Server Metrics")
 | 
						||
    plt.legend()
 | 
						||
    out_png = os.path.join(save_dir, savename)
 | 
						||
    plt.savefig(out_png, dpi=150, bbox_inches="tight")
 | 
						||
    print(f"[plot] saved: {out_png}")
 | 
						||
 | 
						||
 | 
						||
def prepare_result_dir(base_root: str = "results"):
 | 
						||
    """
 | 
						||
    Prepare result directories for saving outputs.
 | 
						||
 | 
						||
    Args:
 | 
						||
        base_root (str): base directory for results.
 | 
						||
 | 
						||
    Returns:
 | 
						||
        (res_dir, weights_dir) (str,str): Path to result directory and weights directory.
 | 
						||
    """
 | 
						||
    os.makedirs(base_root, exist_ok=True)
 | 
						||
    timestamp = time.strftime("%Y%m%d_%H%M%S")
 | 
						||
    res_dir = os.path.join(base_root, f"result_{timestamp}")
 | 
						||
    weights_dir = os.path.join(res_dir, f"weight_{timestamp}")
 | 
						||
    os.makedirs(res_dir, exist_ok=True)
 | 
						||
    os.makedirs(weights_dir, exist_ok=True)
 | 
						||
    print(f"[INFO] Saving results to: {res_dir}")
 | 
						||
    return res_dir, weights_dir
 |