58 lines
2.5 KiB
Python
58 lines
2.5 KiB
Python
import torch
|
|
from torch import nn
|
|
from torchvision import models
|
|
|
|
from Deeplab.deeplab import DeepLab_F
|
|
from Deeplab.resnet_psa import BasicBlockWithPSA
|
|
from Deeplab.resnet_psa_v2 import ResNet
|
|
from model_base.efNet_base_model import DeepLab
|
|
from model_base.efficientnet import EfficientNet
|
|
from model_base.resnet_more import CustomResNet
|
|
from model_base.xcption import Xception
|
|
|
|
|
|
def get_model(name, number_class, device, backbone):
|
|
"""
|
|
根据指定的模型名称加载模型,并根据任务类别数调整最后的分类层。
|
|
|
|
Args:
|
|
name (str): 模型名称 ('Vgg', 'ResNet', 'EfficientNet', 'Xception')。
|
|
number_class (int): 分类类别数。
|
|
device (torch.device): 设备 ('cuda' or 'cpu')。
|
|
resnet_type (str): ResNet类型 ('resnet18', 'resnet34', 'resnet50', 'resnet101', etc.)。
|
|
|
|
Returns:
|
|
nn.Module: 经过修改的模型。
|
|
"""
|
|
if name == 'Vgg':
|
|
model = models.vgg16_bn(pretrained=True).to(device)
|
|
model.classifier[6] = nn.Linear(model.classifier[6].in_features, number_class)
|
|
elif name == 'ResNet18':
|
|
model = CustomResNet(resnet_type='resnet18', num_classes=number_class, pretrained=True).to(device)
|
|
elif name == 'ResNet34':
|
|
model = CustomResNet(resnet_type='resnet34', num_classes=number_class, pretrained=True).to(device)
|
|
elif name == 'ResNet50':
|
|
model = CustomResNet(resnet_type='resnet50', num_classes=number_class, pretrained=True).to(device)
|
|
elif name == 'ResNet101':
|
|
model = CustomResNet(resnet_type='resnet101', num_classes=number_class, pretrained=True).to(device)
|
|
elif name == 'ResNet152':
|
|
model = CustomResNet(resnet_type='resnet152', num_classes=number_class, pretrained=True).to(device)
|
|
elif name == 'EfficientNet':
|
|
# 使用自定义的 DeepLab 类加载 EfficientNet
|
|
model = DeepLab(backbone='efficientnet', num_classes=number_class).to(device)
|
|
elif name == 'Xception':
|
|
model = Xception(
|
|
in_planes=3,
|
|
num_classes=number_class,
|
|
pretrained=True,
|
|
pretrained_path="/home/terminator/1325/yhs/fedLeaning/pre_model/xception-43020ad28.pth"
|
|
).to(device)
|
|
elif name == 'DeepLab':
|
|
# 使用自定义的 DeepLab 类加载 EfficientNet
|
|
model = DeepLab_F(num_classes=1, backbone=backbone).to(device)
|
|
elif name == 'resnet18_psa':
|
|
model = ResNet(BasicBlockWithPSA, [2, 2, 2, 2], number_class)
|
|
else:
|
|
raise ValueError(f"Model {name} is not supported.")
|
|
return model
|