Fedavg and YOLOv11 training
This commit is contained in:
362
nets/nn.py
Normal file
362
nets/nn.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
This file contains the model definition of YOLOv11
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from utils.util import make_anchors
|
||||
|
||||
|
||||
def fuse_conv(conv, norm):
|
||||
fused_conv = (
|
||||
torch.nn.Conv2d(
|
||||
conv.in_channels,
|
||||
conv.out_channels,
|
||||
kernel_size=conv.kernel_size,
|
||||
stride=conv.stride,
|
||||
padding=conv.padding,
|
||||
groups=conv.groups,
|
||||
bias=True,
|
||||
)
|
||||
.requires_grad_(False)
|
||||
.to(conv.weight.device)
|
||||
)
|
||||
|
||||
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
||||
w_norm = torch.diag(norm.weight.div(torch.sqrt(norm.eps + norm.running_var)))
|
||||
fused_conv.weight.copy_(torch.mm(w_norm, w_conv).view(fused_conv.weight.size()))
|
||||
|
||||
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
||||
b_norm = norm.bias - norm.weight.mul(norm.running_mean).div(torch.sqrt(norm.running_var + norm.eps))
|
||||
if fused_conv.bias is not None:
|
||||
fused_conv.bias.copy_(torch.mm(w_norm, b_conv.reshape(-1, 1)).reshape(-1) + b_norm)
|
||||
|
||||
return fused_conv
|
||||
|
||||
|
||||
class Conv(torch.nn.Module):
|
||||
def __init__(self, in_ch, out_ch, activation, k=1, s=1, p=0, g=1):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_ch, out_ch, k, s, p, groups=g, bias=False)
|
||||
self.norm = torch.nn.BatchNorm2d(out_ch, eps=0.001, momentum=0.03)
|
||||
self.relu = activation
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(self.norm(self.conv(x)))
|
||||
|
||||
def fuse_forward(self, x):
|
||||
return self.relu(self.conv(x))
|
||||
|
||||
|
||||
class Residual(torch.nn.Module):
|
||||
def __init__(self, ch, e=0.5):
|
||||
super().__init__()
|
||||
self.conv1 = Conv(ch, int(ch * e), torch.nn.SiLU(), k=3, p=1)
|
||||
self.conv2 = Conv(int(ch * e), ch, torch.nn.SiLU(), k=3, p=1)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.conv2(self.conv1(x))
|
||||
|
||||
|
||||
class CSPModule(torch.nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.conv1 = Conv(in_ch, out_ch // 2, torch.nn.SiLU())
|
||||
self.conv2 = Conv(in_ch, out_ch // 2, torch.nn.SiLU())
|
||||
self.conv3 = Conv(2 * (out_ch // 2), out_ch, torch.nn.SiLU())
|
||||
self.res_m = torch.nn.Sequential(Residual(out_ch // 2, e=1.0), Residual(out_ch // 2, e=1.0))
|
||||
|
||||
def forward(self, x):
|
||||
y = self.res_m(self.conv1(x))
|
||||
return self.conv3(torch.cat((y, self.conv2(x)), dim=1))
|
||||
|
||||
|
||||
class CSP(torch.nn.Module):
|
||||
def __init__(self, in_ch, out_ch, n, csp, r):
|
||||
super().__init__()
|
||||
self.conv1 = Conv(in_ch, 2 * (out_ch // r), torch.nn.SiLU())
|
||||
self.conv2 = Conv((2 + n) * (out_ch // r), out_ch, torch.nn.SiLU())
|
||||
|
||||
if not csp:
|
||||
self.res_m = torch.nn.ModuleList(Residual(out_ch // r) for _ in range(n))
|
||||
else:
|
||||
self.res_m = torch.nn.ModuleList(CSPModule(out_ch // r, out_ch // r) for _ in range(n))
|
||||
|
||||
def forward(self, x):
|
||||
y = list(self.conv1(x).chunk(2, 1))
|
||||
y.extend(m(y[-1]) for m in self.res_m)
|
||||
return self.conv2(torch.cat(y, dim=1))
|
||||
|
||||
|
||||
class SPP(torch.nn.Module):
|
||||
def __init__(self, in_ch, out_ch, k=5):
|
||||
super().__init__()
|
||||
self.conv1 = Conv(in_ch, in_ch // 2, torch.nn.SiLU())
|
||||
self.conv2 = Conv(in_ch * 2, out_ch, torch.nn.SiLU())
|
||||
self.res_m = torch.nn.MaxPool2d(k, stride=1, padding=k // 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
y1 = self.res_m(x)
|
||||
y2 = self.res_m(y1)
|
||||
return self.conv2(torch.cat(tensors=[x, y1, y2, self.res_m(y2)], dim=1))
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
def __init__(self, ch, num_head):
|
||||
super().__init__()
|
||||
self.num_head = num_head
|
||||
self.dim_head = ch // num_head
|
||||
self.dim_key = self.dim_head // 2
|
||||
self.scale = self.dim_key**-0.5
|
||||
|
||||
self.qkv = Conv(ch, ch + self.dim_key * num_head * 2, torch.nn.Identity())
|
||||
|
||||
self.conv1 = Conv(ch, ch, torch.nn.Identity(), k=3, p=1, g=ch)
|
||||
self.conv2 = Conv(ch, ch, torch.nn.Identity())
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv(x)
|
||||
qkv = qkv.view(b, self.num_head, self.dim_key * 2 + self.dim_head, h * w)
|
||||
|
||||
q, k, v = qkv.split([self.dim_key, self.dim_key, self.dim_head], dim=2)
|
||||
|
||||
attn = (q.transpose(-2, -1) @ k) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
x = (v @ attn.transpose(-2, -1)).view(b, c, h, w) + self.conv1(v.reshape(b, c, h, w))
|
||||
return self.conv2(x)
|
||||
|
||||
|
||||
class PSABlock(torch.nn.Module):
|
||||
def __init__(self, ch, num_head):
|
||||
super().__init__()
|
||||
self.conv1 = Attention(ch, num_head)
|
||||
self.conv2 = torch.nn.Sequential(Conv(ch, ch * 2, torch.nn.SiLU()), Conv(ch * 2, ch, torch.nn.Identity()))
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.conv1(x)
|
||||
return x + self.conv2(x)
|
||||
|
||||
|
||||
class PSA(torch.nn.Module):
|
||||
def __init__(self, ch, n):
|
||||
super().__init__()
|
||||
self.conv1 = Conv(ch, 2 * (ch // 2), torch.nn.SiLU())
|
||||
self.conv2 = Conv(2 * (ch // 2), ch, torch.nn.SiLU())
|
||||
self.res_m = torch.nn.Sequential(*(PSABlock(ch // 2, ch // 128) for _ in range(n)))
|
||||
|
||||
def forward(self, x):
|
||||
x, y = self.conv1(x).chunk(2, 1)
|
||||
return self.conv2(torch.cat(tensors=(x, self.res_m(y)), dim=1))
|
||||
|
||||
|
||||
class DarkNet(torch.nn.Module):
|
||||
def __init__(self, width, depth, csp):
|
||||
super().__init__()
|
||||
self.p1 = []
|
||||
self.p2 = []
|
||||
self.p3 = []
|
||||
self.p4 = []
|
||||
self.p5 = []
|
||||
|
||||
# p1/2
|
||||
self.p1.append(Conv(width[0], width[1], torch.nn.SiLU(), k=3, s=2, p=1))
|
||||
# p2/4
|
||||
self.p2.append(Conv(width[1], width[2], torch.nn.SiLU(), k=3, s=2, p=1))
|
||||
self.p2.append(CSP(width[2], width[3], depth[0], csp[0], r=4))
|
||||
# p3/8
|
||||
self.p3.append(Conv(width[3], width[3], torch.nn.SiLU(), k=3, s=2, p=1))
|
||||
self.p3.append(CSP(width[3], width[4], depth[1], csp[0], r=4))
|
||||
# p4/16
|
||||
self.p4.append(Conv(width[4], width[4], torch.nn.SiLU(), k=3, s=2, p=1))
|
||||
self.p4.append(CSP(width[4], width[4], depth[2], csp[1], r=2))
|
||||
# p5/32
|
||||
self.p5.append(Conv(width[4], width[5], torch.nn.SiLU(), k=3, s=2, p=1))
|
||||
self.p5.append(CSP(width[5], width[5], depth[3], csp[1], r=2))
|
||||
self.p5.append(SPP(width[5], width[5]))
|
||||
self.p5.append(PSA(width[5], depth[4]))
|
||||
|
||||
self.p1 = torch.nn.Sequential(*self.p1)
|
||||
self.p2 = torch.nn.Sequential(*self.p2)
|
||||
self.p3 = torch.nn.Sequential(*self.p3)
|
||||
self.p4 = torch.nn.Sequential(*self.p4)
|
||||
self.p5 = torch.nn.Sequential(*self.p5)
|
||||
|
||||
def forward(self, x):
|
||||
p1 = self.p1(x)
|
||||
p2 = self.p2(p1)
|
||||
p3 = self.p3(p2)
|
||||
p4 = self.p4(p3)
|
||||
p5 = self.p5(p4)
|
||||
return p3, p4, p5
|
||||
|
||||
|
||||
class DarkFPN(torch.nn.Module):
|
||||
def __init__(self, width, depth, csp):
|
||||
super().__init__()
|
||||
self.up = torch.nn.Upsample(scale_factor=2)
|
||||
self.h1 = CSP(width[4] + width[5], width[4], depth[5], csp[0], r=2)
|
||||
self.h2 = CSP(width[4] + width[4], width[3], depth[5], csp[0], r=2)
|
||||
self.h3 = Conv(width[3], width[3], torch.nn.SiLU(), k=3, s=2, p=1)
|
||||
self.h4 = CSP(width[3] + width[4], width[4], depth[5], csp[0], r=2)
|
||||
self.h5 = Conv(width[4], width[4], torch.nn.SiLU(), k=3, s=2, p=1)
|
||||
self.h6 = CSP(width[4] + width[5], width[5], depth[5], csp[1], r=2)
|
||||
|
||||
def forward(self, x):
|
||||
p3, p4, p5 = x
|
||||
p4 = self.h1(torch.cat(tensors=[self.up(p5), p4], dim=1))
|
||||
p3 = self.h2(torch.cat(tensors=[self.up(p4), p3], dim=1))
|
||||
p4 = self.h4(torch.cat(tensors=[self.h3(p3), p4], dim=1))
|
||||
p5 = self.h6(torch.cat(tensors=[self.h5(p4), p5], dim=1))
|
||||
return p3, p4, p5
|
||||
|
||||
|
||||
class DFL(torch.nn.Module):
|
||||
# Generalized Focal Loss
|
||||
# https://ieeexplore.ieee.org/document/9792391
|
||||
def __init__(self, ch=16):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.conv = torch.nn.Conv2d(ch, out_channels=1, kernel_size=1, bias=False).requires_grad_(False)
|
||||
x = torch.arange(ch, dtype=torch.float).view(1, ch, 1, 1)
|
||||
self.conv.weight.data[:] = torch.nn.Parameter(x)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, a = x.shape
|
||||
x = x.view(b, 4, self.ch, a).transpose(2, 1)
|
||||
return self.conv(x.softmax(1)).view(b, 4, a)
|
||||
|
||||
|
||||
class Head(torch.nn.Module):
|
||||
anchors = torch.empty(0)
|
||||
strides = torch.empty(0)
|
||||
|
||||
def __init__(self, nc=80, filters=()):
|
||||
super().__init__()
|
||||
self.ch = 16 # DFL channels
|
||||
self.nc = nc # number of classes
|
||||
self.nl = len(filters) # number of detection layers
|
||||
self.no = nc + self.ch * 4 # number of outputs per anchor
|
||||
self.stride = torch.zeros(self.nl) # strides computed during build
|
||||
|
||||
box = max(64, filters[0] // 4)
|
||||
cls = max(80, filters[0], self.nc)
|
||||
|
||||
self.dfl = DFL(self.ch)
|
||||
self.box = torch.nn.ModuleList(
|
||||
torch.nn.Sequential(
|
||||
Conv(x, box, torch.nn.SiLU(), k=3, p=1),
|
||||
Conv(box, box, torch.nn.SiLU(), k=3, p=1),
|
||||
torch.nn.Conv2d(box, out_channels=4 * self.ch, kernel_size=1),
|
||||
)
|
||||
for x in filters
|
||||
)
|
||||
self.cls = torch.nn.ModuleList(
|
||||
torch.nn.Sequential(
|
||||
Conv(x, x, torch.nn.SiLU(), k=3, p=1, g=x),
|
||||
Conv(x, cls, torch.nn.SiLU()),
|
||||
Conv(cls, cls, torch.nn.SiLU(), k=3, p=1, g=cls),
|
||||
Conv(cls, cls, torch.nn.SiLU()),
|
||||
torch.nn.Conv2d(cls, out_channels=self.nc, kernel_size=1),
|
||||
)
|
||||
for x in filters
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for i, (box, cls) in enumerate(zip(self.box, self.cls)):
|
||||
x[i] = torch.cat(tensors=(box(x[i]), cls(x[i])), dim=1)
|
||||
if self.training:
|
||||
return x
|
||||
|
||||
self.anchors, self.strides = (i.transpose(0, 1) for i in make_anchors(x, self.stride))
|
||||
x = torch.cat([i.view(x[0].shape[0], self.no, -1) for i in x], dim=2)
|
||||
box, cls = x.split(split_size=(4 * self.ch, self.nc), dim=1)
|
||||
|
||||
a, b = self.dfl(box).chunk(2, 1)
|
||||
a = self.anchors.unsqueeze(0) - a
|
||||
b = self.anchors.unsqueeze(0) + b
|
||||
box = torch.cat(tensors=((a + b) / 2, b - a), dim=1)
|
||||
|
||||
return torch.cat(tensors=(box * self.strides, cls.sigmoid()), dim=1)
|
||||
|
||||
def initialize_biases(self):
|
||||
# Initialize biases
|
||||
# WARNING: requires stride availability
|
||||
for box, cls, s in zip(self.box, self.cls, self.stride):
|
||||
# box
|
||||
box[-1].bias.data[:] = 1.0
|
||||
# cls (.01 objects, 80 classes, 640 image)
|
||||
cls[-1].bias.data[: self.nc] = math.log(5 / self.nc / (640 / s) ** 2)
|
||||
|
||||
|
||||
class YOLO(torch.nn.Module):
|
||||
def __init__(self, width, depth, csp, num_classes):
|
||||
super().__init__()
|
||||
self.net = DarkNet(width, depth, csp)
|
||||
self.fpn = DarkFPN(width, depth, csp)
|
||||
|
||||
img_dummy = torch.zeros(1, width[0], 256, 256)
|
||||
self.head = Head(num_classes, (width[3], width[4], width[5]))
|
||||
self.head.stride = torch.tensor([256 / x.shape[-2] for x in self.forward(img_dummy)])
|
||||
self.stride = self.head.stride
|
||||
self.head.initialize_biases()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.net(x)
|
||||
x = self.fpn(x)
|
||||
return self.head(list(x))
|
||||
|
||||
def fuse(self):
|
||||
for m in self.modules():
|
||||
if type(m) is Conv and hasattr(m, "norm"):
|
||||
m.conv = fuse_conv(m.conv, m.norm)
|
||||
m.forward = m.fuse_forward
|
||||
delattr(m, "norm")
|
||||
return self
|
||||
|
||||
|
||||
def yolo_v11_n(num_classes: int = 80):
|
||||
csp = [False, True]
|
||||
depth = [1, 1, 1, 1, 1, 1]
|
||||
width = [3, 16, 32, 64, 128, 256]
|
||||
return YOLO(width, depth, csp, num_classes)
|
||||
|
||||
|
||||
def yolo_v11_t(num_classes: int = 80):
|
||||
csp = [False, True]
|
||||
depth = [1, 1, 1, 1, 1, 1]
|
||||
width = [3, 24, 48, 96, 192, 384]
|
||||
return YOLO(width, depth, csp, num_classes)
|
||||
|
||||
|
||||
def yolo_v11_s(num_classes: int = 80):
|
||||
csp = [False, True]
|
||||
depth = [1, 1, 1, 1, 1, 1]
|
||||
width = [3, 32, 64, 128, 256, 512]
|
||||
return YOLO(width, depth, csp, num_classes)
|
||||
|
||||
|
||||
def yolo_v11_m(num_classes: int = 80):
|
||||
csp = [True, True]
|
||||
depth = [1, 1, 1, 1, 1, 1]
|
||||
width = [3, 64, 128, 256, 512, 512]
|
||||
return YOLO(width, depth, csp, num_classes)
|
||||
|
||||
|
||||
def yolo_v11_l(num_classes: int = 80):
|
||||
csp = [True, True]
|
||||
depth = [2, 2, 2, 2, 2, 2]
|
||||
width = [3, 64, 128, 256, 512, 512]
|
||||
return YOLO(width, depth, csp, num_classes)
|
||||
|
||||
|
||||
def yolo_v11_x(num_classes: int = 80):
|
||||
csp = [True, True]
|
||||
depth = [2, 2, 2, 2, 2, 2]
|
||||
width = [3, 96, 192, 384, 768, 768]
|
||||
return YOLO(width, depth, csp, num_classes)
|
Reference in New Issue
Block a user