363 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			363 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								"""
							 | 
						||
| 
								 | 
							
								This file contains the model definition of YOLOv11
							 | 
						||
| 
								 | 
							
								"""
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import math
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import torch
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from utils.util import make_anchors
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def fuse_conv(conv, norm):
							 | 
						||
| 
								 | 
							
								    fused_conv = (
							 | 
						||
| 
								 | 
							
								        torch.nn.Conv2d(
							 | 
						||
| 
								 | 
							
								            conv.in_channels,
							 | 
						||
| 
								 | 
							
								            conv.out_channels,
							 | 
						||
| 
								 | 
							
								            kernel_size=conv.kernel_size,
							 | 
						||
| 
								 | 
							
								            stride=conv.stride,
							 | 
						||
| 
								 | 
							
								            padding=conv.padding,
							 | 
						||
| 
								 | 
							
								            groups=conv.groups,
							 | 
						||
| 
								 | 
							
								            bias=True,
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        .requires_grad_(False)
							 | 
						||
| 
								 | 
							
								        .to(conv.weight.device)
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    w_conv = conv.weight.clone().view(conv.out_channels, -1)
							 | 
						||
| 
								 | 
							
								    w_norm = torch.diag(norm.weight.div(torch.sqrt(norm.eps + norm.running_var)))
							 | 
						||
| 
								 | 
							
								    fused_conv.weight.copy_(torch.mm(w_norm, w_conv).view(fused_conv.weight.size()))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
							 | 
						||
| 
								 | 
							
								    b_norm = norm.bias - norm.weight.mul(norm.running_mean).div(torch.sqrt(norm.running_var + norm.eps))
							 | 
						||
| 
								 | 
							
								    if fused_conv.bias is not None:
							 | 
						||
| 
								 | 
							
								        fused_conv.bias.copy_(torch.mm(w_norm, b_conv.reshape(-1, 1)).reshape(-1) + b_norm)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return fused_conv
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class Conv(torch.nn.Module):
							 | 
						||
| 
								 | 
							
								    def __init__(self, in_ch, out_ch, activation, k=1, s=1, p=0, g=1):
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        self.conv = torch.nn.Conv2d(in_ch, out_ch, k, s, p, groups=g, bias=False)
							 | 
						||
| 
								 | 
							
								        self.norm = torch.nn.BatchNorm2d(out_ch, eps=0.001, momentum=0.03)
							 | 
						||
| 
								 | 
							
								        self.relu = activation
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, x):
							 | 
						||
| 
								 | 
							
								        return self.relu(self.norm(self.conv(x)))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def fuse_forward(self, x):
							 | 
						||
| 
								 | 
							
								        return self.relu(self.conv(x))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class Residual(torch.nn.Module):
							 | 
						||
| 
								 | 
							
								    def __init__(self, ch, e=0.5):
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        self.conv1 = Conv(ch, int(ch * e), torch.nn.SiLU(), k=3, p=1)
							 | 
						||
| 
								 | 
							
								        self.conv2 = Conv(int(ch * e), ch, torch.nn.SiLU(), k=3, p=1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, x):
							 | 
						||
| 
								 | 
							
								        return x + self.conv2(self.conv1(x))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class CSPModule(torch.nn.Module):
							 | 
						||
| 
								 | 
							
								    def __init__(self, in_ch, out_ch):
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        self.conv1 = Conv(in_ch, out_ch // 2, torch.nn.SiLU())
							 | 
						||
| 
								 | 
							
								        self.conv2 = Conv(in_ch, out_ch // 2, torch.nn.SiLU())
							 | 
						||
| 
								 | 
							
								        self.conv3 = Conv(2 * (out_ch // 2), out_ch, torch.nn.SiLU())
							 | 
						||
| 
								 | 
							
								        self.res_m = torch.nn.Sequential(Residual(out_ch // 2, e=1.0), Residual(out_ch // 2, e=1.0))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, x):
							 | 
						||
| 
								 | 
							
								        y = self.res_m(self.conv1(x))
							 | 
						||
| 
								 | 
							
								        return self.conv3(torch.cat((y, self.conv2(x)), dim=1))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class CSP(torch.nn.Module):
							 | 
						||
| 
								 | 
							
								    def __init__(self, in_ch, out_ch, n, csp, r):
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        self.conv1 = Conv(in_ch, 2 * (out_ch // r), torch.nn.SiLU())
							 | 
						||
| 
								 | 
							
								        self.conv2 = Conv((2 + n) * (out_ch // r), out_ch, torch.nn.SiLU())
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if not csp:
							 | 
						||
| 
								 | 
							
								            self.res_m = torch.nn.ModuleList(Residual(out_ch // r) for _ in range(n))
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            self.res_m = torch.nn.ModuleList(CSPModule(out_ch // r, out_ch // r) for _ in range(n))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, x):
							 | 
						||
| 
								 | 
							
								        y = list(self.conv1(x).chunk(2, 1))
							 | 
						||
| 
								 | 
							
								        y.extend(m(y[-1]) for m in self.res_m)
							 | 
						||
| 
								 | 
							
								        return self.conv2(torch.cat(y, dim=1))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class SPP(torch.nn.Module):
							 | 
						||
| 
								 | 
							
								    def __init__(self, in_ch, out_ch, k=5):
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        self.conv1 = Conv(in_ch, in_ch // 2, torch.nn.SiLU())
							 | 
						||
| 
								 | 
							
								        self.conv2 = Conv(in_ch * 2, out_ch, torch.nn.SiLU())
							 | 
						||
| 
								 | 
							
								        self.res_m = torch.nn.MaxPool2d(k, stride=1, padding=k // 2)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, x):
							 | 
						||
| 
								 | 
							
								        x = self.conv1(x)
							 | 
						||
| 
								 | 
							
								        y1 = self.res_m(x)
							 | 
						||
| 
								 | 
							
								        y2 = self.res_m(y1)
							 | 
						||
| 
								 | 
							
								        return self.conv2(torch.cat(tensors=[x, y1, y2, self.res_m(y2)], dim=1))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class Attention(torch.nn.Module):
							 | 
						||
| 
								 | 
							
								    def __init__(self, ch, num_head):
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        self.num_head = num_head
							 | 
						||
| 
								 | 
							
								        self.dim_head = ch // num_head
							 | 
						||
| 
								 | 
							
								        self.dim_key = self.dim_head // 2
							 | 
						||
| 
								 | 
							
								        self.scale = self.dim_key**-0.5
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.qkv = Conv(ch, ch + self.dim_key * num_head * 2, torch.nn.Identity())
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.conv1 = Conv(ch, ch, torch.nn.Identity(), k=3, p=1, g=ch)
							 | 
						||
| 
								 | 
							
								        self.conv2 = Conv(ch, ch, torch.nn.Identity())
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, x):
							 | 
						||
| 
								 | 
							
								        b, c, h, w = x.shape
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        qkv = self.qkv(x)
							 | 
						||
| 
								 | 
							
								        qkv = qkv.view(b, self.num_head, self.dim_key * 2 + self.dim_head, h * w)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        q, k, v = qkv.split([self.dim_key, self.dim_key, self.dim_head], dim=2)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        attn = (q.transpose(-2, -1) @ k) * self.scale
							 | 
						||
| 
								 | 
							
								        attn = attn.softmax(dim=-1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        x = (v @ attn.transpose(-2, -1)).view(b, c, h, w) + self.conv1(v.reshape(b, c, h, w))
							 | 
						||
| 
								 | 
							
								        return self.conv2(x)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class PSABlock(torch.nn.Module):
							 | 
						||
| 
								 | 
							
								    def __init__(self, ch, num_head):
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        self.conv1 = Attention(ch, num_head)
							 | 
						||
| 
								 | 
							
								        self.conv2 = torch.nn.Sequential(Conv(ch, ch * 2, torch.nn.SiLU()), Conv(ch * 2, ch, torch.nn.Identity()))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, x):
							 | 
						||
| 
								 | 
							
								        x = x + self.conv1(x)
							 | 
						||
| 
								 | 
							
								        return x + self.conv2(x)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class PSA(torch.nn.Module):
							 | 
						||
| 
								 | 
							
								    def __init__(self, ch, n):
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        self.conv1 = Conv(ch, 2 * (ch // 2), torch.nn.SiLU())
							 | 
						||
| 
								 | 
							
								        self.conv2 = Conv(2 * (ch // 2), ch, torch.nn.SiLU())
							 | 
						||
| 
								 | 
							
								        self.res_m = torch.nn.Sequential(*(PSABlock(ch // 2, ch // 128) for _ in range(n)))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, x):
							 | 
						||
| 
								 | 
							
								        x, y = self.conv1(x).chunk(2, 1)
							 | 
						||
| 
								 | 
							
								        return self.conv2(torch.cat(tensors=(x, self.res_m(y)), dim=1))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class DarkNet(torch.nn.Module):
							 | 
						||
| 
								 | 
							
								    def __init__(self, width, depth, csp):
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        self.p1 = []
							 | 
						||
| 
								 | 
							
								        self.p2 = []
							 | 
						||
| 
								 | 
							
								        self.p3 = []
							 | 
						||
| 
								 | 
							
								        self.p4 = []
							 | 
						||
| 
								 | 
							
								        self.p5 = []
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # p1/2
							 | 
						||
| 
								 | 
							
								        self.p1.append(Conv(width[0], width[1], torch.nn.SiLU(), k=3, s=2, p=1))
							 | 
						||
| 
								 | 
							
								        # p2/4
							 | 
						||
| 
								 | 
							
								        self.p2.append(Conv(width[1], width[2], torch.nn.SiLU(), k=3, s=2, p=1))
							 | 
						||
| 
								 | 
							
								        self.p2.append(CSP(width[2], width[3], depth[0], csp[0], r=4))
							 | 
						||
| 
								 | 
							
								        # p3/8
							 | 
						||
| 
								 | 
							
								        self.p3.append(Conv(width[3], width[3], torch.nn.SiLU(), k=3, s=2, p=1))
							 | 
						||
| 
								 | 
							
								        self.p3.append(CSP(width[3], width[4], depth[1], csp[0], r=4))
							 | 
						||
| 
								 | 
							
								        # p4/16
							 | 
						||
| 
								 | 
							
								        self.p4.append(Conv(width[4], width[4], torch.nn.SiLU(), k=3, s=2, p=1))
							 | 
						||
| 
								 | 
							
								        self.p4.append(CSP(width[4], width[4], depth[2], csp[1], r=2))
							 | 
						||
| 
								 | 
							
								        # p5/32
							 | 
						||
| 
								 | 
							
								        self.p5.append(Conv(width[4], width[5], torch.nn.SiLU(), k=3, s=2, p=1))
							 | 
						||
| 
								 | 
							
								        self.p5.append(CSP(width[5], width[5], depth[3], csp[1], r=2))
							 | 
						||
| 
								 | 
							
								        self.p5.append(SPP(width[5], width[5]))
							 | 
						||
| 
								 | 
							
								        self.p5.append(PSA(width[5], depth[4]))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.p1 = torch.nn.Sequential(*self.p1)
							 | 
						||
| 
								 | 
							
								        self.p2 = torch.nn.Sequential(*self.p2)
							 | 
						||
| 
								 | 
							
								        self.p3 = torch.nn.Sequential(*self.p3)
							 | 
						||
| 
								 | 
							
								        self.p4 = torch.nn.Sequential(*self.p4)
							 | 
						||
| 
								 | 
							
								        self.p5 = torch.nn.Sequential(*self.p5)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, x):
							 | 
						||
| 
								 | 
							
								        p1 = self.p1(x)
							 | 
						||
| 
								 | 
							
								        p2 = self.p2(p1)
							 | 
						||
| 
								 | 
							
								        p3 = self.p3(p2)
							 | 
						||
| 
								 | 
							
								        p4 = self.p4(p3)
							 | 
						||
| 
								 | 
							
								        p5 = self.p5(p4)
							 | 
						||
| 
								 | 
							
								        return p3, p4, p5
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class DarkFPN(torch.nn.Module):
							 | 
						||
| 
								 | 
							
								    def __init__(self, width, depth, csp):
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        self.up = torch.nn.Upsample(scale_factor=2)
							 | 
						||
| 
								 | 
							
								        self.h1 = CSP(width[4] + width[5], width[4], depth[5], csp[0], r=2)
							 | 
						||
| 
								 | 
							
								        self.h2 = CSP(width[4] + width[4], width[3], depth[5], csp[0], r=2)
							 | 
						||
| 
								 | 
							
								        self.h3 = Conv(width[3], width[3], torch.nn.SiLU(), k=3, s=2, p=1)
							 | 
						||
| 
								 | 
							
								        self.h4 = CSP(width[3] + width[4], width[4], depth[5], csp[0], r=2)
							 | 
						||
| 
								 | 
							
								        self.h5 = Conv(width[4], width[4], torch.nn.SiLU(), k=3, s=2, p=1)
							 | 
						||
| 
								 | 
							
								        self.h6 = CSP(width[4] + width[5], width[5], depth[5], csp[1], r=2)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, x):
							 | 
						||
| 
								 | 
							
								        p3, p4, p5 = x
							 | 
						||
| 
								 | 
							
								        p4 = self.h1(torch.cat(tensors=[self.up(p5), p4], dim=1))
							 | 
						||
| 
								 | 
							
								        p3 = self.h2(torch.cat(tensors=[self.up(p4), p3], dim=1))
							 | 
						||
| 
								 | 
							
								        p4 = self.h4(torch.cat(tensors=[self.h3(p3), p4], dim=1))
							 | 
						||
| 
								 | 
							
								        p5 = self.h6(torch.cat(tensors=[self.h5(p4), p5], dim=1))
							 | 
						||
| 
								 | 
							
								        return p3, p4, p5
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class DFL(torch.nn.Module):
							 | 
						||
| 
								 | 
							
								    # Generalized Focal Loss
							 | 
						||
| 
								 | 
							
								    # https://ieeexplore.ieee.org/document/9792391
							 | 
						||
| 
								 | 
							
								    def __init__(self, ch=16):
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        self.ch = ch
							 | 
						||
| 
								 | 
							
								        self.conv = torch.nn.Conv2d(ch, out_channels=1, kernel_size=1, bias=False).requires_grad_(False)
							 | 
						||
| 
								 | 
							
								        x = torch.arange(ch, dtype=torch.float).view(1, ch, 1, 1)
							 | 
						||
| 
								 | 
							
								        self.conv.weight.data[:] = torch.nn.Parameter(x)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, x):
							 | 
						||
| 
								 | 
							
								        b, c, a = x.shape
							 | 
						||
| 
								 | 
							
								        x = x.view(b, 4, self.ch, a).transpose(2, 1)
							 | 
						||
| 
								 | 
							
								        return self.conv(x.softmax(1)).view(b, 4, a)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class Head(torch.nn.Module):
							 | 
						||
| 
								 | 
							
								    anchors = torch.empty(0)
							 | 
						||
| 
								 | 
							
								    strides = torch.empty(0)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(self, nc=80, filters=()):
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        self.ch = 16  # DFL channels
							 | 
						||
| 
								 | 
							
								        self.nc = nc  # number of classes
							 | 
						||
| 
								 | 
							
								        self.nl = len(filters)  # number of detection layers
							 | 
						||
| 
								 | 
							
								        self.no = nc + self.ch * 4  # number of outputs per anchor
							 | 
						||
| 
								 | 
							
								        self.stride = torch.zeros(self.nl)  # strides computed during build
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        box = max(64, filters[0] // 4)
							 | 
						||
| 
								 | 
							
								        cls = max(80, filters[0], self.nc)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.dfl = DFL(self.ch)
							 | 
						||
| 
								 | 
							
								        self.box = torch.nn.ModuleList(
							 | 
						||
| 
								 | 
							
								            torch.nn.Sequential(
							 | 
						||
| 
								 | 
							
								                Conv(x, box, torch.nn.SiLU(), k=3, p=1),
							 | 
						||
| 
								 | 
							
								                Conv(box, box, torch.nn.SiLU(), k=3, p=1),
							 | 
						||
| 
								 | 
							
								                torch.nn.Conv2d(box, out_channels=4 * self.ch, kernel_size=1),
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								            for x in filters
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        self.cls = torch.nn.ModuleList(
							 | 
						||
| 
								 | 
							
								            torch.nn.Sequential(
							 | 
						||
| 
								 | 
							
								                Conv(x, x, torch.nn.SiLU(), k=3, p=1, g=x),
							 | 
						||
| 
								 | 
							
								                Conv(x, cls, torch.nn.SiLU()),
							 | 
						||
| 
								 | 
							
								                Conv(cls, cls, torch.nn.SiLU(), k=3, p=1, g=cls),
							 | 
						||
| 
								 | 
							
								                Conv(cls, cls, torch.nn.SiLU()),
							 | 
						||
| 
								 | 
							
								                torch.nn.Conv2d(cls, out_channels=self.nc, kernel_size=1),
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								            for x in filters
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, x):
							 | 
						||
| 
								 | 
							
								        for i, (box, cls) in enumerate(zip(self.box, self.cls)):
							 | 
						||
| 
								 | 
							
								            x[i] = torch.cat(tensors=(box(x[i]), cls(x[i])), dim=1)
							 | 
						||
| 
								 | 
							
								        if self.training:
							 | 
						||
| 
								 | 
							
								            return x
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.anchors, self.strides = (i.transpose(0, 1) for i in make_anchors(x, self.stride))
							 | 
						||
| 
								 | 
							
								        x = torch.cat([i.view(x[0].shape[0], self.no, -1) for i in x], dim=2)
							 | 
						||
| 
								 | 
							
								        box, cls = x.split(split_size=(4 * self.ch, self.nc), dim=1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        a, b = self.dfl(box).chunk(2, 1)
							 | 
						||
| 
								 | 
							
								        a = self.anchors.unsqueeze(0) - a
							 | 
						||
| 
								 | 
							
								        b = self.anchors.unsqueeze(0) + b
							 | 
						||
| 
								 | 
							
								        box = torch.cat(tensors=((a + b) / 2, b - a), dim=1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return torch.cat(tensors=(box * self.strides, cls.sigmoid()), dim=1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def initialize_biases(self):
							 | 
						||
| 
								 | 
							
								        # Initialize biases
							 | 
						||
| 
								 | 
							
								        # WARNING: requires stride availability
							 | 
						||
| 
								 | 
							
								        for box, cls, s in zip(self.box, self.cls, self.stride):
							 | 
						||
| 
								 | 
							
								            # box
							 | 
						||
| 
								 | 
							
								            box[-1].bias.data[:] = 1.0
							 | 
						||
| 
								 | 
							
								            # cls (.01 objects, 80 classes, 640 image)
							 | 
						||
| 
								 | 
							
								            cls[-1].bias.data[: self.nc] = math.log(5 / self.nc / (640 / s) ** 2)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class YOLO(torch.nn.Module):
							 | 
						||
| 
								 | 
							
								    def __init__(self, width, depth, csp, num_classes):
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        self.net = DarkNet(width, depth, csp)
							 | 
						||
| 
								 | 
							
								        self.fpn = DarkFPN(width, depth, csp)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        img_dummy = torch.zeros(1, width[0], 256, 256)
							 | 
						||
| 
								 | 
							
								        self.head = Head(num_classes, (width[3], width[4], width[5]))
							 | 
						||
| 
								 | 
							
								        self.head.stride = torch.tensor([256 / x.shape[-2] for x in self.forward(img_dummy)])
							 | 
						||
| 
								 | 
							
								        self.stride = self.head.stride
							 | 
						||
| 
								 | 
							
								        self.head.initialize_biases()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def forward(self, x):
							 | 
						||
| 
								 | 
							
								        x = self.net(x)
							 | 
						||
| 
								 | 
							
								        x = self.fpn(x)
							 | 
						||
| 
								 | 
							
								        return self.head(list(x))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def fuse(self):
							 | 
						||
| 
								 | 
							
								        for m in self.modules():
							 | 
						||
| 
								 | 
							
								            if type(m) is Conv and hasattr(m, "norm"):
							 | 
						||
| 
								 | 
							
								                m.conv = fuse_conv(m.conv, m.norm)
							 | 
						||
| 
								 | 
							
								                m.forward = m.fuse_forward
							 | 
						||
| 
								 | 
							
								                delattr(m, "norm")
							 | 
						||
| 
								 | 
							
								        return self
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def yolo_v11_n(num_classes: int = 80):
							 | 
						||
| 
								 | 
							
								    csp = [False, True]
							 | 
						||
| 
								 | 
							
								    depth = [1, 1, 1, 1, 1, 1]
							 | 
						||
| 
								 | 
							
								    width = [3, 16, 32, 64, 128, 256]
							 | 
						||
| 
								 | 
							
								    return YOLO(width, depth, csp, num_classes)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def yolo_v11_t(num_classes: int = 80):
							 | 
						||
| 
								 | 
							
								    csp = [False, True]
							 | 
						||
| 
								 | 
							
								    depth = [1, 1, 1, 1, 1, 1]
							 | 
						||
| 
								 | 
							
								    width = [3, 24, 48, 96, 192, 384]
							 | 
						||
| 
								 | 
							
								    return YOLO(width, depth, csp, num_classes)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def yolo_v11_s(num_classes: int = 80):
							 | 
						||
| 
								 | 
							
								    csp = [False, True]
							 | 
						||
| 
								 | 
							
								    depth = [1, 1, 1, 1, 1, 1]
							 | 
						||
| 
								 | 
							
								    width = [3, 32, 64, 128, 256, 512]
							 | 
						||
| 
								 | 
							
								    return YOLO(width, depth, csp, num_classes)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def yolo_v11_m(num_classes: int = 80):
							 | 
						||
| 
								 | 
							
								    csp = [True, True]
							 | 
						||
| 
								 | 
							
								    depth = [1, 1, 1, 1, 1, 1]
							 | 
						||
| 
								 | 
							
								    width = [3, 64, 128, 256, 512, 512]
							 | 
						||
| 
								 | 
							
								    return YOLO(width, depth, csp, num_classes)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def yolo_v11_l(num_classes: int = 80):
							 | 
						||
| 
								 | 
							
								    csp = [True, True]
							 | 
						||
| 
								 | 
							
								    depth = [2, 2, 2, 2, 2, 2]
							 | 
						||
| 
								 | 
							
								    width = [3, 64, 128, 256, 512, 512]
							 | 
						||
| 
								 | 
							
								    return YOLO(width, depth, csp, num_classes)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def yolo_v11_x(num_classes: int = 80):
							 | 
						||
| 
								 | 
							
								    csp = [True, True]
							 | 
						||
| 
								 | 
							
								    depth = [2, 2, 2, 2, 2, 2]
							 | 
						||
| 
								 | 
							
								    width = [3, 96, 192, 384, 768, 768]
							 | 
						||
| 
								 | 
							
								    return YOLO(width, depth, csp, num_classes)
							 |