156 lines
7.6 KiB
Python
156 lines
7.6 KiB
Python
import argparse
|
|
import torch
|
|
import os
|
|
from torch import optim
|
|
from torch.optim import lr_scheduler
|
|
from fed_example.utils.data_utils import get_data
|
|
from fed_example.utils.model_utils import get_model
|
|
from fed_example.utils.train_utils import train_model, validate_model, v3_update_model_weights
|
|
|
|
|
|
def main(args):
|
|
device = torch.device(args.device)
|
|
|
|
# 数据加载器
|
|
loader1, loader2, loader3, subset_len, val_loader = get_data(
|
|
args.train_path, args.val_path, args.batch_size, args.number_workers
|
|
)
|
|
|
|
# 模型 get_model(name='ResNet', number_class=2, device=device, resnet_type='resnet18')
|
|
model_a = get_model(args.model_name, args.number_class, device, args.deep_backbone).to(device)
|
|
model_b = get_model(args.model_name, args.number_class, device, args.deep_backbone).to(device)
|
|
model_c = get_model(args.model_name, args.number_class, device, args.deep_backbone).to(device)
|
|
# 添加全局模型
|
|
global_model = get_model(args.model_name, args.number_class, device, args.deep_backbone).to(device)
|
|
|
|
if args.resume_training:
|
|
model_a.load_state_dict(torch.load(os.path.join(args.save_dir, 'best_model_a.pth')))
|
|
model_b.load_state_dict(torch.load(os.path.join(args.save_dir, 'best_model_b.pth')))
|
|
model_c.load_state_dict(torch.load(os.path.join(args.save_dir, 'best_model_c.pth')))
|
|
print("已加载之前保存的模型参数继续训练")
|
|
|
|
# 优化器和损失函数
|
|
criterion = torch.nn.BCEWithLogitsLoss().to(device)
|
|
|
|
optimizer_a = optim.Adam(model_a.parameters(), lr=args.lr, weight_decay=5e-4)
|
|
optimizer_b = optim.Adam(model_b.parameters(), lr=args.lr, weight_decay=5e-4)
|
|
optimizer_c = optim.Adam(model_c.parameters(), lr=args.lr, weight_decay=5e-4)
|
|
scheduler_a = lr_scheduler.ReduceLROnPlateau(optimizer_a, mode='min', factor=0.5, patience=2, verbose=True)
|
|
scheduler_b = lr_scheduler.ReduceLROnPlateau(optimizer_b, mode='min', factor=0.5, patience=2, verbose=True)
|
|
scheduler_c = lr_scheduler.ReduceLROnPlateau(optimizer_c, mode='min', factor=0.5, patience=2, verbose=True)
|
|
|
|
# 初始化最优验证损失和模型路径
|
|
best_val_loss_a = float('inf')
|
|
best_val_loss_b = float('inf')
|
|
best_val_loss_c = float('inf')
|
|
|
|
save_dir = args.save_dir
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
# 训练与验证
|
|
for epoch in range(args.epochs):
|
|
print(f'Epoch {epoch + 1}/{args.epochs}')
|
|
|
|
# 训练模型
|
|
loss_a = train_model(device, model_a, loader1, optimizer_a, criterion, epoch, 'model_a')
|
|
loss_b = train_model(device, model_b, loader2, optimizer_b, criterion, epoch, 'model_b')
|
|
loss_c = train_model(device, model_c, loader3, optimizer_c, criterion, epoch, 'model_c')
|
|
|
|
# 验证模型
|
|
val_loss_a, val_acc_a, val_auc_a = validate_model(device, model_a, val_loader, criterion, epoch, 'model_a')
|
|
val_loss_b, val_acc_b, val_auc_b = validate_model(device, model_b, val_loader, criterion, epoch, 'model_b')
|
|
val_loss_c, val_acc_c, val_auc_c = validate_model(device, model_c, val_loader, criterion, epoch, 'model_c')
|
|
|
|
if args.save_model and val_loss_a < best_val_loss_a:
|
|
best_val_loss_a = val_loss_a
|
|
torch.save(model_a.state_dict(), os.path.join(save_dir, 'best_model_a.pth'))
|
|
print(f"Best model_a saved with val_loss: {best_val_loss_a:.4f}")
|
|
|
|
if args.save_model and val_loss_b < best_val_loss_b:
|
|
best_val_loss_b = val_loss_b
|
|
torch.save(model_b.state_dict(), os.path.join(save_dir, 'best_model_b.pth'))
|
|
print(f"Best model_b saved with val_loss: {best_val_loss_b:.4f}")
|
|
|
|
if args.save_model and val_loss_c < best_val_loss_c:
|
|
best_val_loss_c = val_loss_c
|
|
torch.save(model_c.state_dict(), os.path.join(save_dir, 'best_model_c.pth'))
|
|
print(f"Best model_c saved with val_loss: {best_val_loss_c:.4f}")
|
|
|
|
print(
|
|
f'Model A - Loss: {loss_a:.4f}, Val Loss: {val_loss_a:.4f}, Val Acc: {val_acc_a:.4f}, AUC: {val_auc_a:.4f}')
|
|
print(
|
|
f'Model B - Loss: {loss_b:.4f}, Val Loss: {val_loss_b:.4f}, Val Acc: {val_acc_b:.4f}, AUC: {val_auc_b:.4f}')
|
|
print(
|
|
f'Model C - Loss: {loss_c:.4f}, Val Loss: {val_loss_c:.4f}, Val Acc: {val_acc_c:.4f}, AUC: {val_auc_c:.4f}')
|
|
|
|
# 更新模型 A 的权重,每 3 轮 1
|
|
val_acc_a, val_auc_a, val_acc_a_threshold = v3_update_model_weights(
|
|
epoch=epoch,
|
|
model_to_update=model_a,
|
|
other_models=[model_a, model_b, model_c],
|
|
global_model=global_model,
|
|
losses=[loss_a, loss_b, loss_c],
|
|
val_loader=val_loader,
|
|
device=device,
|
|
val_auc_threshold=val_auc_a,
|
|
validate_model=validate_model,
|
|
criterion=criterion,
|
|
update_frequency=1
|
|
)
|
|
|
|
# 更新模型 B 的权重,每 5 轮1
|
|
val_acc_b, val_auc_b, val_acc_b_threshold = v3_update_model_weights(
|
|
epoch=epoch,
|
|
model_to_update=model_b,
|
|
other_models=[model_a, model_b, model_c],
|
|
global_model=global_model,
|
|
losses=[loss_a, loss_b, loss_c],
|
|
val_loader=val_loader,
|
|
device=device,
|
|
val_auc_threshold=val_auc_b,
|
|
validate_model=validate_model,
|
|
criterion=criterion,
|
|
update_frequency=1
|
|
)
|
|
|
|
# 更新模型 C 的权重,每 2 轮 1
|
|
val_acc_c, val_auc_c, val_acc_c_threshold = v3_update_model_weights(
|
|
epoch=epoch,
|
|
model_to_update=model_c,
|
|
other_models=[model_a, model_b, model_c],
|
|
global_model=global_model,
|
|
losses=[loss_a, loss_b, loss_c],
|
|
val_loader=val_loader,
|
|
device=device,
|
|
val_auc_threshold=val_auc_c,
|
|
validate_model=validate_model,
|
|
criterion=criterion,
|
|
update_frequency=1
|
|
)
|
|
|
|
print("Training complete! Best models saved.")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model_name', type=str, default='resnet18_psa', help='Model name')
|
|
parser.add_argument('--deep_backbone', type=str, default='*', help='deeplab backbone')
|
|
parser.add_argument('--train_path', type=str, default='/media/terminator/实验&代码/yhs/FF++/c40/total/train')
|
|
parser.add_argument('--val_path', type=str, default='/media/terminator/实验&代码/yhs/FF++/c40/total/val')
|
|
# parser.add_argument('--train_path', type=str, default='/media/terminator/实验&代码/yhs/FF++_mask_sample/c23/df/train')
|
|
# parser.add_argument('--val_path', type=str, default='/media/terminator/实验&代码/yhs/FF++_mask_sample/c23/df/val')
|
|
parser.add_argument('--epochs', type=int, default=10)
|
|
parser.add_argument('--batch_size', type=int, default=16)
|
|
parser.add_argument('--number_workers', type=int, default=8)
|
|
parser.add_argument('--number_class', type=int, default=1)
|
|
parser.add_argument('--device', type=str, default='cuda:0')
|
|
parser.add_argument('--lr', type=float, default=0.00005)
|
|
parser.add_argument('--save_dir', type=str,
|
|
default='/media/terminator/实验&代码/yhs/output/work2/resnet18_psa/c40/total/e10',
|
|
help='Directory to save best models')
|
|
parser.add_argument('--save_model', type=bool, default=True, help='是否保存最优模型')
|
|
parser.add_argument('--resume_training', type=bool, default=False, help='是否从保存的模型参数继续训练')
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|