Files
fed-yolo/nets/nn.py
2025-10-02 16:31:55 +08:00

363 lines
12 KiB
Python

"""
This file contains the model definition of YOLOv11
"""
import math
import torch
from utils.util import make_anchors
def fuse_conv(conv, norm):
fused_conv = (
torch.nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
groups=conv.groups,
bias=True,
)
.requires_grad_(False)
.to(conv.weight.device)
)
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_norm = torch.diag(norm.weight.div(torch.sqrt(norm.eps + norm.running_var)))
fused_conv.weight.copy_(torch.mm(w_norm, w_conv).view(fused_conv.weight.size()))
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
b_norm = norm.bias - norm.weight.mul(norm.running_mean).div(torch.sqrt(norm.running_var + norm.eps))
if fused_conv.bias is not None:
fused_conv.bias.copy_(torch.mm(w_norm, b_conv.reshape(-1, 1)).reshape(-1) + b_norm)
return fused_conv
class Conv(torch.nn.Module):
def __init__(self, in_ch, out_ch, activation, k=1, s=1, p=0, g=1):
super().__init__()
self.conv = torch.nn.Conv2d(in_ch, out_ch, k, s, p, groups=g, bias=False)
self.norm = torch.nn.BatchNorm2d(out_ch, eps=0.001, momentum=0.03)
self.relu = activation
def forward(self, x):
return self.relu(self.norm(self.conv(x)))
def fuse_forward(self, x):
return self.relu(self.conv(x))
class Residual(torch.nn.Module):
def __init__(self, ch, e=0.5):
super().__init__()
self.conv1 = Conv(ch, int(ch * e), torch.nn.SiLU(), k=3, p=1)
self.conv2 = Conv(int(ch * e), ch, torch.nn.SiLU(), k=3, p=1)
def forward(self, x):
return x + self.conv2(self.conv1(x))
class CSPModule(torch.nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv1 = Conv(in_ch, out_ch // 2, torch.nn.SiLU())
self.conv2 = Conv(in_ch, out_ch // 2, torch.nn.SiLU())
self.conv3 = Conv(2 * (out_ch // 2), out_ch, torch.nn.SiLU())
self.res_m = torch.nn.Sequential(Residual(out_ch // 2, e=1.0), Residual(out_ch // 2, e=1.0))
def forward(self, x):
y = self.res_m(self.conv1(x))
return self.conv3(torch.cat((y, self.conv2(x)), dim=1))
class CSP(torch.nn.Module):
def __init__(self, in_ch, out_ch, n, csp, r):
super().__init__()
self.conv1 = Conv(in_ch, 2 * (out_ch // r), torch.nn.SiLU())
self.conv2 = Conv((2 + n) * (out_ch // r), out_ch, torch.nn.SiLU())
if not csp:
self.res_m = torch.nn.ModuleList(Residual(out_ch // r) for _ in range(n))
else:
self.res_m = torch.nn.ModuleList(CSPModule(out_ch // r, out_ch // r) for _ in range(n))
def forward(self, x):
y = list(self.conv1(x).chunk(2, 1))
y.extend(m(y[-1]) for m in self.res_m)
return self.conv2(torch.cat(y, dim=1))
class SPP(torch.nn.Module):
def __init__(self, in_ch, out_ch, k=5):
super().__init__()
self.conv1 = Conv(in_ch, in_ch // 2, torch.nn.SiLU())
self.conv2 = Conv(in_ch * 2, out_ch, torch.nn.SiLU())
self.res_m = torch.nn.MaxPool2d(k, stride=1, padding=k // 2)
def forward(self, x):
x = self.conv1(x)
y1 = self.res_m(x)
y2 = self.res_m(y1)
return self.conv2(torch.cat(tensors=[x, y1, y2, self.res_m(y2)], dim=1))
class Attention(torch.nn.Module):
def __init__(self, ch, num_head):
super().__init__()
self.num_head = num_head
self.dim_head = ch // num_head
self.dim_key = self.dim_head // 2
self.scale = self.dim_key**-0.5
self.qkv = Conv(ch, ch + self.dim_key * num_head * 2, torch.nn.Identity())
self.conv1 = Conv(ch, ch, torch.nn.Identity(), k=3, p=1, g=ch)
self.conv2 = Conv(ch, ch, torch.nn.Identity())
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv(x)
qkv = qkv.view(b, self.num_head, self.dim_key * 2 + self.dim_head, h * w)
q, k, v = qkv.split([self.dim_key, self.dim_key, self.dim_head], dim=2)
attn = (q.transpose(-2, -1) @ k) * self.scale
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).view(b, c, h, w) + self.conv1(v.reshape(b, c, h, w))
return self.conv2(x)
class PSABlock(torch.nn.Module):
def __init__(self, ch, num_head):
super().__init__()
self.conv1 = Attention(ch, num_head)
self.conv2 = torch.nn.Sequential(Conv(ch, ch * 2, torch.nn.SiLU()), Conv(ch * 2, ch, torch.nn.Identity()))
def forward(self, x):
x = x + self.conv1(x)
return x + self.conv2(x)
class PSA(torch.nn.Module):
def __init__(self, ch, n):
super().__init__()
self.conv1 = Conv(ch, 2 * (ch // 2), torch.nn.SiLU())
self.conv2 = Conv(2 * (ch // 2), ch, torch.nn.SiLU())
self.res_m = torch.nn.Sequential(*(PSABlock(ch // 2, ch // 128) for _ in range(n)))
def forward(self, x):
x, y = self.conv1(x).chunk(2, 1)
return self.conv2(torch.cat(tensors=(x, self.res_m(y)), dim=1))
class DarkNet(torch.nn.Module):
def __init__(self, width, depth, csp):
super().__init__()
self.p1 = []
self.p2 = []
self.p3 = []
self.p4 = []
self.p5 = []
# p1/2
self.p1.append(Conv(width[0], width[1], torch.nn.SiLU(), k=3, s=2, p=1))
# p2/4
self.p2.append(Conv(width[1], width[2], torch.nn.SiLU(), k=3, s=2, p=1))
self.p2.append(CSP(width[2], width[3], depth[0], csp[0], r=4))
# p3/8
self.p3.append(Conv(width[3], width[3], torch.nn.SiLU(), k=3, s=2, p=1))
self.p3.append(CSP(width[3], width[4], depth[1], csp[0], r=4))
# p4/16
self.p4.append(Conv(width[4], width[4], torch.nn.SiLU(), k=3, s=2, p=1))
self.p4.append(CSP(width[4], width[4], depth[2], csp[1], r=2))
# p5/32
self.p5.append(Conv(width[4], width[5], torch.nn.SiLU(), k=3, s=2, p=1))
self.p5.append(CSP(width[5], width[5], depth[3], csp[1], r=2))
self.p5.append(SPP(width[5], width[5]))
self.p5.append(PSA(width[5], depth[4]))
self.p1 = torch.nn.Sequential(*self.p1)
self.p2 = torch.nn.Sequential(*self.p2)
self.p3 = torch.nn.Sequential(*self.p3)
self.p4 = torch.nn.Sequential(*self.p4)
self.p5 = torch.nn.Sequential(*self.p5)
def forward(self, x):
p1 = self.p1(x)
p2 = self.p2(p1)
p3 = self.p3(p2)
p4 = self.p4(p3)
p5 = self.p5(p4)
return p3, p4, p5
class DarkFPN(torch.nn.Module):
def __init__(self, width, depth, csp):
super().__init__()
self.up = torch.nn.Upsample(scale_factor=2)
self.h1 = CSP(width[4] + width[5], width[4], depth[5], csp[0], r=2)
self.h2 = CSP(width[4] + width[4], width[3], depth[5], csp[0], r=2)
self.h3 = Conv(width[3], width[3], torch.nn.SiLU(), k=3, s=2, p=1)
self.h4 = CSP(width[3] + width[4], width[4], depth[5], csp[0], r=2)
self.h5 = Conv(width[4], width[4], torch.nn.SiLU(), k=3, s=2, p=1)
self.h6 = CSP(width[4] + width[5], width[5], depth[5], csp[1], r=2)
def forward(self, x):
p3, p4, p5 = x
p4 = self.h1(torch.cat(tensors=[self.up(p5), p4], dim=1))
p3 = self.h2(torch.cat(tensors=[self.up(p4), p3], dim=1))
p4 = self.h4(torch.cat(tensors=[self.h3(p3), p4], dim=1))
p5 = self.h6(torch.cat(tensors=[self.h5(p4), p5], dim=1))
return p3, p4, p5
class DFL(torch.nn.Module):
# Generalized Focal Loss
# https://ieeexplore.ieee.org/document/9792391
def __init__(self, ch=16):
super().__init__()
self.ch = ch
self.conv = torch.nn.Conv2d(ch, out_channels=1, kernel_size=1, bias=False).requires_grad_(False)
x = torch.arange(ch, dtype=torch.float).view(1, ch, 1, 1)
self.conv.weight.data[:] = torch.nn.Parameter(x)
def forward(self, x):
b, c, a = x.shape
x = x.view(b, 4, self.ch, a).transpose(2, 1)
return self.conv(x.softmax(1)).view(b, 4, a)
class Head(torch.nn.Module):
anchors = torch.empty(0)
strides = torch.empty(0)
def __init__(self, nc=80, filters=()):
super().__init__()
self.ch = 16 # DFL channels
self.nc = nc # number of classes
self.nl = len(filters) # number of detection layers
self.no = nc + self.ch * 4 # number of outputs per anchor
self.stride = torch.zeros(self.nl) # strides computed during build
box = max(64, filters[0] // 4)
cls = max(80, filters[0], self.nc)
self.dfl = DFL(self.ch)
self.box = torch.nn.ModuleList(
torch.nn.Sequential(
Conv(x, box, torch.nn.SiLU(), k=3, p=1),
Conv(box, box, torch.nn.SiLU(), k=3, p=1),
torch.nn.Conv2d(box, out_channels=4 * self.ch, kernel_size=1),
)
for x in filters
)
self.cls = torch.nn.ModuleList(
torch.nn.Sequential(
Conv(x, x, torch.nn.SiLU(), k=3, p=1, g=x),
Conv(x, cls, torch.nn.SiLU()),
Conv(cls, cls, torch.nn.SiLU(), k=3, p=1, g=cls),
Conv(cls, cls, torch.nn.SiLU()),
torch.nn.Conv2d(cls, out_channels=self.nc, kernel_size=1),
)
for x in filters
)
def forward(self, x):
for i, (box, cls) in enumerate(zip(self.box, self.cls)):
x[i] = torch.cat(tensors=(box(x[i]), cls(x[i])), dim=1)
if self.training:
return x
self.anchors, self.strides = (i.transpose(0, 1) for i in make_anchors(x, self.stride))
x = torch.cat([i.view(x[0].shape[0], self.no, -1) for i in x], dim=2)
box, cls = x.split(split_size=(4 * self.ch, self.nc), dim=1)
a, b = self.dfl(box).chunk(2, 1)
a = self.anchors.unsqueeze(0) - a
b = self.anchors.unsqueeze(0) + b
box = torch.cat(tensors=((a + b) / 2, b - a), dim=1)
return torch.cat(tensors=(box * self.strides, cls.sigmoid()), dim=1)
def initialize_biases(self):
# Initialize biases
# WARNING: requires stride availability
for box, cls, s in zip(self.box, self.cls, self.stride):
# box
box[-1].bias.data[:] = 1.0
# cls (.01 objects, 80 classes, 640 image)
cls[-1].bias.data[: self.nc] = math.log(5 / self.nc / (640 / s) ** 2)
class YOLO(torch.nn.Module):
def __init__(self, width, depth, csp, num_classes):
super().__init__()
self.net = DarkNet(width, depth, csp)
self.fpn = DarkFPN(width, depth, csp)
img_dummy = torch.zeros(1, width[0], 256, 256)
self.head = Head(num_classes, (width[3], width[4], width[5]))
self.head.stride = torch.tensor([256 / x.shape[-2] for x in self.forward(img_dummy)])
self.stride = self.head.stride
self.head.initialize_biases()
def forward(self, x):
x = self.net(x)
x = self.fpn(x)
return self.head(list(x))
def fuse(self):
for m in self.modules():
if type(m) is Conv and hasattr(m, "norm"):
m.conv = fuse_conv(m.conv, m.norm)
m.forward = m.fuse_forward
delattr(m, "norm")
return self
def yolo_v11_n(num_classes: int = 80):
csp = [False, True]
depth = [1, 1, 1, 1, 1, 1]
width = [3, 16, 32, 64, 128, 256]
return YOLO(width, depth, csp, num_classes)
def yolo_v11_t(num_classes: int = 80):
csp = [False, True]
depth = [1, 1, 1, 1, 1, 1]
width = [3, 24, 48, 96, 192, 384]
return YOLO(width, depth, csp, num_classes)
def yolo_v11_s(num_classes: int = 80):
csp = [False, True]
depth = [1, 1, 1, 1, 1, 1]
width = [3, 32, 64, 128, 256, 512]
return YOLO(width, depth, csp, num_classes)
def yolo_v11_m(num_classes: int = 80):
csp = [True, True]
depth = [1, 1, 1, 1, 1, 1]
width = [3, 64, 128, 256, 512, 512]
return YOLO(width, depth, csp, num_classes)
def yolo_v11_l(num_classes: int = 80):
csp = [True, True]
depth = [2, 2, 2, 2, 2, 2]
width = [3, 64, 128, 256, 512, 512]
return YOLO(width, depth, csp, num_classes)
def yolo_v11_x(num_classes: int = 80):
csp = [True, True]
depth = [2, 2, 2, 2, 2, 2]
width = [3, 96, 192, 384, 768, 768]
return YOLO(width, depth, csp, num_classes)