diff --git a/federated_learning/res18Train.py b/fed_example/res18Train.py similarity index 97% rename from federated_learning/res18Train.py rename to fed_example/res18Train.py index 8754446..aef3011 100644 --- a/federated_learning/res18Train.py +++ b/fed_example/res18Train.py @@ -3,9 +3,9 @@ import torch import os from torch import optim from torch.optim import lr_scheduler -from util.data_utils import get_data -from util.model_utils import get_model -from util.train_utils import train_model, validate_model, update_model_weights, v3_update_model_weights +from fed_example.utils.data_utils import get_data +from fed_example.utils.model_utils import get_model +from fed_example.utils.train_utils import train_model, validate_model, v3_update_model_weights def main(args): diff --git a/federated_learning/utils/__init__.py b/fed_example/utils/__init__.py similarity index 100% rename from federated_learning/utils/__init__.py rename to fed_example/utils/__init__.py diff --git a/federated_learning/utils/data_utils.py b/fed_example/utils/data_utils.py similarity index 88% rename from federated_learning/utils/data_utils.py rename to fed_example/utils/data_utils.py index 6d900db..e7f322f 100644 --- a/federated_learning/utils/data_utils.py +++ b/fed_example/utils/data_utils.py @@ -1,13 +1,12 @@ import os -from PIL import Image -import torch -from torchvision import transforms -from torch.utils.data import DataLoader, Dataset, random_split from collections import Counter -from torch.utils.data import DataLoader, Subset -from torchvision import transforms, datasets -import os + +import torch +from PIL import Image from sklearn.model_selection import train_test_split +from torch.utils.data import DataLoader +from torch.utils.data import Dataset, random_split +from torchvision import transforms, datasets class CustomImageDataset(Dataset): @@ -181,6 +180,29 @@ def get_Fourdata(train_path, val_path, batch_size, num_workers): return (*client_train_loaders, *client_val_loaders, global_val_loader) +def get_federated_data(train_path, val_path, num_clients=3, batch_size=16, num_workers=8): + """ + 将数据集划分为多个客户端,每个客户端拥有独立的训练和验证数据。 + """ + # 加载完整数据集 + full_train_dataset = CustomImageDataset(root_dir=train_path, transform=get_transform("train")) + full_val_dataset = CustomImageDataset(root_dir=val_path, transform=get_transform("val")) + + # 划分客户端训练集 + client_train_datasets = random_split(full_train_dataset, [len(full_train_dataset) // num_clients] * num_clients) + + # 创建客户端数据加载器 + client_train_loaders = [ + DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=num_workers) + for ds in client_train_datasets + ] + + # 全局验证集 + global_val_loader = DataLoader(full_val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + return client_train_loaders, global_val_loader + + def main(): # 设置参数 train_image_path = "/media/terminator/实验&代码/yhs/FF++_mask/c23/f2f/train" diff --git a/federated_learning/utils/model_utils.py b/fed_example/utils/model_utils.py similarity index 90% rename from federated_learning/utils/model_utils.py rename to fed_example/utils/model_utils.py index 9749aec..71664ad 100644 --- a/federated_learning/utils/model_utils.py +++ b/fed_example/utils/model_utils.py @@ -55,3 +55,11 @@ def get_model(name, number_class, device, backbone): else: raise ValueError(f"Model {name} is not supported.") return model + +def get_federated_model(device): + """初始化客户端模型和全局模型""" + client_models = [ + get_model("resnet18_psa", 1, device, "*") for _ in range(3) + ] + global_model = get_model("resnet18_psa", 1, device, "*") + return client_models, global_model \ No newline at end of file diff --git a/federated_learning/utils/train_utils.py b/fed_example/utils/train_utils.py similarity index 99% rename from federated_learning/utils/train_utils.py rename to fed_example/utils/train_utils.py index 988d220..6992eb2 100644 --- a/federated_learning/utils/train_utils.py +++ b/fed_example/utils/train_utils.py @@ -116,6 +116,7 @@ def test_deepmodel(device, model, loader): # avg_loss = running_loss / len(loader) # print(f'{model_name} Training Loss: {avg_loss:.4f}') # return avg_loss + def train_model(device, model, loader, optimizer, criterion, epoch, model_name): model.train() running_loss = 0.0 @@ -331,6 +332,7 @@ def f_update_model_weights( updated_val_auc_threshold (float): 更新后的验证 AUC 阈值。 """ # 每隔指定的 epoch 更新一次模型权重 + if (epoch + 1) % update_frequency == 0: print(f"\n[Epoch {epoch + 1}] Updating global model weights...")