联邦学习示例项目:更改结构
This commit is contained in:
		| @@ -3,9 +3,9 @@ import torch | ||||
| import os | ||||
| from torch import optim | ||||
| from torch.optim import lr_scheduler | ||||
| from util.data_utils import get_data | ||||
| from util.model_utils import get_model | ||||
| from util.train_utils import train_model, validate_model, update_model_weights, v3_update_model_weights | ||||
| from fed_example.utils.data_utils import get_data | ||||
| from fed_example.utils.model_utils import get_model | ||||
| from fed_example.utils.train_utils import train_model, validate_model, v3_update_model_weights | ||||
| 
 | ||||
| 
 | ||||
| def main(args): | ||||
| @@ -1,13 +1,12 @@ | ||||
| import os | ||||
| from PIL import Image | ||||
| import torch | ||||
| from torchvision import transforms | ||||
| from torch.utils.data import DataLoader, Dataset, random_split | ||||
| from collections import Counter | ||||
| from torch.utils.data import DataLoader, Subset | ||||
| from torchvision import transforms, datasets | ||||
| import os | ||||
| 
 | ||||
| import torch | ||||
| from PIL import Image | ||||
| from sklearn.model_selection import train_test_split | ||||
| from torch.utils.data import DataLoader | ||||
| from torch.utils.data import Dataset, random_split | ||||
| from torchvision import transforms, datasets | ||||
| 
 | ||||
| 
 | ||||
| class CustomImageDataset(Dataset): | ||||
| @@ -181,6 +180,29 @@ def get_Fourdata(train_path, val_path, batch_size, num_workers): | ||||
|     return (*client_train_loaders, *client_val_loaders, global_val_loader) | ||||
| 
 | ||||
| 
 | ||||
| def get_federated_data(train_path, val_path, num_clients=3, batch_size=16, num_workers=8): | ||||
|     """ | ||||
|     将数据集划分为多个客户端,每个客户端拥有独立的训练和验证数据。 | ||||
|     """ | ||||
|     # 加载完整数据集 | ||||
|     full_train_dataset = CustomImageDataset(root_dir=train_path, transform=get_transform("train")) | ||||
|     full_val_dataset = CustomImageDataset(root_dir=val_path, transform=get_transform("val")) | ||||
|      | ||||
|     # 划分客户端训练集 | ||||
|     client_train_datasets = random_split(full_train_dataset, [len(full_train_dataset) // num_clients] * num_clients) | ||||
|      | ||||
|     # 创建客户端数据加载器 | ||||
|     client_train_loaders = [ | ||||
|         DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=num_workers) | ||||
|         for ds in client_train_datasets | ||||
|     ] | ||||
|      | ||||
|     # 全局验证集 | ||||
|     global_val_loader = DataLoader(full_val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) | ||||
|      | ||||
|     return client_train_loaders, global_val_loader | ||||
| 
 | ||||
| 
 | ||||
| def main(): | ||||
|     # 设置参数 | ||||
|     train_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/train" | ||||
| @@ -55,3 +55,11 @@ def get_model(name, number_class, device, backbone): | ||||
|     else: | ||||
|         raise ValueError(f"Model {name} is not supported.") | ||||
|     return model | ||||
| 
 | ||||
| def get_federated_model(device): | ||||
|     """初始化客户端模型和全局模型""" | ||||
|     client_models = [ | ||||
|         get_model("resnet18_psa", 1, device, "*") for _ in range(3) | ||||
|     ] | ||||
|     global_model = get_model("resnet18_psa", 1, device, "*") | ||||
|     return client_models, global_model | ||||
| @@ -116,6 +116,7 @@ def test_deepmodel(device, model, loader): | ||||
| #     avg_loss = running_loss / len(loader) | ||||
| #     print(f'{model_name} Training Loss: {avg_loss:.4f}') | ||||
| #     return avg_loss | ||||
| 
 | ||||
| def train_model(device, model, loader, optimizer, criterion, epoch, model_name): | ||||
|     model.train() | ||||
|     running_loss = 0.0 | ||||
| @@ -331,6 +332,7 @@ def f_update_model_weights( | ||||
|         updated_val_auc_threshold (float): 更新后的验证 AUC 阈值。 | ||||
|     """ | ||||
|     # 每隔指定的 epoch 更新一次模型权重 | ||||
|      | ||||
|     if (epoch + 1) % update_frequency == 0: | ||||
|         print(f"\n[Epoch {epoch + 1}] Updating global model weights...") | ||||
|          | ||||
		Reference in New Issue
	
	Block a user