332 lines
11 KiB
Python
332 lines
11 KiB
Python
|
import copy
|
||
|
import csv
|
||
|
import os
|
||
|
import warnings
|
||
|
from argparse import ArgumentParser
|
||
|
from typing import cast
|
||
|
|
||
|
import torch
|
||
|
import tqdm
|
||
|
import yaml
|
||
|
from torch.utils import data
|
||
|
from torch.amp.autocast_mode import autocast
|
||
|
|
||
|
from nets import nn
|
||
|
from utils import util
|
||
|
from utils.dataset import Dataset
|
||
|
|
||
|
warnings.filterwarnings("ignore")
|
||
|
|
||
|
data_dir = "/home/image1325/ssd1/dataset/coco"
|
||
|
|
||
|
|
||
|
def train(args, params):
|
||
|
# Model
|
||
|
model = nn.yolo_v11_n(len(params["names"]))
|
||
|
model.cuda()
|
||
|
|
||
|
# Optimizer
|
||
|
accumulate = max(round(64 / (args.batch_size * args.world_size)), 1)
|
||
|
params["weight_decay"] *= args.batch_size * args.world_size * accumulate / 64
|
||
|
|
||
|
optimizer = torch.optim.SGD(
|
||
|
util.set_params(model, params["weight_decay"]), params["min_lr"], params["momentum"], nesterov=True
|
||
|
)
|
||
|
|
||
|
# EMA
|
||
|
ema = util.EMA(model) if args.local_rank == 0 else None
|
||
|
|
||
|
filenames = []
|
||
|
with open(f"{data_dir}/train2017.txt") as f:
|
||
|
for filename in f.readlines():
|
||
|
filename = os.path.basename(filename.rstrip())
|
||
|
filenames.append(f"{data_dir}/images/train2017/" + filename)
|
||
|
|
||
|
sampler = None
|
||
|
dataset = Dataset(filenames, args.input_size, params, augment=True)
|
||
|
|
||
|
if args.distributed:
|
||
|
sampler = data.DistributedSampler(dataset)
|
||
|
|
||
|
loader = data.DataLoader(
|
||
|
dataset,
|
||
|
args.batch_size,
|
||
|
sampler is None,
|
||
|
sampler,
|
||
|
num_workers=8,
|
||
|
pin_memory=True,
|
||
|
collate_fn=Dataset.collate_fn,
|
||
|
)
|
||
|
|
||
|
# Scheduler
|
||
|
num_steps = len(loader)
|
||
|
scheduler = util.LinearLR(args, params, num_steps)
|
||
|
|
||
|
if args.distributed:
|
||
|
# DDP mode
|
||
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||
|
model = torch.nn.parallel.DistributedDataParallel(
|
||
|
module=model, device_ids=[args.local_rank], output_device=args.local_rank
|
||
|
)
|
||
|
|
||
|
best = 0
|
||
|
amp_scale = torch.amp.grad_scaler.GradScaler()
|
||
|
criterion = util.ComputeLoss(model, params)
|
||
|
|
||
|
with open("weights/step.csv", "w") as log:
|
||
|
if args.local_rank == 0:
|
||
|
logger = csv.DictWriter(
|
||
|
log, fieldnames=["epoch", "box", "cls", "dfl", "Recall", "Precision", "mAP@50", "mAP"]
|
||
|
)
|
||
|
logger.writeheader()
|
||
|
|
||
|
for epoch in range(args.epochs):
|
||
|
model.train()
|
||
|
if args.distributed and sampler:
|
||
|
sampler.set_epoch(epoch)
|
||
|
if args.epochs - epoch == 10:
|
||
|
ds = cast(Dataset, loader.dataset)
|
||
|
ds.mosaic = False
|
||
|
|
||
|
p_bar = enumerate(loader)
|
||
|
|
||
|
if args.local_rank == 0:
|
||
|
print(("\n" + "%10s" * 5) % ("epoch", "memory", "box", "cls", "dfl"))
|
||
|
p_bar = tqdm.tqdm(p_bar, total=num_steps, ascii=" >-")
|
||
|
|
||
|
optimizer.zero_grad()
|
||
|
avg_box_loss = util.AverageMeter()
|
||
|
avg_cls_loss = util.AverageMeter()
|
||
|
avg_dfl_loss = util.AverageMeter()
|
||
|
for i, (samples, targets) in p_bar:
|
||
|
step = i + num_steps * epoch
|
||
|
scheduler.step(step, optimizer)
|
||
|
|
||
|
samples = samples.cuda().float() / 255
|
||
|
|
||
|
# Forward
|
||
|
with autocast("cuda"):
|
||
|
outputs = model(samples) # forward
|
||
|
loss_box, loss_cls, loss_dfl = criterion(outputs, targets)
|
||
|
|
||
|
avg_box_loss.update(loss_box.item(), samples.size(0))
|
||
|
avg_cls_loss.update(loss_cls.item(), samples.size(0))
|
||
|
avg_dfl_loss.update(loss_dfl.item(), samples.size(0))
|
||
|
|
||
|
loss_box *= args.batch_size # loss scaled by batch_size
|
||
|
loss_cls *= args.batch_size # loss scaled by batch_size
|
||
|
loss_dfl *= args.batch_size # loss scaled by batch_size
|
||
|
loss_box *= args.world_size # gradient averaged between devices in DDP mode
|
||
|
loss_cls *= args.world_size # gradient averaged between devices in DDP mode
|
||
|
loss_dfl *= args.world_size # gradient averaged between devices in DDP mode
|
||
|
|
||
|
# Backward
|
||
|
amp_scale.scale(loss_box + loss_cls + loss_dfl).backward()
|
||
|
|
||
|
# Optimize
|
||
|
if step % accumulate == 0:
|
||
|
# amp_scale.unscale_(optimizer) # unscale gradients
|
||
|
# util.clip_gradients(model) # clip gradients
|
||
|
amp_scale.step(optimizer) # optimizer.step
|
||
|
amp_scale.update()
|
||
|
optimizer.zero_grad()
|
||
|
if ema:
|
||
|
ema.update(model)
|
||
|
|
||
|
torch.cuda.synchronize()
|
||
|
|
||
|
# Log
|
||
|
if args.local_rank == 0:
|
||
|
memory = f"{torch.cuda.memory_reserved() / 1e9:.4g}G" # (GB)
|
||
|
s = ("%10s" * 2 + "%10.3g" * 3) % (
|
||
|
f"{epoch + 1}/{args.epochs}",
|
||
|
memory,
|
||
|
avg_box_loss.avg,
|
||
|
avg_cls_loss.avg,
|
||
|
avg_dfl_loss.avg,
|
||
|
)
|
||
|
p_bar = cast(tqdm.tqdm, p_bar)
|
||
|
p_bar.set_description(s)
|
||
|
|
||
|
if args.local_rank == 0:
|
||
|
# mAP
|
||
|
last = test(args, params, ema.ema if ema else None)
|
||
|
|
||
|
logger.writerow(
|
||
|
{
|
||
|
"epoch": str(epoch + 1).zfill(3),
|
||
|
"box": str(f"{avg_box_loss.avg:.3f}"),
|
||
|
"cls": str(f"{avg_cls_loss.avg:.3f}"),
|
||
|
"dfl": str(f"{avg_dfl_loss.avg:.3f}"),
|
||
|
"mAP": str(f"{last[0]:.3f}"),
|
||
|
"mAP@50": str(f"{last[1]:.3f}"),
|
||
|
"Recall": str(f"{last[2]:.3f}"),
|
||
|
"Precision": str(f"{last[3]:.3f}"),
|
||
|
}
|
||
|
)
|
||
|
log.flush()
|
||
|
|
||
|
# Update best mAP
|
||
|
if last[0] > best:
|
||
|
best = last[0]
|
||
|
|
||
|
# Save model
|
||
|
save = {"epoch": epoch + 1, "model": copy.deepcopy(ema.ema if ema else None)}
|
||
|
|
||
|
# Save last, best and delete
|
||
|
torch.save(save, f="./weights/last.pt")
|
||
|
if best == last[0]:
|
||
|
torch.save(save, f="./weights/best.pt")
|
||
|
del save
|
||
|
|
||
|
if args.local_rank == 0:
|
||
|
util.strip_optimizer("./weights/best.pt") # strip optimizers
|
||
|
util.strip_optimizer("./weights/last.pt") # strip optimizers
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def test(args, params, model=None):
|
||
|
filenames = []
|
||
|
with open(f"{data_dir}/val2017.txt") as f:
|
||
|
for filename in f.readlines():
|
||
|
filename = os.path.basename(filename.rstrip())
|
||
|
filenames.append(f"{data_dir}/images/val2017/" + filename)
|
||
|
|
||
|
dataset = Dataset(filenames, args.input_size, params, augment=False)
|
||
|
loader = data.DataLoader(
|
||
|
dataset, batch_size=4, shuffle=False, num_workers=4, pin_memory=True, collate_fn=Dataset.collate_fn
|
||
|
)
|
||
|
|
||
|
plot = False
|
||
|
if not model:
|
||
|
plot = True
|
||
|
model = torch.load(f="./weights/best.pt", map_location="cuda", weights_only=False)
|
||
|
model = model["model"].float().fuse()
|
||
|
|
||
|
model.half()
|
||
|
model.eval()
|
||
|
|
||
|
# Configure
|
||
|
iou_v = torch.linspace(start=0.5, end=0.95, steps=10).cuda() # iou vector for mAP@0.5:0.95
|
||
|
n_iou = iou_v.numel()
|
||
|
|
||
|
m_pre = 0
|
||
|
m_rec = 0
|
||
|
map50 = 0
|
||
|
mean_ap = 0
|
||
|
metrics = []
|
||
|
p_bar = tqdm.tqdm(loader, desc=("%10s" * 5) % ("", "precision", "recall", "mAP50", "mAP"), ascii=" >-")
|
||
|
for samples, targets in p_bar:
|
||
|
samples = samples.cuda()
|
||
|
samples = samples.half() # uint8 to fp16/32
|
||
|
samples = samples / 255.0 # 0 - 255 to 0.0 - 1.0
|
||
|
_, _, h, w = samples.shape # batch-size, channels, height, width
|
||
|
scale = torch.tensor((w, h, w, h)).cuda()
|
||
|
# Inference
|
||
|
outputs = model(samples)
|
||
|
# NMS
|
||
|
outputs = util.non_max_suppression(outputs)
|
||
|
# Metrics
|
||
|
for i, output in enumerate(outputs):
|
||
|
# Ensure idx is a 1D boolean mask (squeeze any trailing dimension) to match cls/box shapes
|
||
|
idx = targets["idx"]
|
||
|
if idx.dim() > 1:
|
||
|
idx = idx.squeeze(-1)
|
||
|
idx = idx == i
|
||
|
|
||
|
# XXX: initially, the code was like below, which caused shape mismatch when idx has extra dimension
|
||
|
# idx = targets["idx"] == i
|
||
|
cls = targets["cls"][idx]
|
||
|
box = targets["box"][idx]
|
||
|
|
||
|
cls = cls.cuda()
|
||
|
box = box.cuda()
|
||
|
|
||
|
metric = torch.zeros(output.shape[0], n_iou, dtype=torch.bool).cuda()
|
||
|
|
||
|
if output.shape[0] == 0:
|
||
|
if cls.shape[0]:
|
||
|
metrics.append((metric, *torch.zeros((2, 0)).cuda(), cls.squeeze(-1)))
|
||
|
continue
|
||
|
# Evaluate
|
||
|
if cls.shape[0]:
|
||
|
target = torch.cat(tensors=(cls, util.wh2xy(box) * scale), dim=1)
|
||
|
metric = util.compute_metric(output[:, :6], target, iou_v)
|
||
|
# Append
|
||
|
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
|
||
|
|
||
|
# Compute metrics
|
||
|
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy
|
||
|
if len(metrics) and metrics[0].any():
|
||
|
tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics, plot=plot, names=params["names"])
|
||
|
# Print results
|
||
|
print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap))
|
||
|
# Return results
|
||
|
model.float() # for training
|
||
|
return mean_ap, map50, m_rec, m_pre
|
||
|
|
||
|
|
||
|
def profile(args, params):
|
||
|
import thop
|
||
|
|
||
|
shape = (1, 3, args.input_size, args.input_size)
|
||
|
model = nn.yolo_v11_n(len(params["names"])).fuse()
|
||
|
|
||
|
model.eval()
|
||
|
model(torch.zeros(shape))
|
||
|
|
||
|
x = torch.empty(shape)
|
||
|
flops, num_params = thop.profile(model, inputs=[x], verbose=False)
|
||
|
flops, num_params = thop.clever_format(nums=[2 * flops, num_params], format="%.3f")
|
||
|
|
||
|
if args.local_rank == 0:
|
||
|
print(f"Number of parameters: {num_params}")
|
||
|
print(f"Number of FLOPs: {flops}")
|
||
|
|
||
|
|
||
|
def main():
|
||
|
parser = ArgumentParser()
|
||
|
parser.add_argument("--input-size", default=640, type=int)
|
||
|
parser.add_argument("--batch-size", default=32, type=int)
|
||
|
parser.add_argument("--local-rank", default=0, type=int)
|
||
|
parser.add_argument("--local_rank", default=0, type=int)
|
||
|
parser.add_argument("--epochs", default=600, type=int)
|
||
|
parser.add_argument("--train", action="store_true")
|
||
|
parser.add_argument("--test", action="store_true")
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
args.local_rank = int(os.getenv("LOCAL_RANK", 0))
|
||
|
args.world_size = int(os.getenv("WORLD_SIZE", 1))
|
||
|
args.distributed = int(os.getenv("WORLD_SIZE", 1)) > 1
|
||
|
|
||
|
if args.distributed:
|
||
|
torch.cuda.set_device(device=args.local_rank)
|
||
|
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
||
|
|
||
|
if args.local_rank == 0:
|
||
|
if not os.path.exists("weights"):
|
||
|
os.makedirs("weights")
|
||
|
|
||
|
with open("utils/args.yaml", errors="ignore") as f:
|
||
|
params = yaml.safe_load(f)
|
||
|
|
||
|
util.setup_seed()
|
||
|
util.setup_multi_processes()
|
||
|
|
||
|
profile(args, params)
|
||
|
|
||
|
if args.train:
|
||
|
train(args, params)
|
||
|
if args.test:
|
||
|
test(args, params)
|
||
|
|
||
|
# Clean
|
||
|
if args.distributed:
|
||
|
torch.distributed.destroy_process_group()
|
||
|
torch.cuda.empty_cache()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|