import argparse import torch import os from torch import optim from torch.optim import lr_scheduler from util.data_utils import get_data from util.model_utils import get_model from util.train_utils import train_model, validate_model, update_model_weights, v3_update_model_weights def main(args): device = torch.device(args.device) # 数据加载器 loader1, loader2, loader3, subset_len, val_loader = get_data( args.train_path, args.val_path, args.batch_size, args.number_workers ) # 模型 get_model(name='ResNet', number_class=2, device=device, resnet_type='resnet18') model_a = get_model(args.model_name, args.number_class, device, args.deep_backbone).to(device) model_b = get_model(args.model_name, args.number_class, device, args.deep_backbone).to(device) model_c = get_model(args.model_name, args.number_class, device, args.deep_backbone).to(device) # 添加全局模型 global_model = get_model(args.model_name, args.number_class, device, args.deep_backbone).to(device) if args.resume_training: model_a.load_state_dict(torch.load(os.path.join(args.save_dir, 'best_model_a.pth'))) model_b.load_state_dict(torch.load(os.path.join(args.save_dir, 'best_model_b.pth'))) model_c.load_state_dict(torch.load(os.path.join(args.save_dir, 'best_model_c.pth'))) print("已加载之前保存的模型参数继续训练") # 优化器和损失函数 criterion = torch.nn.BCEWithLogitsLoss().to(device) optimizer_a = optim.Adam(model_a.parameters(), lr=args.lr, weight_decay=5e-4) optimizer_b = optim.Adam(model_b.parameters(), lr=args.lr, weight_decay=5e-4) optimizer_c = optim.Adam(model_c.parameters(), lr=args.lr, weight_decay=5e-4) scheduler_a = lr_scheduler.ReduceLROnPlateau(optimizer_a, mode='min', factor=0.5, patience=2, verbose=True) scheduler_b = lr_scheduler.ReduceLROnPlateau(optimizer_b, mode='min', factor=0.5, patience=2, verbose=True) scheduler_c = lr_scheduler.ReduceLROnPlateau(optimizer_c, mode='min', factor=0.5, patience=2, verbose=True) # 初始化最优验证损失和模型路径 best_val_loss_a = float('inf') best_val_loss_b = float('inf') best_val_loss_c = float('inf') save_dir = args.save_dir os.makedirs(save_dir, exist_ok=True) # 训练与验证 for epoch in range(args.epochs): print(f'Epoch {epoch + 1}/{args.epochs}') # 训练模型 loss_a = train_model(device, model_a, loader1, optimizer_a, criterion, epoch, 'model_a') loss_b = train_model(device, model_b, loader2, optimizer_b, criterion, epoch, 'model_b') loss_c = train_model(device, model_c, loader3, optimizer_c, criterion, epoch, 'model_c') # 验证模型 val_loss_a, val_acc_a, val_auc_a = validate_model(device, model_a, val_loader, criterion, epoch, 'model_a') val_loss_b, val_acc_b, val_auc_b = validate_model(device, model_b, val_loader, criterion, epoch, 'model_b') val_loss_c, val_acc_c, val_auc_c = validate_model(device, model_c, val_loader, criterion, epoch, 'model_c') if args.save_model and val_loss_a < best_val_loss_a: best_val_loss_a = val_loss_a torch.save(model_a.state_dict(), os.path.join(save_dir, 'best_model_a.pth')) print(f"Best model_a saved with val_loss: {best_val_loss_a:.4f}") if args.save_model and val_loss_b < best_val_loss_b: best_val_loss_b = val_loss_b torch.save(model_b.state_dict(), os.path.join(save_dir, 'best_model_b.pth')) print(f"Best model_b saved with val_loss: {best_val_loss_b:.4f}") if args.save_model and val_loss_c < best_val_loss_c: best_val_loss_c = val_loss_c torch.save(model_c.state_dict(), os.path.join(save_dir, 'best_model_c.pth')) print(f"Best model_c saved with val_loss: {best_val_loss_c:.4f}") print( f'Model A - Loss: {loss_a:.4f}, Val Loss: {val_loss_a:.4f}, Val Acc: {val_acc_a:.4f}, AUC: {val_auc_a:.4f}') print( f'Model B - Loss: {loss_b:.4f}, Val Loss: {val_loss_b:.4f}, Val Acc: {val_acc_b:.4f}, AUC: {val_auc_b:.4f}') print( f'Model C - Loss: {loss_c:.4f}, Val Loss: {val_loss_c:.4f}, Val Acc: {val_acc_c:.4f}, AUC: {val_auc_c:.4f}') # 更新模型 A 的权重,每 3 轮 1 val_acc_a, val_auc_a, val_acc_a_threshold = v3_update_model_weights( epoch=epoch, model_to_update=model_a, other_models=[model_a, model_b, model_c], global_model=global_model, losses=[loss_a, loss_b, loss_c], val_loader=val_loader, device=device, val_auc_threshold=val_auc_a, validate_model=validate_model, criterion=criterion, update_frequency=1 ) # 更新模型 B 的权重,每 5 轮1 val_acc_b, val_auc_b, val_acc_b_threshold = v3_update_model_weights( epoch=epoch, model_to_update=model_b, other_models=[model_a, model_b, model_c], global_model=global_model, losses=[loss_a, loss_b, loss_c], val_loader=val_loader, device=device, val_auc_threshold=val_auc_b, validate_model=validate_model, criterion=criterion, update_frequency=1 ) # 更新模型 C 的权重,每 2 轮 1 val_acc_c, val_auc_c, val_acc_c_threshold = v3_update_model_weights( epoch=epoch, model_to_update=model_c, other_models=[model_a, model_b, model_c], global_model=global_model, losses=[loss_a, loss_b, loss_c], val_loader=val_loader, device=device, val_auc_threshold=val_auc_c, validate_model=validate_model, criterion=criterion, update_frequency=1 ) print("Training complete! Best models saved.") if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--model_name', type=str, default='resnet18_psa', help='Model name') parser.add_argument('--deep_backbone', type=str, default='*', help='deeplab backbone') parser.add_argument('--train_path', type=str, default='/media/terminator/实验&代码/yhs/FF++/c40/total/train') parser.add_argument('--val_path', type=str, default='/media/terminator/实验&代码/yhs/FF++/c40/total/val') # parser.add_argument('--train_path', type=str, default='/media/terminator/实验&代码/yhs/FF++_mask_sample/c23/df/train') # parser.add_argument('--val_path', type=str, default='/media/terminator/实验&代码/yhs/FF++_mask_sample/c23/df/val') parser.add_argument('--epochs', type=int, default=10) parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--number_workers', type=int, default=8) parser.add_argument('--number_class', type=int, default=1) parser.add_argument('--device', type=str, default='cuda:0') parser.add_argument('--lr', type=float, default=0.00005) parser.add_argument('--save_dir', type=str, default='/media/terminator/实验&代码/yhs/output/work2/resnet18_psa/c40/total/e10', help='Directory to save best models') parser.add_argument('--save_model', type=bool, default=True, help='是否保存最优模型') parser.add_argument('--resume_training', type=bool, default=False, help='是否从保存的模型参数继续训练') args = parser.parse_args() main(args)