240 lines
9.2 KiB
Python
240 lines
9.2 KiB
Python
import os
|
||
from collections import Counter
|
||
|
||
import torch
|
||
from PIL import Image
|
||
from sklearn.model_selection import train_test_split
|
||
from torch.utils.data import DataLoader
|
||
from torch.utils.data import Dataset, random_split
|
||
from torchvision import transforms, datasets
|
||
|
||
|
||
class CustomImageDataset(Dataset):
|
||
def __init__(self, root_dir, transform=None):
|
||
self.root_dir = root_dir
|
||
self.transform = transform
|
||
self.image_paths = []
|
||
self.labels = []
|
||
|
||
# 遍历 root_dir 下的子文件夹 0 和 1
|
||
for label in [0, 1]:
|
||
folder_path = os.path.join(root_dir, str(label))
|
||
if os.path.isdir(folder_path):
|
||
for img_name in os.listdir(folder_path):
|
||
img_path = os.path.join(folder_path, img_name)
|
||
self.image_paths.append(img_path)
|
||
self.labels.append(label)
|
||
|
||
# 打印用于调试的图像路径和标签
|
||
# print("Loaded image paths and labels:")
|
||
# for path, label in zip(self.image_paths[:5], self.labels[:5]):
|
||
# print(f"Path: {path}, Label: {label}")
|
||
# print(f"Total samples: {len(self.image_paths)}\n")
|
||
|
||
def __len__(self):
|
||
return len(self.image_paths)
|
||
|
||
def __getitem__(self, idx):
|
||
img_path = self.image_paths[idx]
|
||
label = self.labels[idx]
|
||
image = Image.open(img_path).convert("RGB")
|
||
|
||
if self.transform:
|
||
image = self.transform(image)
|
||
|
||
return image, label
|
||
|
||
|
||
def get_test_data(test_image_path, batch_size, nw):
|
||
data_transform = transforms.Compose([
|
||
transforms.Resize((256, 256)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||
])
|
||
|
||
# test_dataset = datasets.ImageFolder(root=test_image_path, transform=data_transform)
|
||
|
||
test_dataset = CustomImageDataset(root_dir=test_image_path, transform=data_transform)
|
||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=nw)
|
||
return test_loader
|
||
|
||
|
||
def get_Onedata(train_image_path, val_image_path, batch_size, num_workers):
|
||
"""
|
||
加载完整的训练数据集和验证数据集。
|
||
"""
|
||
data_transform = {
|
||
"train": transforms.Compose([
|
||
transforms.Resize((256, 256)),
|
||
transforms.RandomHorizontalFlip(),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||
]),
|
||
"val": transforms.Compose([
|
||
transforms.Resize((256, 256)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||
]),
|
||
}
|
||
|
||
# 创建训练和验证数据集
|
||
train_dataset = CustomImageDataset(root_dir=train_image_path, transform=data_transform["train"])
|
||
val_dataset = CustomImageDataset(root_dir=val_image_path, transform=data_transform["val"])
|
||
|
||
# 创建数据加载器
|
||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
||
|
||
return train_loader, val_loader
|
||
|
||
|
||
def get_data(train_image_path, val_image_path, batch_size, num_workers):
|
||
data_transform = {
|
||
"train": transforms.Compose([
|
||
transforms.Resize((256, 256)),
|
||
transforms.RandomHorizontalFlip(),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||
]),
|
||
"val": transforms.Compose([
|
||
transforms.Resize((256, 256)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||
]),
|
||
"test": transforms.Compose([
|
||
transforms.Resize((256, 256)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||
]),
|
||
}
|
||
|
||
train_dataset = CustomImageDataset(root_dir=train_image_path, transform=data_transform["train"])
|
||
val_dataset = CustomImageDataset(root_dir=val_image_path, transform=data_transform["val"])
|
||
|
||
# 切分数据集为三个等分
|
||
train_len = (len(train_dataset) // 3) * 3
|
||
train_dataset_truncated = torch.utils.data.Subset(train_dataset, range(train_len))
|
||
subset_len = train_len // 3
|
||
dataset1, dataset2, dataset3 = random_split(train_dataset_truncated, [subset_len] * 3)
|
||
|
||
loader1 = DataLoader(dataset1, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
||
loader2 = DataLoader(dataset2, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
||
loader3 = DataLoader(dataset3, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
||
|
||
return loader1, loader2, loader3, subset_len, val_loader
|
||
|
||
|
||
def get_Fourdata(train_path, val_path, batch_size, num_workers):
|
||
"""
|
||
加载训练集和验证集。
|
||
包括 4 个客户端验证集(df、f2f、fs、nt)和 1 个全局验证集。
|
||
|
||
Args:
|
||
train_path (str): 训练数据路径
|
||
val_path (str): 验证数据路径
|
||
batch_size (int): 批量大小
|
||
num_workers (int): DataLoader 的工作线程数
|
||
|
||
Returns:
|
||
tuple: 包含 4 个客户端训练和验证加载器,以及全局验证加载器
|
||
"""
|
||
# 数据预处理
|
||
train_transform = transforms.Compose([
|
||
transforms.Resize((256, 256)),
|
||
transforms.RandomHorizontalFlip(),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||
])
|
||
|
||
val_transform = transforms.Compose([
|
||
transforms.Resize((256, 256)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||
])
|
||
|
||
# 定义 4 个客户端数据集路径
|
||
client_names = ['df', 'f2f', 'fs', 'nt']
|
||
client_train_loaders = []
|
||
client_val_loaders = []
|
||
|
||
for client in client_names:
|
||
client_train_path = os.path.join(train_path, client)
|
||
client_val_path = os.path.join(val_path, client)
|
||
|
||
# 加载客户端训练数据
|
||
train_dataset = datasets.ImageFolder(root=client_train_path, transform=train_transform)
|
||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
||
|
||
# 加载客户端验证数据
|
||
val_dataset = datasets.ImageFolder(root=client_val_path, transform=val_transform)
|
||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
||
|
||
client_train_loaders.append(train_loader)
|
||
client_val_loaders.append(val_loader)
|
||
|
||
# 全局验证集
|
||
global_val_dataset = datasets.ImageFolder(root=val_path, transform=val_transform)
|
||
global_val_loader = DataLoader(global_val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
||
|
||
return (*client_train_loaders, *client_val_loaders, global_val_loader)
|
||
|
||
|
||
def get_federated_data(train_path, val_path, num_clients=3, batch_size=16, num_workers=8):
|
||
"""
|
||
将数据集划分为多个客户端,每个客户端拥有独立的训练和验证数据。
|
||
"""
|
||
# 加载完整数据集
|
||
full_train_dataset = CustomImageDataset(root_dir=train_path, transform=get_transform("train"))
|
||
full_val_dataset = CustomImageDataset(root_dir=val_path, transform=get_transform("val"))
|
||
|
||
# 划分客户端训练集
|
||
client_train_datasets = random_split(full_train_dataset, [len(full_train_dataset) // num_clients] * num_clients)
|
||
|
||
# 创建客户端数据加载器
|
||
client_train_loaders = [
|
||
DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
||
for ds in client_train_datasets
|
||
]
|
||
|
||
# 全局验证集
|
||
global_val_loader = DataLoader(full_val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
||
|
||
return client_train_loaders, global_val_loader
|
||
|
||
|
||
def main():
|
||
# 设置参数
|
||
train_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/train"
|
||
val_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/val"
|
||
batch_size = 4
|
||
num_workers = 2
|
||
|
||
# 获取数据加载器
|
||
loader1, loader2, loader3, subset_len, val_loader = get_data(train_image_path, val_image_path, batch_size,
|
||
num_workers)
|
||
|
||
# 统计标签数量和类型
|
||
train_labels = []
|
||
for dataset in [loader1, loader2, loader3]:
|
||
for _, labels in dataset:
|
||
train_labels.extend(labels.tolist())
|
||
|
||
val_labels = []
|
||
for _, labels in val_loader:
|
||
val_labels.extend(labels.tolist())
|
||
|
||
# 使用 Counter 统计标签数量
|
||
train_label_counts = Counter(train_labels)
|
||
val_label_counts = Counter(val_labels)
|
||
|
||
# 打印统计结果
|
||
print("Training Dataset - Label Counts:", train_label_counts)
|
||
print("Validation Dataset - Label Counts:", val_label_counts)
|
||
print("Label Types in Training:", set(train_labels))
|
||
print("Label Types in Validation:", set(val_labels))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|