Graduation-Project/federated_learning/utils/model_utils.py
2025-04-18 22:15:25 +08:00

58 lines
2.5 KiB
Python

import torch
from torch import nn
from torchvision import models
from Deeplab.deeplab import DeepLab_F
from Deeplab.resnet_psa import BasicBlockWithPSA
from Deeplab.resnet_psa_v2 import ResNet
from model_base.efNet_base_model import DeepLab
from model_base.efficientnet import EfficientNet
from model_base.resnet_more import CustomResNet
from model_base.xcption import Xception
def get_model(name, number_class, device, backbone):
"""
根据指定的模型名称加载模型,并根据任务类别数调整最后的分类层。
Args:
name (str): 模型名称 ('Vgg', 'ResNet', 'EfficientNet', 'Xception')。
number_class (int): 分类类别数。
device (torch.device): 设备 ('cuda' or 'cpu')。
resnet_type (str): ResNet类型 ('resnet18', 'resnet34', 'resnet50', 'resnet101', etc.)。
Returns:
nn.Module: 经过修改的模型。
"""
if name == 'Vgg':
model = models.vgg16_bn(pretrained=True).to(device)
model.classifier[6] = nn.Linear(model.classifier[6].in_features, number_class)
elif name == 'ResNet18':
model = CustomResNet(resnet_type='resnet18', num_classes=number_class, pretrained=True).to(device)
elif name == 'ResNet34':
model = CustomResNet(resnet_type='resnet34', num_classes=number_class, pretrained=True).to(device)
elif name == 'ResNet50':
model = CustomResNet(resnet_type='resnet50', num_classes=number_class, pretrained=True).to(device)
elif name == 'ResNet101':
model = CustomResNet(resnet_type='resnet101', num_classes=number_class, pretrained=True).to(device)
elif name == 'ResNet152':
model = CustomResNet(resnet_type='resnet152', num_classes=number_class, pretrained=True).to(device)
elif name == 'EfficientNet':
# 使用自定义的 DeepLab 类加载 EfficientNet
model = DeepLab(backbone='efficientnet', num_classes=number_class).to(device)
elif name == 'Xception':
model = Xception(
in_planes=3,
num_classes=number_class,
pretrained=True,
pretrained_path="/home/terminator/1325/yhs/fedLeaning/pre_model/xception-43020ad28.pth"
).to(device)
elif name == 'DeepLab':
# 使用自定义的 DeepLab 类加载 EfficientNet
model = DeepLab_F(num_classes=1, backbone=backbone).to(device)
elif name == 'resnet18_psa':
model = ResNet(BasicBlockWithPSA, [2, 2, 2, 2], number_class)
else:
raise ValueError(f"Model {name} is not supported.")
return model