联邦学习示例项目:更改结构
This commit is contained in:
		| @@ -3,9 +3,9 @@ import torch | |||||||
| import os | import os | ||||||
| from torch import optim | from torch import optim | ||||||
| from torch.optim import lr_scheduler | from torch.optim import lr_scheduler | ||||||
| from util.data_utils import get_data | from fed_example.utils.data_utils import get_data | ||||||
| from util.model_utils import get_model | from fed_example.utils.model_utils import get_model | ||||||
| from util.train_utils import train_model, validate_model, update_model_weights, v3_update_model_weights | from fed_example.utils.train_utils import train_model, validate_model, v3_update_model_weights | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def main(args): | def main(args): | ||||||
| @@ -1,13 +1,12 @@ | |||||||
| import os | import os | ||||||
| from PIL import Image |  | ||||||
| import torch |  | ||||||
| from torchvision import transforms |  | ||||||
| from torch.utils.data import DataLoader, Dataset, random_split |  | ||||||
| from collections import Counter | from collections import Counter | ||||||
| from torch.utils.data import DataLoader, Subset | 
 | ||||||
| from torchvision import transforms, datasets | import torch | ||||||
| import os | from PIL import Image | ||||||
| from sklearn.model_selection import train_test_split | from sklearn.model_selection import train_test_split | ||||||
|  | from torch.utils.data import DataLoader | ||||||
|  | from torch.utils.data import Dataset, random_split | ||||||
|  | from torchvision import transforms, datasets | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class CustomImageDataset(Dataset): | class CustomImageDataset(Dataset): | ||||||
| @@ -181,6 +180,29 @@ def get_Fourdata(train_path, val_path, batch_size, num_workers): | |||||||
|     return (*client_train_loaders, *client_val_loaders, global_val_loader) |     return (*client_train_loaders, *client_val_loaders, global_val_loader) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def get_federated_data(train_path, val_path, num_clients=3, batch_size=16, num_workers=8): | ||||||
|  |     """ | ||||||
|  |     将数据集划分为多个客户端,每个客户端拥有独立的训练和验证数据。 | ||||||
|  |     """ | ||||||
|  |     # 加载完整数据集 | ||||||
|  |     full_train_dataset = CustomImageDataset(root_dir=train_path, transform=get_transform("train")) | ||||||
|  |     full_val_dataset = CustomImageDataset(root_dir=val_path, transform=get_transform("val")) | ||||||
|  |      | ||||||
|  |     # 划分客户端训练集 | ||||||
|  |     client_train_datasets = random_split(full_train_dataset, [len(full_train_dataset) // num_clients] * num_clients) | ||||||
|  |      | ||||||
|  |     # 创建客户端数据加载器 | ||||||
|  |     client_train_loaders = [ | ||||||
|  |         DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=num_workers) | ||||||
|  |         for ds in client_train_datasets | ||||||
|  |     ] | ||||||
|  |      | ||||||
|  |     # 全局验证集 | ||||||
|  |     global_val_loader = DataLoader(full_val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) | ||||||
|  |      | ||||||
|  |     return client_train_loaders, global_val_loader | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def main(): | def main(): | ||||||
|     # 设置参数 |     # 设置参数 | ||||||
|     train_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/train" |     train_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/train" | ||||||
| @@ -55,3 +55,11 @@ def get_model(name, number_class, device, backbone): | |||||||
|     else: |     else: | ||||||
|         raise ValueError(f"Model {name} is not supported.") |         raise ValueError(f"Model {name} is not supported.") | ||||||
|     return model |     return model | ||||||
|  | 
 | ||||||
|  | def get_federated_model(device): | ||||||
|  |     """初始化客户端模型和全局模型""" | ||||||
|  |     client_models = [ | ||||||
|  |         get_model("resnet18_psa", 1, device, "*") for _ in range(3) | ||||||
|  |     ] | ||||||
|  |     global_model = get_model("resnet18_psa", 1, device, "*") | ||||||
|  |     return client_models, global_model | ||||||
| @@ -116,6 +116,7 @@ def test_deepmodel(device, model, loader): | |||||||
| #     avg_loss = running_loss / len(loader) | #     avg_loss = running_loss / len(loader) | ||||||
| #     print(f'{model_name} Training Loss: {avg_loss:.4f}') | #     print(f'{model_name} Training Loss: {avg_loss:.4f}') | ||||||
| #     return avg_loss | #     return avg_loss | ||||||
|  | 
 | ||||||
| def train_model(device, model, loader, optimizer, criterion, epoch, model_name): | def train_model(device, model, loader, optimizer, criterion, epoch, model_name): | ||||||
|     model.train() |     model.train() | ||||||
|     running_loss = 0.0 |     running_loss = 0.0 | ||||||
| @@ -331,6 +332,7 @@ def f_update_model_weights( | |||||||
|         updated_val_auc_threshold (float): 更新后的验证 AUC 阈值。 |         updated_val_auc_threshold (float): 更新后的验证 AUC 阈值。 | ||||||
|     """ |     """ | ||||||
|     # 每隔指定的 epoch 更新一次模型权重 |     # 每隔指定的 epoch 更新一次模型权重 | ||||||
|  |      | ||||||
|     if (epoch + 1) % update_frequency == 0: |     if (epoch + 1) % update_frequency == 0: | ||||||
|         print(f"\n[Epoch {epoch + 1}] Updating global model weights...") |         print(f"\n[Epoch {epoch + 1}] Updating global model weights...") | ||||||
|          |          | ||||||
		Reference in New Issue
	
	Block a user