import os from PIL import Image import torch from torchvision import transforms from torch.utils.data import DataLoader, Dataset, random_split from collections import Counter from torch.utils.data import DataLoader, Subset from torchvision import transforms, datasets import os from sklearn.model_selection import train_test_split class CustomImageDataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.image_paths = [] self.labels = [] # 遍历 root_dir 下的子文件夹 0 和 1 for label in [0, 1]: folder_path = os.path.join(root_dir, str(label)) if os.path.isdir(folder_path): for img_name in os.listdir(folder_path): img_path = os.path.join(folder_path, img_name) self.image_paths.append(img_path) self.labels.append(label) # 打印用于调试的图像路径和标签 # print("Loaded image paths and labels:") # for path, label in zip(self.image_paths[:5], self.labels[:5]): # print(f"Path: {path}, Label: {label}") # print(f"Total samples: {len(self.image_paths)}\n") def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] label = self.labels[idx] image = Image.open(img_path).convert("RGB") if self.transform: image = self.transform(image) return image, label def get_test_data(test_image_path, batch_size, nw): data_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # test_dataset = datasets.ImageFolder(root=test_image_path, transform=data_transform) test_dataset = CustomImageDataset(root_dir=test_image_path, transform=data_transform) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=nw) return test_loader def get_Onedata(train_image_path, val_image_path, batch_size, num_workers): """ 加载完整的训练数据集和验证数据集。 """ data_transform = { "train": transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), "val": transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } # 创建训练和验证数据集 train_dataset = CustomImageDataset(root_dir=train_image_path, transform=data_transform["train"]) val_dataset = CustomImageDataset(root_dir=val_image_path, transform=data_transform["val"]) # 创建数据加载器 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) return train_loader, val_loader def get_data(train_image_path, val_image_path, batch_size, num_workers): data_transform = { "train": transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), "val": transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), "test": transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } train_dataset = CustomImageDataset(root_dir=train_image_path, transform=data_transform["train"]) val_dataset = CustomImageDataset(root_dir=val_image_path, transform=data_transform["val"]) # 切分数据集为三个等分 train_len = (len(train_dataset) // 3) * 3 train_dataset_truncated = torch.utils.data.Subset(train_dataset, range(train_len)) subset_len = train_len // 3 dataset1, dataset2, dataset3 = random_split(train_dataset_truncated, [subset_len] * 3) loader1 = DataLoader(dataset1, batch_size=batch_size, shuffle=True, num_workers=num_workers) loader2 = DataLoader(dataset2, batch_size=batch_size, shuffle=True, num_workers=num_workers) loader3 = DataLoader(dataset3, batch_size=batch_size, shuffle=True, num_workers=num_workers) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) return loader1, loader2, loader3, subset_len, val_loader def get_Fourdata(train_path, val_path, batch_size, num_workers): """ 加载训练集和验证集。 包括 4 个客户端验证集(df、f2f、fs、nt)和 1 个全局验证集。 Args: train_path (str): 训练数据路径 val_path (str): 验证数据路径 batch_size (int): 批量大小 num_workers (int): DataLoader 的工作线程数 Returns: tuple: 包含 4 个客户端训练和验证加载器,以及全局验证加载器 """ # 数据预处理 train_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 定义 4 个客户端数据集路径 client_names = ['df', 'f2f', 'fs', 'nt'] client_train_loaders = [] client_val_loaders = [] for client in client_names: client_train_path = os.path.join(train_path, client) client_val_path = os.path.join(val_path, client) # 加载客户端训练数据 train_dataset = datasets.ImageFolder(root=client_train_path, transform=train_transform) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) # 加载客户端验证数据 val_dataset = datasets.ImageFolder(root=client_val_path, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) client_train_loaders.append(train_loader) client_val_loaders.append(val_loader) # 全局验证集 global_val_dataset = datasets.ImageFolder(root=val_path, transform=val_transform) global_val_loader = DataLoader(global_val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) return (*client_train_loaders, *client_val_loaders, global_val_loader) def main(): # 设置参数 train_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/train" val_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/val" batch_size = 4 num_workers = 2 # 获取数据加载器 loader1, loader2, loader3, subset_len, val_loader = get_data(train_image_path, val_image_path, batch_size, num_workers) # 统计标签数量和类型 train_labels = [] for dataset in [loader1, loader2, loader3]: for _, labels in dataset: train_labels.extend(labels.tolist()) val_labels = [] for _, labels in val_loader: val_labels.extend(labels.tolist()) # 使用 Counter 统计标签数量 train_label_counts = Counter(train_labels) val_label_counts = Counter(val_labels) # 打印统计结果 print("Training Dataset - Label Counts:", train_label_counts) print("Validation Dataset - Label Counts:", val_label_counts) print("Label Types in Training:", set(train_labels)) print("Label Types in Validation:", set(val_labels)) if __name__ == "__main__": main()