import os
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, random_split
from collections import Counter
from torch.utils.data import DataLoader, Subset
from torchvision import transforms, datasets
import os
from sklearn.model_selection import train_test_split


class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        
        # 遍历 root_dir 下的子文件夹 0 和 1
        for label in [0, 1]:
            folder_path = os.path.join(root_dir, str(label))
            if os.path.isdir(folder_path):
                for img_name in os.listdir(folder_path):
                    img_path = os.path.join(folder_path, img_name)
                    self.image_paths.append(img_path)
                    self.labels.append(label)
        
        # 打印用于调试的图像路径和标签
        # print("Loaded image paths and labels:")
        # for path, label in zip(self.image_paths[:5], self.labels[:5]):
        #     print(f"Path: {path}, Label: {label}")
        # print(f"Total samples: {len(self.image_paths)}\n")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        return image, label


def get_test_data(test_image_path, batch_size, nw):
    data_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # test_dataset = datasets.ImageFolder(root=test_image_path, transform=data_transform)
    
    test_dataset = CustomImageDataset(root_dir=test_image_path, transform=data_transform)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=nw)
    return test_loader


def get_Onedata(train_image_path, val_image_path, batch_size, num_workers):
    """
    加载完整的训练数据集和验证数据集。
    """
    data_transform = {
        "train": transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        "val": transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    
    # 创建训练和验证数据集
    train_dataset = CustomImageDataset(root_dir=train_image_path, transform=data_transform["train"])
    val_dataset = CustomImageDataset(root_dir=val_image_path, transform=data_transform["val"])
    
    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    return train_loader, val_loader


def get_data(train_image_path, val_image_path, batch_size, num_workers):
    data_transform = {
        "train": transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        "val": transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        "test": transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    
    train_dataset = CustomImageDataset(root_dir=train_image_path, transform=data_transform["train"])
    val_dataset = CustomImageDataset(root_dir=val_image_path, transform=data_transform["val"])
    
    # 切分数据集为三个等分
    train_len = (len(train_dataset) // 3) * 3
    train_dataset_truncated = torch.utils.data.Subset(train_dataset, range(train_len))
    subset_len = train_len // 3
    dataset1, dataset2, dataset3 = random_split(train_dataset_truncated, [subset_len] * 3)
    
    loader1 = DataLoader(dataset1, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    loader2 = DataLoader(dataset2, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    loader3 = DataLoader(dataset3, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    return loader1, loader2, loader3, subset_len, val_loader


def get_Fourdata(train_path, val_path, batch_size, num_workers):
    """
    加载训练集和验证集。
    包括 4 个客户端验证集(df、f2f、fs、nt)和 1 个全局验证集。

    Args:
        train_path (str): 训练数据路径
        val_path (str): 验证数据路径
        batch_size (int): 批量大小
        num_workers (int): DataLoader 的工作线程数

    Returns:
        tuple: 包含 4 个客户端训练和验证加载器,以及全局验证加载器
    """
    # 数据预处理
    train_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # 定义 4 个客户端数据集路径
    client_names = ['df', 'f2f', 'fs', 'nt']
    client_train_loaders = []
    client_val_loaders = []
    
    for client in client_names:
        client_train_path = os.path.join(train_path, client)
        client_val_path = os.path.join(val_path, client)
        
        # 加载客户端训练数据
        train_dataset = datasets.ImageFolder(root=client_train_path, transform=train_transform)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        
        # 加载客户端验证数据
        val_dataset = datasets.ImageFolder(root=client_val_path, transform=val_transform)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
        
        client_train_loaders.append(train_loader)
        client_val_loaders.append(val_loader)
    
    # 全局验证集
    global_val_dataset = datasets.ImageFolder(root=val_path, transform=val_transform)
    global_val_loader = DataLoader(global_val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    return (*client_train_loaders, *client_val_loaders, global_val_loader)


def main():
    # 设置参数
    train_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/train"
    val_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/val"
    batch_size = 4
    num_workers = 2
    
    # 获取数据加载器
    loader1, loader2, loader3, subset_len, val_loader = get_data(train_image_path, val_image_path, batch_size,
                                                                 num_workers)
    
    # 统计标签数量和类型
    train_labels = []
    for dataset in [loader1, loader2, loader3]:
        for _, labels in dataset:
            train_labels.extend(labels.tolist())
    
    val_labels = []
    for _, labels in val_loader:
        val_labels.extend(labels.tolist())
    
    # 使用 Counter 统计标签数量
    train_label_counts = Counter(train_labels)
    val_label_counts = Counter(val_labels)
    
    # 打印统计结果
    print("Training Dataset - Label Counts:", train_label_counts)
    print("Validation Dataset - Label Counts:", val_label_counts)
    print("Label Types in Training:", set(train_labels))
    print("Label Types in Validation:", set(val_labels))


if __name__ == "__main__":
    main()