Graduation-Project/federated_learning/utils/data_utils.py
2025-04-18 22:15:25 +08:00

218 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, random_split
from collections import Counter
from torch.utils.data import DataLoader, Subset
from torchvision import transforms, datasets
import os
from sklearn.model_selection import train_test_split
class CustomImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_paths = []
self.labels = []
# 遍历 root_dir 下的子文件夹 0 和 1
for label in [0, 1]:
folder_path = os.path.join(root_dir, str(label))
if os.path.isdir(folder_path):
for img_name in os.listdir(folder_path):
img_path = os.path.join(folder_path, img_name)
self.image_paths.append(img_path)
self.labels.append(label)
# 打印用于调试的图像路径和标签
# print("Loaded image paths and labels:")
# for path, label in zip(self.image_paths[:5], self.labels[:5]):
# print(f"Path: {path}, Label: {label}")
# print(f"Total samples: {len(self.image_paths)}\n")
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
label = self.labels[idx]
image = Image.open(img_path).convert("RGB")
if self.transform:
image = self.transform(image)
return image, label
def get_test_data(test_image_path, batch_size, nw):
data_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# test_dataset = datasets.ImageFolder(root=test_image_path, transform=data_transform)
test_dataset = CustomImageDataset(root_dir=test_image_path, transform=data_transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=nw)
return test_loader
def get_Onedata(train_image_path, val_image_path, batch_size, num_workers):
"""
加载完整的训练数据集和验证数据集。
"""
data_transform = {
"train": transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
"val": transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
# 创建训练和验证数据集
train_dataset = CustomImageDataset(root_dir=train_image_path, transform=data_transform["train"])
val_dataset = CustomImageDataset(root_dir=val_image_path, transform=data_transform["val"])
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return train_loader, val_loader
def get_data(train_image_path, val_image_path, batch_size, num_workers):
data_transform = {
"train": transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
"val": transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
"test": transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
train_dataset = CustomImageDataset(root_dir=train_image_path, transform=data_transform["train"])
val_dataset = CustomImageDataset(root_dir=val_image_path, transform=data_transform["val"])
# 切分数据集为三个等分
train_len = (len(train_dataset) // 3) * 3
train_dataset_truncated = torch.utils.data.Subset(train_dataset, range(train_len))
subset_len = train_len // 3
dataset1, dataset2, dataset3 = random_split(train_dataset_truncated, [subset_len] * 3)
loader1 = DataLoader(dataset1, batch_size=batch_size, shuffle=True, num_workers=num_workers)
loader2 = DataLoader(dataset2, batch_size=batch_size, shuffle=True, num_workers=num_workers)
loader3 = DataLoader(dataset3, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return loader1, loader2, loader3, subset_len, val_loader
def get_Fourdata(train_path, val_path, batch_size, num_workers):
"""
加载训练集和验证集。
包括 4 个客户端验证集df、f2f、fs、nt和 1 个全局验证集。
Args:
train_path (str): 训练数据路径
val_path (str): 验证数据路径
batch_size (int): 批量大小
num_workers (int): DataLoader 的工作线程数
Returns:
tuple: 包含 4 个客户端训练和验证加载器,以及全局验证加载器
"""
# 数据预处理
train_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 定义 4 个客户端数据集路径
client_names = ['df', 'f2f', 'fs', 'nt']
client_train_loaders = []
client_val_loaders = []
for client in client_names:
client_train_path = os.path.join(train_path, client)
client_val_path = os.path.join(val_path, client)
# 加载客户端训练数据
train_dataset = datasets.ImageFolder(root=client_train_path, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
# 加载客户端验证数据
val_dataset = datasets.ImageFolder(root=client_val_path, transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
client_train_loaders.append(train_loader)
client_val_loaders.append(val_loader)
# 全局验证集
global_val_dataset = datasets.ImageFolder(root=val_path, transform=val_transform)
global_val_loader = DataLoader(global_val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return (*client_train_loaders, *client_val_loaders, global_val_loader)
def main():
# 设置参数
train_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/train"
val_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/val"
batch_size = 4
num_workers = 2
# 获取数据加载器
loader1, loader2, loader3, subset_len, val_loader = get_data(train_image_path, val_image_path, batch_size,
num_workers)
# 统计标签数量和类型
train_labels = []
for dataset in [loader1, loader2, loader3]:
for _, labels in dataset:
train_labels.extend(labels.tolist())
val_labels = []
for _, labels in val_loader:
val_labels.extend(labels.tolist())
# 使用 Counter 统计标签数量
train_label_counts = Counter(train_labels)
val_label_counts = Counter(val_labels)
# 打印统计结果
print("Training Dataset - Label Counts:", train_label_counts)
print("Validation Dataset - Label Counts:", val_label_counts)
print("Label Types in Training:", set(train_labels))
print("Label Types in Validation:", set(val_labels))
if __name__ == "__main__":
main()