联邦学习示例项目:更改结构
This commit is contained in:
parent
1930e1b96b
commit
34a5247dd2
@ -3,9 +3,9 @@ import torch
|
||||
import os
|
||||
from torch import optim
|
||||
from torch.optim import lr_scheduler
|
||||
from util.data_utils import get_data
|
||||
from util.model_utils import get_model
|
||||
from util.train_utils import train_model, validate_model, update_model_weights, v3_update_model_weights
|
||||
from fed_example.utils.data_utils import get_data
|
||||
from fed_example.utils.model_utils import get_model
|
||||
from fed_example.utils.train_utils import train_model, validate_model, v3_update_model_weights
|
||||
|
||||
|
||||
def main(args):
|
@ -1,13 +1,12 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader, Dataset, random_split
|
||||
from collections import Counter
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from torchvision import transforms, datasets
|
||||
import os
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Dataset, random_split
|
||||
from torchvision import transforms, datasets
|
||||
|
||||
|
||||
class CustomImageDataset(Dataset):
|
||||
@ -181,6 +180,29 @@ def get_Fourdata(train_path, val_path, batch_size, num_workers):
|
||||
return (*client_train_loaders, *client_val_loaders, global_val_loader)
|
||||
|
||||
|
||||
def get_federated_data(train_path, val_path, num_clients=3, batch_size=16, num_workers=8):
|
||||
"""
|
||||
将数据集划分为多个客户端,每个客户端拥有独立的训练和验证数据。
|
||||
"""
|
||||
# 加载完整数据集
|
||||
full_train_dataset = CustomImageDataset(root_dir=train_path, transform=get_transform("train"))
|
||||
full_val_dataset = CustomImageDataset(root_dir=val_path, transform=get_transform("val"))
|
||||
|
||||
# 划分客户端训练集
|
||||
client_train_datasets = random_split(full_train_dataset, [len(full_train_dataset) // num_clients] * num_clients)
|
||||
|
||||
# 创建客户端数据加载器
|
||||
client_train_loaders = [
|
||||
DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
||||
for ds in client_train_datasets
|
||||
]
|
||||
|
||||
# 全局验证集
|
||||
global_val_loader = DataLoader(full_val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
||||
|
||||
return client_train_loaders, global_val_loader
|
||||
|
||||
|
||||
def main():
|
||||
# 设置参数
|
||||
train_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/train"
|
@ -55,3 +55,11 @@ def get_model(name, number_class, device, backbone):
|
||||
else:
|
||||
raise ValueError(f"Model {name} is not supported.")
|
||||
return model
|
||||
|
||||
def get_federated_model(device):
|
||||
"""初始化客户端模型和全局模型"""
|
||||
client_models = [
|
||||
get_model("resnet18_psa", 1, device, "*") for _ in range(3)
|
||||
]
|
||||
global_model = get_model("resnet18_psa", 1, device, "*")
|
||||
return client_models, global_model
|
@ -116,6 +116,7 @@ def test_deepmodel(device, model, loader):
|
||||
# avg_loss = running_loss / len(loader)
|
||||
# print(f'{model_name} Training Loss: {avg_loss:.4f}')
|
||||
# return avg_loss
|
||||
|
||||
def train_model(device, model, loader, optimizer, criterion, epoch, model_name):
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
@ -331,6 +332,7 @@ def f_update_model_weights(
|
||||
updated_val_auc_threshold (float): 更新后的验证 AUC 阈值。
|
||||
"""
|
||||
# 每隔指定的 epoch 更新一次模型权重
|
||||
|
||||
if (epoch + 1) % update_frequency == 0:
|
||||
print(f"\n[Epoch {epoch + 1}] Updating global model weights...")
|
||||
|
Loading…
Reference in New Issue
Block a user