""" Utility functions for yolo. """ import copy import random from time import time import math import numpy import torch import torchvision from torch.nn.functional import cross_entropy def setup_seed(): """ Setup random seed. """ random.seed(0) numpy.random.seed(0) torch.manual_seed(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True def setup_multi_processes(): """ Setup multi-processing environment variables. """ import cv2 from os import environ from platform import system # set multiprocess start method as `fork` to speed up the training if system() != "Windows": torch.multiprocessing.set_start_method("fork", force=True) # disable opencv multithreading to avoid system being overloaded cv2.setNumThreads(0) # setup OMP threads if "OMP_NUM_THREADS" not in environ: environ["OMP_NUM_THREADS"] = "1" # setup MKL threads if "MKL_NUM_THREADS" not in environ: environ["MKL_NUM_THREADS"] = "1" def export_onnx(args): import onnx # noqa inputs = ["images"] outputs = ["outputs"] dynamic = {"outputs": {0: "batch", 1: "anchors"}} m = torch.load("./weights/best.pt", weights_only=False)["model"].float() x = torch.zeros((1, 3, args.input_size, args.input_size)) torch.onnx.export( m.cpu(), (x.cpu(),), f="./weights/best.onnx", verbose=False, opset_version=12, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False do_constant_folding=True, input_names=inputs, output_names=outputs, dynamic_axes=dynamic or None, ) # Checks model_onnx = onnx.load("./weights/best.onnx") # load onnx model onnx.checker.check_model(model_onnx) # check onnx model onnx.save(model_onnx, "./weights/best.onnx") # Inference example # https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/autobackend.py def wh2xy(x): y = x.clone() if isinstance(x, torch.Tensor) else numpy.copy(x) y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y return y def make_anchors(x, strides, offset=0.5): assert x is not None anchor_tensor, stride_tensor = [], [] dtype, device = x[0].dtype, x[0].device for i, stride in enumerate(strides): _, _, h, w = x[i].shape sx = torch.arange(end=w, device=device, dtype=dtype) + offset # shift x sy = torch.arange(end=h, device=device, dtype=dtype) + offset # shift y sy, sx = torch.meshgrid(sy, sx, indexing="ij") anchor_tensor.append(torch.stack((sx, sy), -1).view(-1, 2)) stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) return torch.cat(anchor_tensor), torch.cat(stride_tensor) def compute_metric(output, target, iou_v): # intersection(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) (a1, a2) = target[:, 1:].unsqueeze(1).chunk(2, 2) (b1, b2) = output[:, :4].unsqueeze(0).chunk(2, 2) intersection = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2) # IoU = intersection / (area1 + area2 - intersection) iou = intersection / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - intersection + 1e-7) correct = numpy.zeros((output.shape[0], iou_v.shape[0])) correct = correct.astype(bool) for i in range(len(iou_v)): # IoU > threshold and classes match x = torch.where((iou >= iou_v[i]) & (target[:, 0:1] == output[:, 5])) if x[0].shape[0]: matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detect, iou] if x[0].shape[0] > 1: matches = matches[matches[:, 2].argsort()[::-1]] matches = matches[numpy.unique(matches[:, 1], return_index=True)[1]] matches = matches[numpy.unique(matches[:, 0], return_index=True)[1]] correct[matches[:, 1].astype(int), i] = True return torch.tensor(correct, dtype=torch.bool, device=output.device) def non_max_suppression(outputs, confidence_threshold=0.001, iou_threshold=0.65): max_wh = 7680 max_det = 300 max_nms = 30000 bs = outputs.shape[0] # batch size nc = outputs.shape[1] - 4 # number of classes xc = outputs[:, 4 : 4 + nc].amax(1) > confidence_threshold # candidates # Settings start = time() limit = 0.5 + 0.05 * bs # seconds to quit after output = [torch.zeros((0, 6), device=outputs.device)] * bs for index, x in enumerate(outputs): # image index, image inference x = x.transpose(0, -1)[xc[index]] # confidence # If none remain process next image if not x.shape[0]: continue # matrix nx6 (box, confidence, cls) box, cls = x.split((4, nc), 1) box = wh2xy(box) # (cx, cy, w, h) to (x1, y1, x2, y2) if nc > 1: i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T x = torch.cat((box[i], x[i, 4 + j].unsqueeze(1), j[:, None].float()), dim=1) else: # best class only conf, j = cls.max(1, keepdim=True) x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold] # Check shape n = x.shape[0] # number of boxes if not n: # no boxes continue x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes # Batched NMS c = x[:, 5:6] * max_wh # classes boxes, scores = x[:, :4] + c, x[:, 4] # boxes, scores indices = torchvision.ops.nms(boxes, scores, iou_threshold) # NMS indices = indices[:max_det] # limit detections output[index] = x[indices] if (time() - start) > limit: break # time limit exceeded return output def smooth(y, f=0.1): # Box filter of fraction f nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd) p = numpy.ones(nf // 2) # ones padding yp = numpy.concatenate((p * y[0], y, p * y[-1]), 0) # y padded return numpy.convolve(yp, numpy.ones(nf) / nf, mode="valid") # y-smoothed def plot_pr_curve(px, py, ap, names, save_dir): from matplotlib import pyplot fig, ax = pyplot.subplots(1, 1, figsize=(9, 6), tight_layout=True) py = numpy.stack(py, axis=1) if 0 < len(names) < 21: # display per-class legend if < 21 classes for i, y in enumerate(py.T): ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision) else: ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision) ax.plot( px, py.mean(1), linewidth=3, color="blue", label="all classes %.3f mAP@0.5" % ap[:, 0].mean(), ) ax.set_xlabel("Recall") ax.set_ylabel("Precision") ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") ax.set_title("Precision-Recall Curve") fig.savefig(save_dir, dpi=250) pyplot.close(fig) def plot_curve(px, py, names, save_dir, x_label="Confidence", y_label="Metric"): from matplotlib import pyplot figure, ax = pyplot.subplots(1, 1, figsize=(9, 6), tight_layout=True) if 0 < len(names) < 21: # display per-class legend if < 21 classes for i, y in enumerate(py): ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric) else: ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric) y = smooth(py.mean(0), f=0.05) ax.plot( px, y, linewidth=3, color="blue", label=f"all classes {y.max():.3f} at {px[y.argmax()]:.3f}", ) ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") ax.set_title(f"{y_label}-Confidence Curve") figure.savefig(save_dir, dpi=250) pyplot.close(figure) def compute_ap(tp, conf, output, target, plot=False, names=(), eps=1e-16): """ Compute the average precision, given the recall and precision curves. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. # Arguments tp: True positives (nparray, nx1 or nx10). conf: Object-ness value from 0-1 (nparray). output: Predicted object classes (nparray). target: True object classes (nparray). # Returns The average precision """ # Sort by object-ness i = numpy.argsort(-conf) tp, conf, output = tp[i], conf[i], output[i] # Find unique classes unique_classes, nt = numpy.unique(target, return_counts=True) nc = unique_classes.shape[0] # number of classes, number of detections # Create Precision-Recall curve and compute AP for each class p = numpy.zeros((nc, 1000)) r = numpy.zeros((nc, 1000)) ap = numpy.zeros((nc, tp.shape[1])) px, py = numpy.linspace(start=0, stop=1, num=1000), [] # for plotting for ci, c in enumerate(unique_classes): i = output == c nl = nt[ci] # number of labels no = i.sum() # number of outputs if no == 0 or nl == 0: continue # Accumulate FPs and TPs fpc = (1 - tp[i]).cumsum(0) tpc = tp[i].cumsum(0) # Recall recall = tpc / (nl + eps) # recall curve # negative x, xp because xp decreases r[ci] = numpy.interp(-px, -conf[i], recall[:, 0], left=0) # Precision precision = tpc / (tpc + fpc) # precision curve p[ci] = numpy.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score # AP from recall-precision curve for j in range(tp.shape[1]): m_rec = numpy.concatenate(([0.0], recall[:, j], [1.0])) m_pre = numpy.concatenate(([1.0], precision[:, j], [0.0])) # Compute the precision envelope m_pre = numpy.flip(numpy.maximum.accumulate(numpy.flip(m_pre))) # Integrate area under curve x = numpy.linspace(start=0, stop=1, num=101) # 101-point interp (COCO) # numpy.trapz is deprecated in numpy 2.0.0 or after version, use numpy.trapezoid instead ap[ci, j] = numpy.trapezoid(numpy.interp(x, m_rec, m_pre), x) # integrate if plot and j == 0: py.append(numpy.interp(px, m_rec, m_pre)) # precision at mAP@0.5 # Compute F1 (harmonic mean of precision and recall) f1 = 2 * p * r / (p + r + eps) if plot: names = dict(enumerate(names)) # to dict names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data plot_pr_curve(px, py, ap, names, save_dir="./weights/PR_curve.png") plot_curve(px, f1, names, save_dir="./weights/F1_curve.png", y_label="F1") plot_curve(px, p, names, save_dir="./weights/P_curve.png", y_label="Precision") plot_curve(px, r, names, save_dir="./weights/R_curve.png", y_label="Recall") i = smooth(f1.mean(0), 0.1).argmax() # max F1 index p, r, f1 = p[:, i], r[:, i], f1[:, i] tp = (r * nt).round() # true positives fp = (tp / (p + eps) - tp).round() # false positives ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95 m_pre, m_rec = p.mean(), r.mean() map50, mean_ap = ap50.mean(), ap.mean() return tp, fp, m_pre, m_rec, map50, mean_ap def compute_iou(box1, box2, eps=1e-7): # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4) # Get the coordinates of bounding boxes b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1) b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1) w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps # Intersection area inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * ( b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1) ).clamp(0) # Union Area union = w1 * h1 + w2 * h2 - inter + eps # IoU iou = inter / union cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height c2 = cw**2 + ch**2 + eps # convex diagonal squared rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2 # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2) with torch.no_grad(): alpha = v / (v - iou + (1 + eps)) return iou - (rho2 / c2 + v * alpha) # CIoU def strip_optimizer(filename): x = torch.load(filename, map_location="cpu", weights_only=False) x["model"].half() # to FP16 for p in x["model"].parameters(): p.requires_grad = False torch.save(x, f=filename) def clip_gradients(model, max_norm=10.0): parameters = model.parameters() torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm) def load_weight(model, ckpt): dst = model.state_dict() src = torch.load(ckpt, weights_only=False)["model"].float().cpu() ckpt = {} for k, v in src.state_dict().items(): if k in dst and v.shape == dst[k].shape: ckpt[k] = v model.load_state_dict(state_dict=ckpt, strict=False) return model def set_params(model, decay): p1 = [] p2 = [] norm = tuple(v for k, v in torch.nn.__dict__.items() if "Norm" in k) for m in model.modules(): for n, p in m.named_parameters(recurse=0): if not p.requires_grad: continue if n == "bias": # bias (no decay) p1.append(p) elif n == "weight" and isinstance(m, norm): # norm-weight (no decay) p1.append(p) else: p2.append(p) # weight (with decay) return [{"params": p1, "weight_decay": 0.00}, {"params": p2, "weight_decay": decay}] def plot_lr(args, optimizer, scheduler, num_steps): from matplotlib import pyplot optimizer = copy.copy(optimizer) scheduler = copy.copy(scheduler) y = [] for epoch in range(args.epochs): for i in range(num_steps): step = i + num_steps * epoch scheduler.step(step, optimizer) y.append(optimizer.param_groups[0]["lr"]) pyplot.plot(y, ".-", label="LR") pyplot.xlabel("step") pyplot.ylabel("LR") pyplot.grid() pyplot.xlim(0, args.epochs * num_steps) pyplot.ylim(0) pyplot.savefig("./weights/lr.png", dpi=200) pyplot.close() class CosineLR: def __init__(self, args, params, num_steps): max_lr = params["max_lr"] min_lr = params["min_lr"] warmup_steps = int(max(params["warmup_epochs"] * num_steps, 100)) decay_steps = int(args.epochs * num_steps - warmup_steps) warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps)) decay_lr = [] for step in range(1, decay_steps + 1): alpha = math.cos(math.pi * step / decay_steps) decay_lr.append(min_lr + 0.5 * (max_lr - min_lr) * (1 + alpha)) self.total_lr = numpy.concatenate((warmup_lr, decay_lr)) def step(self, step, optimizer): for param_group in optimizer.param_groups: param_group["lr"] = self.total_lr[step] class LinearLR: def __init__(self, args, params, num_steps): max_lr = params["max_lr"] min_lr = params["min_lr"] warmup_steps = int(max(params["warmup_epochs"] * num_steps, 100)) decay_steps = max(1, int(args.epochs * num_steps - warmup_steps)) warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps), endpoint=False) decay_lr = numpy.linspace(max_lr, min_lr, decay_steps) self.total_lr = numpy.concatenate((warmup_lr, decay_lr)) def step(self, step, optimizer): for param_group in optimizer.param_groups: param_group["lr"] = self.total_lr[step] class EMA: """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models Keeps a moving average of everything in the model state_dict (parameters and buffers) For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage """ def __init__(self, model, decay=0.9999, tau=2000, updates=0): # Create EMA self.ema = copy.deepcopy(model).eval() # FP32 EMA self.updates = updates # number of EMA updates # decay exponential ramp (to help early epochs) self.decay = lambda x: decay * (1 - math.exp(-x / tau)) for p in self.ema.parameters(): p.requires_grad_(False) def update(self, model): if hasattr(model, "module"): model = model.module # Update EMA parameters with torch.no_grad(): self.updates += 1 d = self.decay(self.updates) msd = model.state_dict() # model state_dict for k, v in self.ema.state_dict().items(): if v.dtype.is_floating_point: v *= d v += (1 - d) * msd[k].detach() class AverageMeter: def __init__(self): self.num = 0 self.sum = 0 self.avg = 0 def update(self, v, n): if not math.isnan(float(v)): self.num = self.num + n self.sum = self.sum + v * n self.avg = self.sum / self.num class Assigner(torch.nn.Module): def __init__(self, nc=80, top_k=13, alpha=1.0, beta=6.0, eps=1e-9): super().__init__() self.top_k = top_k self.nc = nc self.alpha = alpha self.beta = beta self.eps = eps @torch.no_grad() def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): batch_size = pd_scores.size(0) num_max_boxes = gt_bboxes.size(1) if num_max_boxes == 0: device = gt_bboxes.device return ( torch.zeros_like(pd_bboxes).to(device), torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device), ) num_anchors = anc_points.shape[0] shape = gt_bboxes.shape lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) mask_in_gts = torch.cat((anc_points[None] - lt, rb - anc_points[None]), dim=2) mask_in_gts = mask_in_gts.view(shape[0], shape[1], num_anchors, -1).amin(3).gt_(self.eps) na = pd_bboxes.shape[-2] gt_mask = (mask_in_gts * mask_gt).bool() # b, max_num_obj, h*w overlaps = torch.zeros( [batch_size, num_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device, ) bbox_scores = torch.zeros( [batch_size, num_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device, ) ind = torch.zeros([2, batch_size, num_max_boxes], dtype=torch.long) # 2, b, max_num_obj ind[0] = torch.arange(end=batch_size).view(-1, 1).expand(-1, num_max_boxes) # b, max_num_obj ind[1] = gt_labels.squeeze(-1) # b, max_num_obj bbox_scores[gt_mask] = pd_scores[ind[0], :, ind[1]][gt_mask] # b, max_num_obj, h*w pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, num_max_boxes, -1, -1)[gt_mask] gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[gt_mask] overlaps[gt_mask] = compute_iou(gt_boxes, pd_boxes).squeeze(-1).clamp_(0) align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) top_k_mask = mask_gt.expand(-1, -1, self.top_k).bool() top_k_metrics, top_k_indices = torch.topk(align_metric, self.top_k, dim=-1, largest=True) if top_k_mask is None: top_k_mask = (top_k_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(top_k_indices) top_k_indices.masked_fill_(~top_k_mask, 0) mask_top_k = torch.zeros(align_metric.shape, dtype=torch.int8, device=top_k_indices.device) ones = torch.ones_like(top_k_indices[:, :, :1], dtype=torch.int8, device=top_k_indices.device) for k in range(self.top_k): mask_top_k.scatter_add_(-1, top_k_indices[:, :, k : k + 1], ones) mask_top_k.masked_fill_(mask_top_k > 1, 0) mask_top_k = mask_top_k.to(align_metric.dtype) mask_pos = mask_top_k * mask_in_gts * mask_gt fg_mask = mask_pos.sum(-2) if fg_mask.max() > 1: mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, num_max_boxes, -1) max_overlaps_idx = overlaps.argmax(1) is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device) is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1) mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() fg_mask = mask_pos.sum(-2) target_gt_idx = mask_pos.argmax(-2) # Assigned target index = torch.arange(end=batch_size, dtype=torch.int64, device=gt_labels.device)[..., None] target_index = target_gt_idx + index * num_max_boxes target_labels = gt_labels.long().flatten()[target_index] target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_index] # Assigned target scores target_labels.clamp_(0) target_scores = torch.zeros( (target_labels.shape[0], target_labels.shape[1], self.nc), dtype=torch.int64, device=target_labels.device, ) target_scores.scatter_(2, target_labels.unsqueeze(-1), 1) fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.nc) target_scores = torch.where(fg_scores_mask > 0, target_scores, 0) # Normalize align_metric *= mask_pos pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1) target_scores = target_scores * norm_align_metric return target_bboxes, target_scores, fg_mask.bool() class QFL(torch.nn.Module): def __init__(self, beta=2.0): super().__init__() self.beta = beta self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none") def forward(self, outputs, targets): bce_loss = self.bce_loss(outputs, targets) return torch.pow(torch.abs(targets - outputs.sigmoid()), self.beta) * bce_loss class VFL(torch.nn.Module): def __init__(self, alpha=0.75, gamma=2.00, iou_weighted=True): super().__init__() assert alpha >= 0.0 self.alpha = alpha self.gamma = gamma self.iou_weighted = iou_weighted self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none") def forward(self, outputs, targets): assert outputs.size() == targets.size() targets = targets.type_as(outputs) if self.iou_weighted: focal_weight = ( targets * (targets > 0.0).float() + self.alpha * (outputs.sigmoid() - targets).abs().pow(self.gamma) * (targets <= 0.0).float() ) else: focal_weight = (targets > 0.0).float() + self.alpha * (outputs.sigmoid() - targets).abs().pow( self.gamma ) * (targets <= 0.0).float() return self.bce_loss(outputs, targets) * focal_weight class FocalLoss(torch.nn.Module): def __init__(self, alpha=0.25, gamma=1.5): super().__init__() self.alpha = alpha self.gamma = gamma self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none") def forward(self, outputs, targets): loss = self.bce_loss(outputs, targets) if self.alpha > 0: alpha_factor = targets * self.alpha + (1 - targets) * (1 - self.alpha) loss *= alpha_factor if self.gamma > 0: outputs_sigmoid = outputs.sigmoid() p_t = targets * outputs_sigmoid + (1 - targets) * (1 - outputs_sigmoid) gamma_factor = (1.0 - p_t) ** self.gamma loss *= gamma_factor return loss class BoxLoss(torch.nn.Module): def __init__(self, dfl_ch): super().__init__() self.dfl_ch = dfl_ch def forward( self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask, ): # IoU loss weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1) iou = compute_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask]) loss_box = ((1.0 - iou) * weight).sum() / target_scores_sum # DFL loss a, b = target_bboxes.chunk(2, -1) target = torch.cat((anchor_points - a, b - anchor_points), -1) target = target.clamp(0, self.dfl_ch - 0.01) loss_dfl = self.df_loss(pred_dist[fg_mask].view(-1, self.dfl_ch + 1), target[fg_mask]) loss_dfl = (loss_dfl * weight).sum() / target_scores_sum return loss_box, loss_dfl @staticmethod def df_loss(pred_dist, target): # Distribution Focal Loss (DFL) # https://ieeexplore.ieee.org/document/9792391 tl = target.long() # target left tr = tl + 1 # target right wl = tr - target # weight left wr = 1 - wl # weight right left_loss = cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) right_loss = cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) return (left_loss * wl + right_loss * wr).mean(-1, keepdim=True) class ComputeLoss: def __init__(self, model, params): if hasattr(model, "module"): model = model.module device = next(model.parameters()).device m = model.head # Head() module self.params = params self.stride = m.stride self.nc = m.nc self.no = m.no self.reg_max = m.ch self.device = device self.box_loss = BoxLoss(m.ch - 1).to(device) self.cls_loss = torch.nn.BCEWithLogitsLoss(reduction="none") self.assigner = Assigner(nc=self.nc, top_k=10, alpha=0.5, beta=6.0) self.project = torch.arange(m.ch, dtype=torch.float, device=device) def box_decode(self, anchor_points, pred_dist): b, a, c = pred_dist.shape pred_dist = pred_dist.view(b, a, 4, c // 4) pred_dist = pred_dist.softmax(3) pred_dist = pred_dist.matmul(self.project.type(pred_dist.dtype)) lt, rb = pred_dist.chunk(2, -1) x1y1 = anchor_points - lt x2y2 = anchor_points + rb return torch.cat(tensors=(x1y1, x2y2), dim=-1) def __call__(self, outputs, targets): x = torch.cat([i.view(outputs[0].shape[0], self.no, -1) for i in outputs], dim=2) pred_distri, pred_scores = x.split(split_size=(self.reg_max * 4, self.nc), dim=1) pred_scores = pred_scores.permute(0, 2, 1).contiguous() pred_distri = pred_distri.permute(0, 2, 1).contiguous() data_type = pred_scores.dtype batch_size = pred_scores.shape[0] input_size = torch.tensor(outputs[0].shape[2:], device=self.device, dtype=data_type) * self.stride[0] anchor_points, stride_tensor = make_anchors(outputs, self.stride, offset=0.5) idx = targets["idx"].view(-1, 1) cls = targets["cls"].view(-1, 1) box = targets["box"] targets = torch.cat((idx, cls, box), dim=1).to(self.device) if targets.shape[0] == 0: gt = torch.zeros(batch_size, 0, 5, device=self.device) else: i = targets[:, 0] _, counts = i.unique(return_counts=True) counts = counts.to(dtype=torch.int32) gt = torch.zeros(batch_size, counts.max(), 5, device=self.device) for j in range(batch_size): matches = i == j n = matches.sum() if n: gt[j, :n] = targets[matches, 1:] x = gt[..., 1:5].mul_(input_size[[1, 0, 1, 0]]) y = torch.empty_like(x) dw = x[..., 2] / 2 # half-width dh = x[..., 3] / 2 # half-height y[..., 0] = x[..., 0] - dw # top left x y[..., 1] = x[..., 1] - dh # top left y y[..., 2] = x[..., 0] + dw # bottom right x y[..., 3] = x[..., 1] + dh # bottom right y gt[..., 1:5] = y gt_labels, gt_bboxes = gt.split((1, 4), 2) mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) pred_bboxes = self.box_decode(anchor_points, pred_distri) assigned_targets = self.assigner( pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt, ) target_bboxes, target_scores, fg_mask = assigned_targets target_scores_sum = max(target_scores.sum(), 1) loss_cls = self.cls_loss(pred_scores, target_scores.to(data_type)).sum() / target_scores_sum # BCE # Box loss loss_box = torch.zeros(1, device=self.device) loss_dfl = torch.zeros(1, device=self.device) if fg_mask.sum(): target_bboxes /= stride_tensor loss_box, loss_dfl = self.box_loss( pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask, ) loss_box *= self.params["box"] # box gain loss_cls *= self.params["cls"] # cls gain loss_dfl *= self.params["dfl"] # dfl gain return loss_box, loss_cls, loss_dfl