联邦学习示例项目:更改结构

This commit is contained in:
myh 2025-04-20 15:19:24 +08:00
parent 1930e1b96b
commit 34a5247dd2
5 changed files with 42 additions and 10 deletions

View File

@ -3,9 +3,9 @@ import torch
import os
from torch import optim
from torch.optim import lr_scheduler
from util.data_utils import get_data
from util.model_utils import get_model
from util.train_utils import train_model, validate_model, update_model_weights, v3_update_model_weights
from fed_example.utils.data_utils import get_data
from fed_example.utils.model_utils import get_model
from fed_example.utils.train_utils import train_model, validate_model, v3_update_model_weights
def main(args):

View File

@ -1,13 +1,12 @@
import os
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, random_split
from collections import Counter
from torch.utils.data import DataLoader, Subset
from torchvision import transforms, datasets
import os
import torch
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, random_split
from torchvision import transforms, datasets
class CustomImageDataset(Dataset):
@ -181,6 +180,29 @@ def get_Fourdata(train_path, val_path, batch_size, num_workers):
return (*client_train_loaders, *client_val_loaders, global_val_loader)
def get_federated_data(train_path, val_path, num_clients=3, batch_size=16, num_workers=8):
"""
将数据集划分为多个客户端每个客户端拥有独立的训练和验证数据
"""
# 加载完整数据集
full_train_dataset = CustomImageDataset(root_dir=train_path, transform=get_transform("train"))
full_val_dataset = CustomImageDataset(root_dir=val_path, transform=get_transform("val"))
# 划分客户端训练集
client_train_datasets = random_split(full_train_dataset, [len(full_train_dataset) // num_clients] * num_clients)
# 创建客户端数据加载器
client_train_loaders = [
DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
for ds in client_train_datasets
]
# 全局验证集
global_val_loader = DataLoader(full_val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return client_train_loaders, global_val_loader
def main():
# 设置参数
train_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/train"

View File

@ -55,3 +55,11 @@ def get_model(name, number_class, device, backbone):
else:
raise ValueError(f"Model {name} is not supported.")
return model
def get_federated_model(device):
"""初始化客户端模型和全局模型"""
client_models = [
get_model("resnet18_psa", 1, device, "*") for _ in range(3)
]
global_model = get_model("resnet18_psa", 1, device, "*")
return client_models, global_model

View File

@ -116,6 +116,7 @@ def test_deepmodel(device, model, loader):
# avg_loss = running_loss / len(loader)
# print(f'{model_name} Training Loss: {avg_loss:.4f}')
# return avg_loss
def train_model(device, model, loader, optimizer, criterion, epoch, model_name):
model.train()
running_loss = 0.0
@ -331,6 +332,7 @@ def f_update_model_weights(
updated_val_auc_threshold (float): 更新后的验证 AUC 阈值
"""
# 每隔指定的 epoch 更新一次模型权重
if (epoch + 1) % update_frequency == 0:
print(f"\n[Epoch {epoch + 1}] Updating global model weights...")