363 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			363 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""
 | 
						|
This file contains the model definition of YOLOv11
 | 
						|
"""
 | 
						|
 | 
						|
import math
 | 
						|
 | 
						|
import torch
 | 
						|
 | 
						|
from utils.util import make_anchors
 | 
						|
 | 
						|
 | 
						|
def fuse_conv(conv, norm):
 | 
						|
    fused_conv = (
 | 
						|
        torch.nn.Conv2d(
 | 
						|
            conv.in_channels,
 | 
						|
            conv.out_channels,
 | 
						|
            kernel_size=conv.kernel_size,
 | 
						|
            stride=conv.stride,
 | 
						|
            padding=conv.padding,
 | 
						|
            groups=conv.groups,
 | 
						|
            bias=True,
 | 
						|
        )
 | 
						|
        .requires_grad_(False)
 | 
						|
        .to(conv.weight.device)
 | 
						|
    )
 | 
						|
 | 
						|
    w_conv = conv.weight.clone().view(conv.out_channels, -1)
 | 
						|
    w_norm = torch.diag(norm.weight.div(torch.sqrt(norm.eps + norm.running_var)))
 | 
						|
    fused_conv.weight.copy_(torch.mm(w_norm, w_conv).view(fused_conv.weight.size()))
 | 
						|
 | 
						|
    b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
 | 
						|
    b_norm = norm.bias - norm.weight.mul(norm.running_mean).div(torch.sqrt(norm.running_var + norm.eps))
 | 
						|
    if fused_conv.bias is not None:
 | 
						|
        fused_conv.bias.copy_(torch.mm(w_norm, b_conv.reshape(-1, 1)).reshape(-1) + b_norm)
 | 
						|
 | 
						|
    return fused_conv
 | 
						|
 | 
						|
 | 
						|
class Conv(torch.nn.Module):
 | 
						|
    def __init__(self, in_ch, out_ch, activation, k=1, s=1, p=0, g=1):
 | 
						|
        super().__init__()
 | 
						|
        self.conv = torch.nn.Conv2d(in_ch, out_ch, k, s, p, groups=g, bias=False)
 | 
						|
        self.norm = torch.nn.BatchNorm2d(out_ch, eps=0.001, momentum=0.03)
 | 
						|
        self.relu = activation
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        return self.relu(self.norm(self.conv(x)))
 | 
						|
 | 
						|
    def fuse_forward(self, x):
 | 
						|
        return self.relu(self.conv(x))
 | 
						|
 | 
						|
 | 
						|
class Residual(torch.nn.Module):
 | 
						|
    def __init__(self, ch, e=0.5):
 | 
						|
        super().__init__()
 | 
						|
        self.conv1 = Conv(ch, int(ch * e), torch.nn.SiLU(), k=3, p=1)
 | 
						|
        self.conv2 = Conv(int(ch * e), ch, torch.nn.SiLU(), k=3, p=1)
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        return x + self.conv2(self.conv1(x))
 | 
						|
 | 
						|
 | 
						|
class CSPModule(torch.nn.Module):
 | 
						|
    def __init__(self, in_ch, out_ch):
 | 
						|
        super().__init__()
 | 
						|
        self.conv1 = Conv(in_ch, out_ch // 2, torch.nn.SiLU())
 | 
						|
        self.conv2 = Conv(in_ch, out_ch // 2, torch.nn.SiLU())
 | 
						|
        self.conv3 = Conv(2 * (out_ch // 2), out_ch, torch.nn.SiLU())
 | 
						|
        self.res_m = torch.nn.Sequential(Residual(out_ch // 2, e=1.0), Residual(out_ch // 2, e=1.0))
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        y = self.res_m(self.conv1(x))
 | 
						|
        return self.conv3(torch.cat((y, self.conv2(x)), dim=1))
 | 
						|
 | 
						|
 | 
						|
class CSP(torch.nn.Module):
 | 
						|
    def __init__(self, in_ch, out_ch, n, csp, r):
 | 
						|
        super().__init__()
 | 
						|
        self.conv1 = Conv(in_ch, 2 * (out_ch // r), torch.nn.SiLU())
 | 
						|
        self.conv2 = Conv((2 + n) * (out_ch // r), out_ch, torch.nn.SiLU())
 | 
						|
 | 
						|
        if not csp:
 | 
						|
            self.res_m = torch.nn.ModuleList(Residual(out_ch // r) for _ in range(n))
 | 
						|
        else:
 | 
						|
            self.res_m = torch.nn.ModuleList(CSPModule(out_ch // r, out_ch // r) for _ in range(n))
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        y = list(self.conv1(x).chunk(2, 1))
 | 
						|
        y.extend(m(y[-1]) for m in self.res_m)
 | 
						|
        return self.conv2(torch.cat(y, dim=1))
 | 
						|
 | 
						|
 | 
						|
class SPP(torch.nn.Module):
 | 
						|
    def __init__(self, in_ch, out_ch, k=5):
 | 
						|
        super().__init__()
 | 
						|
        self.conv1 = Conv(in_ch, in_ch // 2, torch.nn.SiLU())
 | 
						|
        self.conv2 = Conv(in_ch * 2, out_ch, torch.nn.SiLU())
 | 
						|
        self.res_m = torch.nn.MaxPool2d(k, stride=1, padding=k // 2)
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        x = self.conv1(x)
 | 
						|
        y1 = self.res_m(x)
 | 
						|
        y2 = self.res_m(y1)
 | 
						|
        return self.conv2(torch.cat(tensors=[x, y1, y2, self.res_m(y2)], dim=1))
 | 
						|
 | 
						|
 | 
						|
class Attention(torch.nn.Module):
 | 
						|
    def __init__(self, ch, num_head):
 | 
						|
        super().__init__()
 | 
						|
        self.num_head = num_head
 | 
						|
        self.dim_head = ch // num_head
 | 
						|
        self.dim_key = self.dim_head // 2
 | 
						|
        self.scale = self.dim_key**-0.5
 | 
						|
 | 
						|
        self.qkv = Conv(ch, ch + self.dim_key * num_head * 2, torch.nn.Identity())
 | 
						|
 | 
						|
        self.conv1 = Conv(ch, ch, torch.nn.Identity(), k=3, p=1, g=ch)
 | 
						|
        self.conv2 = Conv(ch, ch, torch.nn.Identity())
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        b, c, h, w = x.shape
 | 
						|
 | 
						|
        qkv = self.qkv(x)
 | 
						|
        qkv = qkv.view(b, self.num_head, self.dim_key * 2 + self.dim_head, h * w)
 | 
						|
 | 
						|
        q, k, v = qkv.split([self.dim_key, self.dim_key, self.dim_head], dim=2)
 | 
						|
 | 
						|
        attn = (q.transpose(-2, -1) @ k) * self.scale
 | 
						|
        attn = attn.softmax(dim=-1)
 | 
						|
 | 
						|
        x = (v @ attn.transpose(-2, -1)).view(b, c, h, w) + self.conv1(v.reshape(b, c, h, w))
 | 
						|
        return self.conv2(x)
 | 
						|
 | 
						|
 | 
						|
class PSABlock(torch.nn.Module):
 | 
						|
    def __init__(self, ch, num_head):
 | 
						|
        super().__init__()
 | 
						|
        self.conv1 = Attention(ch, num_head)
 | 
						|
        self.conv2 = torch.nn.Sequential(Conv(ch, ch * 2, torch.nn.SiLU()), Conv(ch * 2, ch, torch.nn.Identity()))
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        x = x + self.conv1(x)
 | 
						|
        return x + self.conv2(x)
 | 
						|
 | 
						|
 | 
						|
class PSA(torch.nn.Module):
 | 
						|
    def __init__(self, ch, n):
 | 
						|
        super().__init__()
 | 
						|
        self.conv1 = Conv(ch, 2 * (ch // 2), torch.nn.SiLU())
 | 
						|
        self.conv2 = Conv(2 * (ch // 2), ch, torch.nn.SiLU())
 | 
						|
        self.res_m = torch.nn.Sequential(*(PSABlock(ch // 2, ch // 128) for _ in range(n)))
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        x, y = self.conv1(x).chunk(2, 1)
 | 
						|
        return self.conv2(torch.cat(tensors=(x, self.res_m(y)), dim=1))
 | 
						|
 | 
						|
 | 
						|
class DarkNet(torch.nn.Module):
 | 
						|
    def __init__(self, width, depth, csp):
 | 
						|
        super().__init__()
 | 
						|
        self.p1 = []
 | 
						|
        self.p2 = []
 | 
						|
        self.p3 = []
 | 
						|
        self.p4 = []
 | 
						|
        self.p5 = []
 | 
						|
 | 
						|
        # p1/2
 | 
						|
        self.p1.append(Conv(width[0], width[1], torch.nn.SiLU(), k=3, s=2, p=1))
 | 
						|
        # p2/4
 | 
						|
        self.p2.append(Conv(width[1], width[2], torch.nn.SiLU(), k=3, s=2, p=1))
 | 
						|
        self.p2.append(CSP(width[2], width[3], depth[0], csp[0], r=4))
 | 
						|
        # p3/8
 | 
						|
        self.p3.append(Conv(width[3], width[3], torch.nn.SiLU(), k=3, s=2, p=1))
 | 
						|
        self.p3.append(CSP(width[3], width[4], depth[1], csp[0], r=4))
 | 
						|
        # p4/16
 | 
						|
        self.p4.append(Conv(width[4], width[4], torch.nn.SiLU(), k=3, s=2, p=1))
 | 
						|
        self.p4.append(CSP(width[4], width[4], depth[2], csp[1], r=2))
 | 
						|
        # p5/32
 | 
						|
        self.p5.append(Conv(width[4], width[5], torch.nn.SiLU(), k=3, s=2, p=1))
 | 
						|
        self.p5.append(CSP(width[5], width[5], depth[3], csp[1], r=2))
 | 
						|
        self.p5.append(SPP(width[5], width[5]))
 | 
						|
        self.p5.append(PSA(width[5], depth[4]))
 | 
						|
 | 
						|
        self.p1 = torch.nn.Sequential(*self.p1)
 | 
						|
        self.p2 = torch.nn.Sequential(*self.p2)
 | 
						|
        self.p3 = torch.nn.Sequential(*self.p3)
 | 
						|
        self.p4 = torch.nn.Sequential(*self.p4)
 | 
						|
        self.p5 = torch.nn.Sequential(*self.p5)
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        p1 = self.p1(x)
 | 
						|
        p2 = self.p2(p1)
 | 
						|
        p3 = self.p3(p2)
 | 
						|
        p4 = self.p4(p3)
 | 
						|
        p5 = self.p5(p4)
 | 
						|
        return p3, p4, p5
 | 
						|
 | 
						|
 | 
						|
class DarkFPN(torch.nn.Module):
 | 
						|
    def __init__(self, width, depth, csp):
 | 
						|
        super().__init__()
 | 
						|
        self.up = torch.nn.Upsample(scale_factor=2)
 | 
						|
        self.h1 = CSP(width[4] + width[5], width[4], depth[5], csp[0], r=2)
 | 
						|
        self.h2 = CSP(width[4] + width[4], width[3], depth[5], csp[0], r=2)
 | 
						|
        self.h3 = Conv(width[3], width[3], torch.nn.SiLU(), k=3, s=2, p=1)
 | 
						|
        self.h4 = CSP(width[3] + width[4], width[4], depth[5], csp[0], r=2)
 | 
						|
        self.h5 = Conv(width[4], width[4], torch.nn.SiLU(), k=3, s=2, p=1)
 | 
						|
        self.h6 = CSP(width[4] + width[5], width[5], depth[5], csp[1], r=2)
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        p3, p4, p5 = x
 | 
						|
        p4 = self.h1(torch.cat(tensors=[self.up(p5), p4], dim=1))
 | 
						|
        p3 = self.h2(torch.cat(tensors=[self.up(p4), p3], dim=1))
 | 
						|
        p4 = self.h4(torch.cat(tensors=[self.h3(p3), p4], dim=1))
 | 
						|
        p5 = self.h6(torch.cat(tensors=[self.h5(p4), p5], dim=1))
 | 
						|
        return p3, p4, p5
 | 
						|
 | 
						|
 | 
						|
class DFL(torch.nn.Module):
 | 
						|
    # Generalized Focal Loss
 | 
						|
    # https://ieeexplore.ieee.org/document/9792391
 | 
						|
    def __init__(self, ch=16):
 | 
						|
        super().__init__()
 | 
						|
        self.ch = ch
 | 
						|
        self.conv = torch.nn.Conv2d(ch, out_channels=1, kernel_size=1, bias=False).requires_grad_(False)
 | 
						|
        x = torch.arange(ch, dtype=torch.float).view(1, ch, 1, 1)
 | 
						|
        self.conv.weight.data[:] = torch.nn.Parameter(x)
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        b, c, a = x.shape
 | 
						|
        x = x.view(b, 4, self.ch, a).transpose(2, 1)
 | 
						|
        return self.conv(x.softmax(1)).view(b, 4, a)
 | 
						|
 | 
						|
 | 
						|
class Head(torch.nn.Module):
 | 
						|
    anchors = torch.empty(0)
 | 
						|
    strides = torch.empty(0)
 | 
						|
 | 
						|
    def __init__(self, nc=80, filters=()):
 | 
						|
        super().__init__()
 | 
						|
        self.ch = 16  # DFL channels
 | 
						|
        self.nc = nc  # number of classes
 | 
						|
        self.nl = len(filters)  # number of detection layers
 | 
						|
        self.no = nc + self.ch * 4  # number of outputs per anchor
 | 
						|
        self.stride = torch.zeros(self.nl)  # strides computed during build
 | 
						|
 | 
						|
        box = max(64, filters[0] // 4)
 | 
						|
        cls = max(80, filters[0], self.nc)
 | 
						|
 | 
						|
        self.dfl = DFL(self.ch)
 | 
						|
        self.box = torch.nn.ModuleList(
 | 
						|
            torch.nn.Sequential(
 | 
						|
                Conv(x, box, torch.nn.SiLU(), k=3, p=1),
 | 
						|
                Conv(box, box, torch.nn.SiLU(), k=3, p=1),
 | 
						|
                torch.nn.Conv2d(box, out_channels=4 * self.ch, kernel_size=1),
 | 
						|
            )
 | 
						|
            for x in filters
 | 
						|
        )
 | 
						|
        self.cls = torch.nn.ModuleList(
 | 
						|
            torch.nn.Sequential(
 | 
						|
                Conv(x, x, torch.nn.SiLU(), k=3, p=1, g=x),
 | 
						|
                Conv(x, cls, torch.nn.SiLU()),
 | 
						|
                Conv(cls, cls, torch.nn.SiLU(), k=3, p=1, g=cls),
 | 
						|
                Conv(cls, cls, torch.nn.SiLU()),
 | 
						|
                torch.nn.Conv2d(cls, out_channels=self.nc, kernel_size=1),
 | 
						|
            )
 | 
						|
            for x in filters
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        for i, (box, cls) in enumerate(zip(self.box, self.cls)):
 | 
						|
            x[i] = torch.cat(tensors=(box(x[i]), cls(x[i])), dim=1)
 | 
						|
        if self.training:
 | 
						|
            return x
 | 
						|
 | 
						|
        self.anchors, self.strides = (i.transpose(0, 1) for i in make_anchors(x, self.stride))
 | 
						|
        x = torch.cat([i.view(x[0].shape[0], self.no, -1) for i in x], dim=2)
 | 
						|
        box, cls = x.split(split_size=(4 * self.ch, self.nc), dim=1)
 | 
						|
 | 
						|
        a, b = self.dfl(box).chunk(2, 1)
 | 
						|
        a = self.anchors.unsqueeze(0) - a
 | 
						|
        b = self.anchors.unsqueeze(0) + b
 | 
						|
        box = torch.cat(tensors=((a + b) / 2, b - a), dim=1)
 | 
						|
 | 
						|
        return torch.cat(tensors=(box * self.strides, cls.sigmoid()), dim=1)
 | 
						|
 | 
						|
    def initialize_biases(self):
 | 
						|
        # Initialize biases
 | 
						|
        # WARNING: requires stride availability
 | 
						|
        for box, cls, s in zip(self.box, self.cls, self.stride):
 | 
						|
            # box
 | 
						|
            box[-1].bias.data[:] = 1.0
 | 
						|
            # cls (.01 objects, 80 classes, 640 image)
 | 
						|
            cls[-1].bias.data[: self.nc] = math.log(5 / self.nc / (640 / s) ** 2)
 | 
						|
 | 
						|
 | 
						|
class YOLO(torch.nn.Module):
 | 
						|
    def __init__(self, width, depth, csp, num_classes):
 | 
						|
        super().__init__()
 | 
						|
        self.net = DarkNet(width, depth, csp)
 | 
						|
        self.fpn = DarkFPN(width, depth, csp)
 | 
						|
 | 
						|
        img_dummy = torch.zeros(1, width[0], 256, 256)
 | 
						|
        self.head = Head(num_classes, (width[3], width[4], width[5]))
 | 
						|
        self.head.stride = torch.tensor([256 / x.shape[-2] for x in self.forward(img_dummy)])
 | 
						|
        self.stride = self.head.stride
 | 
						|
        self.head.initialize_biases()
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        x = self.net(x)
 | 
						|
        x = self.fpn(x)
 | 
						|
        return self.head(list(x))
 | 
						|
 | 
						|
    def fuse(self):
 | 
						|
        for m in self.modules():
 | 
						|
            if type(m) is Conv and hasattr(m, "norm"):
 | 
						|
                m.conv = fuse_conv(m.conv, m.norm)
 | 
						|
                m.forward = m.fuse_forward
 | 
						|
                delattr(m, "norm")
 | 
						|
        return self
 | 
						|
 | 
						|
 | 
						|
def yolo_v11_n(num_classes: int = 80):
 | 
						|
    csp = [False, True]
 | 
						|
    depth = [1, 1, 1, 1, 1, 1]
 | 
						|
    width = [3, 16, 32, 64, 128, 256]
 | 
						|
    return YOLO(width, depth, csp, num_classes)
 | 
						|
 | 
						|
 | 
						|
def yolo_v11_t(num_classes: int = 80):
 | 
						|
    csp = [False, True]
 | 
						|
    depth = [1, 1, 1, 1, 1, 1]
 | 
						|
    width = [3, 24, 48, 96, 192, 384]
 | 
						|
    return YOLO(width, depth, csp, num_classes)
 | 
						|
 | 
						|
 | 
						|
def yolo_v11_s(num_classes: int = 80):
 | 
						|
    csp = [False, True]
 | 
						|
    depth = [1, 1, 1, 1, 1, 1]
 | 
						|
    width = [3, 32, 64, 128, 256, 512]
 | 
						|
    return YOLO(width, depth, csp, num_classes)
 | 
						|
 | 
						|
 | 
						|
def yolo_v11_m(num_classes: int = 80):
 | 
						|
    csp = [True, True]
 | 
						|
    depth = [1, 1, 1, 1, 1, 1]
 | 
						|
    width = [3, 64, 128, 256, 512, 512]
 | 
						|
    return YOLO(width, depth, csp, num_classes)
 | 
						|
 | 
						|
 | 
						|
def yolo_v11_l(num_classes: int = 80):
 | 
						|
    csp = [True, True]
 | 
						|
    depth = [2, 2, 2, 2, 2, 2]
 | 
						|
    width = [3, 64, 128, 256, 512, 512]
 | 
						|
    return YOLO(width, depth, csp, num_classes)
 | 
						|
 | 
						|
 | 
						|
def yolo_v11_x(num_classes: int = 80):
 | 
						|
    csp = [True, True]
 | 
						|
    depth = [2, 2, 2, 2, 2, 2]
 | 
						|
    width = [3, 96, 192, 384, 768, 768]
 | 
						|
    return YOLO(width, depth, csp, num_classes)
 |