联邦学习示例项目:更改结构
This commit is contained in:
parent
1930e1b96b
commit
34a5247dd2
@ -3,9 +3,9 @@ import torch
|
|||||||
import os
|
import os
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch.optim import lr_scheduler
|
from torch.optim import lr_scheduler
|
||||||
from util.data_utils import get_data
|
from fed_example.utils.data_utils import get_data
|
||||||
from util.model_utils import get_model
|
from fed_example.utils.model_utils import get_model
|
||||||
from util.train_utils import train_model, validate_model, update_model_weights, v3_update_model_weights
|
from fed_example.utils.train_utils import train_model, validate_model, v3_update_model_weights
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
@ -1,13 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
from torchvision import transforms
|
|
||||||
from torch.utils.data import DataLoader, Dataset, random_split
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from torch.utils.data import DataLoader, Subset
|
|
||||||
from torchvision import transforms, datasets
|
import torch
|
||||||
import os
|
from PIL import Image
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data import Dataset, random_split
|
||||||
|
from torchvision import transforms, datasets
|
||||||
|
|
||||||
|
|
||||||
class CustomImageDataset(Dataset):
|
class CustomImageDataset(Dataset):
|
||||||
@ -181,6 +180,29 @@ def get_Fourdata(train_path, val_path, batch_size, num_workers):
|
|||||||
return (*client_train_loaders, *client_val_loaders, global_val_loader)
|
return (*client_train_loaders, *client_val_loaders, global_val_loader)
|
||||||
|
|
||||||
|
|
||||||
|
def get_federated_data(train_path, val_path, num_clients=3, batch_size=16, num_workers=8):
|
||||||
|
"""
|
||||||
|
将数据集划分为多个客户端,每个客户端拥有独立的训练和验证数据。
|
||||||
|
"""
|
||||||
|
# 加载完整数据集
|
||||||
|
full_train_dataset = CustomImageDataset(root_dir=train_path, transform=get_transform("train"))
|
||||||
|
full_val_dataset = CustomImageDataset(root_dir=val_path, transform=get_transform("val"))
|
||||||
|
|
||||||
|
# 划分客户端训练集
|
||||||
|
client_train_datasets = random_split(full_train_dataset, [len(full_train_dataset) // num_clients] * num_clients)
|
||||||
|
|
||||||
|
# 创建客户端数据加载器
|
||||||
|
client_train_loaders = [
|
||||||
|
DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
||||||
|
for ds in client_train_datasets
|
||||||
|
]
|
||||||
|
|
||||||
|
# 全局验证集
|
||||||
|
global_val_loader = DataLoader(full_val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
||||||
|
|
||||||
|
return client_train_loaders, global_val_loader
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 设置参数
|
# 设置参数
|
||||||
train_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/train"
|
train_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/train"
|
@ -55,3 +55,11 @@ def get_model(name, number_class, device, backbone):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Model {name} is not supported.")
|
raise ValueError(f"Model {name} is not supported.")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def get_federated_model(device):
|
||||||
|
"""初始化客户端模型和全局模型"""
|
||||||
|
client_models = [
|
||||||
|
get_model("resnet18_psa", 1, device, "*") for _ in range(3)
|
||||||
|
]
|
||||||
|
global_model = get_model("resnet18_psa", 1, device, "*")
|
||||||
|
return client_models, global_model
|
@ -116,6 +116,7 @@ def test_deepmodel(device, model, loader):
|
|||||||
# avg_loss = running_loss / len(loader)
|
# avg_loss = running_loss / len(loader)
|
||||||
# print(f'{model_name} Training Loss: {avg_loss:.4f}')
|
# print(f'{model_name} Training Loss: {avg_loss:.4f}')
|
||||||
# return avg_loss
|
# return avg_loss
|
||||||
|
|
||||||
def train_model(device, model, loader, optimizer, criterion, epoch, model_name):
|
def train_model(device, model, loader, optimizer, criterion, epoch, model_name):
|
||||||
model.train()
|
model.train()
|
||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
@ -331,6 +332,7 @@ def f_update_model_weights(
|
|||||||
updated_val_auc_threshold (float): 更新后的验证 AUC 阈值。
|
updated_val_auc_threshold (float): 更新后的验证 AUC 阈值。
|
||||||
"""
|
"""
|
||||||
# 每隔指定的 epoch 更新一次模型权重
|
# 每隔指定的 epoch 更新一次模型权重
|
||||||
|
|
||||||
if (epoch + 1) % update_frequency == 0:
|
if (epoch + 1) % update_frequency == 0:
|
||||||
print(f"\n[Epoch {epoch + 1}] Updating global model weights...")
|
print(f"\n[Epoch {epoch + 1}] Updating global model weights...")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user