import torch from torch import nn from torchvision import models from Deeplab.deeplab import DeepLab_F from Deeplab.resnet_psa import BasicBlockWithPSA from Deeplab.resnet_psa_v2 import ResNet from model_base.efNet_base_model import DeepLab from model_base.efficientnet import EfficientNet from model_base.resnet_more import CustomResNet from model_base.xcption import Xception def get_model(name, number_class, device, backbone): """ 根据指定的模型名称加载模型,并根据任务类别数调整最后的分类层。 Args: name (str): 模型名称 ('Vgg', 'ResNet', 'EfficientNet', 'Xception')。 number_class (int): 分类类别数。 device (torch.device): 设备 ('cuda' or 'cpu')。 resnet_type (str): ResNet类型 ('resnet18', 'resnet34', 'resnet50', 'resnet101', etc.)。 Returns: nn.Module: 经过修改的模型。 """ if name == 'Vgg': model = models.vgg16_bn(pretrained=True).to(device) model.classifier[6] = nn.Linear(model.classifier[6].in_features, number_class) elif name == 'ResNet18': model = CustomResNet(resnet_type='resnet18', num_classes=number_class, pretrained=True).to(device) elif name == 'ResNet34': model = CustomResNet(resnet_type='resnet34', num_classes=number_class, pretrained=True).to(device) elif name == 'ResNet50': model = CustomResNet(resnet_type='resnet50', num_classes=number_class, pretrained=True).to(device) elif name == 'ResNet101': model = CustomResNet(resnet_type='resnet101', num_classes=number_class, pretrained=True).to(device) elif name == 'ResNet152': model = CustomResNet(resnet_type='resnet152', num_classes=number_class, pretrained=True).to(device) elif name == 'EfficientNet': # 使用自定义的 DeepLab 类加载 EfficientNet model = DeepLab(backbone='efficientnet', num_classes=number_class).to(device) elif name == 'Xception': model = Xception( in_planes=3, num_classes=number_class, pretrained=True, pretrained_path="/home/terminator/1325/yhs/fedLeaning/pre_model/xception-43020ad28.pth" ).to(device) elif name == 'DeepLab': # 使用自定义的 DeepLab 类加载 EfficientNet model = DeepLab_F(num_classes=1, backbone=backbone).to(device) elif name == 'resnet18_psa': model = ResNet(BasicBlockWithPSA, [2, 2, 2, 2], number_class) else: raise ValueError(f"Model {name} is not supported.") return model