yolo v5 构建
参考 https://zhuanlan.zhihu.com/p/242456389
参考 https://github.com/ultralytics/yolov5.git
import torch
from torch import nn
def autopad(k, p=None): # kernel, padding
# Pad to 'same'
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
class Conv(nn.Module):
# Standard convolution
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super(Conv, self).__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def fuseforward(self, x):
return self.act(self.conv(x))
class Focus(nn.Module):
# Focus wh information into c-space
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super(Focus, self).__init__()
self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
# self.contract = Contract(gain=2)
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
class Bottleneck(nn.Module):
# Standard bottleneck
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
super(Bottleneck, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_, c2, 3, 1, g=g)
self.add = shortcut and c1 == c2
def forward(self, x):
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
class BottleneckCSP(nn.Module):
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super(BottleneckCSP, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
self.cv4 = Conv(2 * c_, c2, 1, 1)
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
self.act = nn.LeakyReLU(0.1, inplace=True)
self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
def forward(self, x):
y1 = self.cv3(self.m(self.cv1(x)))
y2 = self.cv2(x)
return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
class SPP(nn.Module):
# Spatial pyramid pooling layer used in YOLOv3-SPP
def __init__(self, c1, c2, k=(5, 9, 13)):
super(SPP, self).__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
def forward(self, x):
x = self.cv1(x)
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
class Detect(nn.Module):
stride = None # strides computed during build
onnx_dynamic = False # ONNX export parameter
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
super(Detect, self).__init__()
self.nc = nc # number of classes
self.no = nc + 5 # number of outputs per anchor
self.nl = len(anchors) # number of detection layers
self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [torch.zeros(1)] * self.nl # init grid
a = torch.tensor(anchors).float().view(self.nl, -1, 2)
self.register_buffer('anchors', a) # shape(nl,na,2)
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
self.inplace = inplace # use in-place ops (e.g. slice assignment)
def forward(self, x):
# x = x.copy() # for profiling
z = [] # inference output
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
y = x[i].sigmoid()
if self.inplace:
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.no))
return x if self.training else (torch.cat(z, 1), x)
def _make_grid(nx=20, ny=20):
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
def check_anchor_order(m):
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
a = m.anchor_grid.prod(-1).view(-1) # anchor area
da = a[-1] - a[0] # delta a
ds = m.stride[-1] - m.stride[0] # delta s
if da.sign() != ds.sign(): # same order
print('Reversing anchor order')
m.anchors[:] = m.anchors.flip(0)
m.anchor_grid[:] = m.anchor_grid.flip(0)
class YoloV5(nn.Module):
def __init__(self, num_cls=80, ch=3, anchors=None):
super(YoloV5, self).__init__()
assert anchors != None, 'anchor must be provided'
# divid by
cd = 2
wd = 3
self.focus = Focus(ch, 64 // cd) # 3 >>> 32
self.conv1 = Conv(64 // cd, 128 // cd, 3, 2) # 32 >>> 64
self.csp1 = BottleneckCSP(128 // cd, 128 // cd, n=3 // wd) # 64
self.conv2 = Conv(128 // cd, 256 // cd, 3, 2) # 64 >>> 128
self.csp2 = BottleneckCSP(256 // cd, 256 // cd, n=9 // wd) # 128
self.conv3 = Conv(256 // cd, 512 // cd, 3, 2) # 128 >>> 256
self.csp3 = BottleneckCSP(512 // cd, 512 // cd, n=9 // wd) # 256
self.conv4 = Conv(512 // cd, 1024 // cd, 3, 2) # 256 >>> 512
self.spp = SPP(1024 // cd, 1024 // cd) # 512
self.csp4 = BottleneckCSP(1024 // cd, 1024 // cd, n=3 // wd, shortcut=False) # 512
# PANet
self.conv5 = Conv(1024 // cd, 512 // cd) # 512 >>> 256
self.up1 = nn.Upsample(scale_factor=2) # 256 Upsample
self.csp5 = BottleneckCSP(1024 // cd, 512 // cd, n=3 // wd, shortcut=False) # 256
self.conv6 = Conv(512 // cd, 256 // cd) # 256 >>> 128
self.up2 = nn.Upsample(scale_factor=2) # 128 Upsample
self.csp6 = BottleneckCSP(512 // cd, 256 // cd, n=3 // wd, shortcut=False) # 128
self.conv7 = Conv(256 // cd, 256 // cd, 3, 2) # 128
self.csp7 = BottleneckCSP(512 // cd, 512 // cd, n=3 // wd, shortcut=False) # 128 >>> 256
self.conv8 = Conv(512 // cd, 512 // cd, 3, 2) # 256 >>> 256
self.csp8 = BottleneckCSP(512 // cd, 1024 // cd, n=3 // wd, shortcut=False) # 512
self.detect=Detect(nc=80, anchors=anchors,ch=[128, 256, 512])
s = 256 # 2x min stride
# detect
self.detect.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))])
def _build_backbone(self, x):
x = self.focus(x)
x = self.conv1(x)
x = self.csp1(x)
x_p3 = self.conv2(x) # P3
x = self.csp2(x_p3)
x_p4 = self.conv3(x) # P4
x = self.csp3(x_p4)
x_p5 = self.conv4(x) # P5
x = self.spp(x_p5)
x = self.csp4(x)
return x_p3, x_p4, x_p5, x
def _build_head(self, p3, p4, p5, feas):
p3 128
p4 256
p5 512
feas 512
h_p5 = self.conv5(feas) # head P5
x = self.up1(h_p5)
x_concat = torch.cat([x, p4], dim=1)
x = self.csp5(x_concat)
h_p4 = self.conv6(x) # head P4
x = self.up2(h_p4)
x_concat = torch.cat([x, p3], dim=1)
x_small = self.csp6(x_concat)
x = self.conv7(x_small)
x_concat = torch.cat([x, h_p4], dim=1)
x_medium = self.csp7(x_concat)
x = self.conv8(x_medium)
x_concat = torch.cat([x, h_p5], dim=1)
x_large = self.csp8(x)
return x_small, x_medium, x_large
def forward(self, x):
p3, p4, p5, feas = self._build_backbone(x)
xs, xm, xl = self._build_head(p3, p4, p5, feas)
res=self.detect([xs, xm, xl])
return res
if __name__ == "__main__":
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
inputimg = torch.randn([1, 3, 640, 640]).to(device)
anchors = [[10,13, 16,30, 33,23],
[30,61, 62,45, 59,119],
[116,90, 156,198, 373,326]]
model = YoloV5(num_cls=80, ch=3, anchors=anchors).to(device)
for name, parameters in model.named_parameters():
# print(name, ':', parameters.size())
focus.conv.conv.weight 384 torch.Size([32, 12, 1, 1])
focus.conv.bn.weight 32 torch.Size([32])
focus.conv.bn.bias 32 torch.Size([32])
conv1.conv.weight 18432 torch.Size([64, 32, 3, 3])
conv1.bn.weight 64 torch.Size([64])
conv1.bn.bias 64 torch.Size([64])
csp1.cv1.conv.weight 2048 torch.Size([32, 64, 1, 1])
csp1.cv1.bn.weight 32 torch.Size([32])
csp1.cv1.bn.bias 32 torch.Size([32])
csp1.cv2.weight 2048 torch.Size([32, 64, 1, 1])
csp1.cv3.weight 1024 torch.Size([32, 32, 1, 1])
csp1.cv4.conv.weight 4096 torch.Size([64, 64, 1, 1])
csp1.cv4.bn.weight 64 torch.Size([64])
csp1.cv4.bn.bias 64 torch.Size([64])
csp1.bn.weight 64 torch.Size([64])
csp1.bn.bias 64 torch.Size([64])
csp1.m.0.cv1.conv.weight 1024 torch.Size([32, 32, 1, 1])
csp1.m.0.cv1.bn.weight 32 torch.Size([32])
csp1.m.0.cv1.bn.bias 32 torch.Size([32])
csp1.m.0.cv2.conv.weight 9216 torch.Size([32, 32, 3, 3])
csp1.m.0.cv2.bn.weight 32 torch.Size([32])
csp1.m.0.cv2.bn.bias 32 torch.Size([32])
conv2.conv.weight 73728 torch.Size([128, 64, 3, 3])
conv2.bn.weight 128 torch.Size([128])
conv2.bn.bias 128 torch.Size([128])
csp2.cv1.conv.weight 8192 torch.Size([64, 128, 1, 1])
csp2.cv1.bn.weight 64 torch.Size([64])
csp2.cv1.bn.bias 64 torch.Size([64])
csp2.cv2.weight 8192 torch.Size([64, 128, 1, 1])
csp2.cv3.weight 4096 torch.Size([64, 64, 1, 1])
csp2.cv4.conv.weight 16384 torch.Size([128, 128, 1, 1])
csp2.cv4.bn.weight 128 torch.Size([128])
csp2.cv4.bn.bias 128 torch.Size([128])
csp2.bn.weight 128 torch.Size([128])
csp2.bn.bias 128 torch.Size([128])
csp2.m.0.cv1.conv.weight 4096 torch.Size([64, 64, 1, 1])
csp2.m.0.cv1.bn.weight 64 torch.Size([64])
csp2.m.0.cv1.bn.bias 64 torch.Size([64])
csp2.m.0.cv2.conv.weight 36864 torch.Size([64, 64, 3, 3])
csp2.m.0.cv2.bn.weight 64 torch.Size([64])
csp2.m.0.cv2.bn.bias 64 torch.Size([64])
csp2.m.1.cv1.conv.weight 4096 torch.Size([64, 64, 1, 1])
csp2.m.1.cv1.bn.weight 64 torch.Size([64])
csp2.m.1.cv1.bn.bias 64 torch.Size([64])
csp2.m.1.cv2.conv.weight 36864 torch.Size([64, 64, 3, 3])
csp2.m.1.cv2.bn.weight 64 torch.Size([64])
csp2.m.1.cv2.bn.bias 64 torch.Size([64])
csp2.m.2.cv1.conv.weight 4096 torch.Size([64, 64, 1, 1])
csp2.m.2.cv1.bn.weight 64 torch.Size([64])
csp2.m.2.cv1.bn.bias 64 torch.Size([64])
csp2.m.2.cv2.conv.weight 36864 torch.Size([64, 64, 3, 3])
csp2.m.2.cv2.bn.weight 64 torch.Size([64])
csp2.m.2.cv2.bn.bias 64 torch.Size([64])
conv3.conv.weight 294912 torch.Size([256, 128, 3, 3])
conv3.bn.weight 256 torch.Size([256])
conv3.bn.bias 256 torch.Size([256])
csp3.cv1.conv.weight 32768 torch.Size([128, 256, 1, 1])
csp3.cv1.bn.weight 128 torch.Size([128])
csp3.cv1.bn.bias 128 torch.Size([128])
csp3.cv2.weight 32768 torch.Size([128, 256, 1, 1])
csp3.cv3.weight 16384 torch.Size([128, 128, 1, 1])
csp3.cv4.conv.weight 65536 torch.Size([256, 256, 1, 1])
csp3.cv4.bn.weight 256 torch.Size([256])
csp3.cv4.bn.bias 256 torch.Size([256])
csp3.bn.weight 256 torch.Size([256])
csp3.bn.bias 256 torch.Size([256])
csp3.m.0.cv1.conv.weight 16384 torch.Size([128, 128, 1, 1])
csp3.m.0.cv1.bn.weight 128 torch.Size([128])
csp3.m.0.cv1.bn.bias 128 torch.Size([128])
csp3.m.0.cv2.conv.weight 147456 torch.Size([128, 128, 3, 3])
csp3.m.0.cv2.bn.weight 128 torch.Size([128])
csp3.m.0.cv2.bn.bias 128 torch.Size([128])
csp3.m.1.cv1.conv.weight 16384 torch.Size([128, 128, 1, 1])
csp3.m.1.cv1.bn.weight 128 torch.Size([128])
csp3.m.1.cv1.bn.bias 128 torch.Size([128])
csp3.m.1.cv2.conv.weight 147456 torch.Size([128, 128, 3, 3])
csp3.m.1.cv2.bn.weight 128 torch.Size([128])
csp3.m.1.cv2.bn.bias 128 torch.Size([128])
csp3.m.2.cv1.conv.weight 16384 torch.Size([128, 128, 1, 1])
csp3.m.2.cv1.bn.weight 128 torch.Size([128])
csp3.m.2.cv1.bn.bias 128 torch.Size([128])
csp3.m.2.cv2.conv.weight 147456 torch.Size([128, 128, 3, 3])
csp3.m.2.cv2.bn.weight 128 torch.Size([128])
csp3.m.2.cv2.bn.bias 128 torch.Size([128])
conv4.conv.weight 1179648 torch.Size([512, 256, 3, 3])
conv4.bn.weight 512 torch.Size([512])
conv4.bn.bias 512 torch.Size([512])
spp.cv1.conv.weight 131072 torch.Size([256, 512, 1, 1])
spp.cv1.bn.weight 256 torch.Size([256])
spp.cv1.bn.bias 256 torch.Size([256])
spp.cv2.conv.weight 524288 torch.Size([512, 1024, 1, 1])
spp.cv2.bn.weight 512 torch.Size([512])
spp.cv2.bn.bias 512 torch.Size([512])
csp4.cv1.conv.weight 131072 torch.Size([256, 512, 1, 1])
csp4.cv1.bn.weight 256 torch.Size([256])
csp4.cv1.bn.bias 256 torch.Size([256])
csp4.cv2.weight 131072 torch.Size([256, 512, 1, 1])
csp4.cv3.weight 65536 torch.Size([256, 256, 1, 1])
csp4.cv4.conv.weight 262144 torch.Size([512, 512, 1, 1])
csp4.cv4.bn.weight 512 torch.Size([512])
csp4.cv4.bn.bias 512 torch.Size([512])
csp4.bn.weight 512 torch.Size([512])
csp4.bn.bias 512 torch.Size([512])
csp4.m.0.cv1.conv.weight 65536 torch.Size([256, 256, 1, 1])
csp4.m.0.cv1.bn.weight 256 torch.Size([256])
csp4.m.0.cv1.bn.bias 256 torch.Size([256])
csp4.m.0.cv2.conv.weight 589824 torch.Size([256, 256, 3, 3])
csp4.m.0.cv2.bn.weight 256 torch.Size([256])
csp4.m.0.cv2.bn.bias 256 torch.Size([256])
conv5.conv.weight 131072 torch.Size([256, 512, 1, 1])
conv5.bn.weight 256 torch.Size([256])
conv5.bn.bias 256 torch.Size([256])
csp5.cv1.conv.weight 65536 torch.Size([128, 512, 1, 1])
csp5.cv1.bn.weight 128 torch.Size([128])
csp5.cv1.bn.bias 128 torch.Size([128])
csp5.cv2.weight 65536 torch.Size([128, 512, 1, 1])
csp5.cv3.weight 16384 torch.Size([128, 128, 1, 1])
csp5.cv4.conv.weight 65536 torch.Size([256, 256, 1, 1])
csp5.cv4.bn.weight 256 torch.Size([256])
csp5.cv4.bn.bias 256 torch.Size([256])
csp5.bn.weight 256 torch.Size([256])
csp5.bn.bias 256 torch.Size([256])
csp5.m.0.cv1.conv.weight 16384 torch.Size([128, 128, 1, 1])
csp5.m.0.cv1.bn.weight 128 torch.Size([128])
csp5.m.0.cv1.bn.bias 128 torch.Size([128])
csp5.m.0.cv2.conv.weight 147456 torch.Size([128, 128, 3, 3])
csp5.m.0.cv2.bn.weight 128 torch.Size([128])
csp5.m.0.cv2.bn.bias 128 torch.Size([128])
conv6.conv.weight 32768 torch.Size([128, 256, 1, 1])
conv6.bn.weight 128 torch.Size([128])
conv6.bn.bias 128 torch.Size([128])
csp6.cv1.conv.weight 16384 torch.Size([64, 256, 1, 1])
csp6.cv1.bn.weight 64 torch.Size([64])
csp6.cv1.bn.bias 64 torch.Size([64])
csp6.cv2.weight 16384 torch.Size([64, 256, 1, 1])
csp6.cv3.weight 4096 torch.Size([64, 64, 1, 1])
csp6.cv4.conv.weight 16384 torch.Size([128, 128, 1, 1])
csp6.cv4.bn.weight 128 torch.Size([128])
csp6.cv4.bn.bias 128 torch.Size([128])
csp6.bn.weight 128 torch.Size([128])
csp6.bn.bias 128 torch.Size([128])
csp6.m.0.cv1.conv.weight 4096 torch.Size([64, 64, 1, 1])
csp6.m.0.cv1.bn.weight 64 torch.Size([64])
csp6.m.0.cv1.bn.bias 64 torch.Size([64])
csp6.m.0.cv2.conv.weight 36864 torch.Size([64, 64, 3, 3])
csp6.m.0.cv2.bn.weight 64 torch.Size([64])
csp6.m.0.cv2.bn.bias 64 torch.Size([64])
conv7.conv.weight 147456 torch.Size([128, 128, 3, 3])
conv7.bn.weight 128 torch.Size([128])
conv7.bn.bias 128 torch.Size([128])
csp7.cv1.conv.weight 32768 torch.Size([128, 256, 1, 1])
csp7.cv1.bn.weight 128 torch.Size([128])
csp7.cv1.bn.bias 128 torch.Size([128])
csp7.cv2.weight 32768 torch.Size([128, 256, 1, 1])
csp7.cv3.weight 16384 torch.Size([128, 128, 1, 1])
csp7.cv4.conv.weight 65536 torch.Size([256, 256, 1, 1])
csp7.cv4.bn.weight 256 torch.Size([256])
csp7.cv4.bn.bias 256 torch.Size([256])
csp7.bn.weight 256 torch.Size([256])
csp7.bn.bias 256 torch.Size([256])
csp7.m.0.cv1.conv.weight 16384 torch.Size([128, 128, 1, 1])
csp7.m.0.cv1.bn.weight 128 torch.Size([128])
csp7.m.0.cv1.bn.bias 128 torch.Size([128])
csp7.m.0.cv2.conv.weight 147456 torch.Size([128, 128, 3, 3])
csp7.m.0.cv2.bn.weight 128 torch.Size([128])
csp7.m.0.cv2.bn.bias 128 torch.Size([128])
conv8.conv.weight 589824 torch.Size([256, 256, 3, 3])
conv8.bn.weight 256 torch.Size([256])
conv8.bn.bias 256 torch.Size([256])
csp8.cv1.conv.weight 65536 torch.Size([256, 256, 1, 1])
csp8.cv1.bn.weight 256 torch.Size([256])
csp8.cv1.bn.bias 256 torch.Size([256])
csp8.cv2.weight 65536 torch.Size([256, 256, 1, 1])
csp8.cv3.weight 65536 torch.Size([256, 256, 1, 1])
csp8.cv4.conv.weight 262144 torch.Size([512, 512, 1, 1])
csp8.cv4.bn.weight 512 torch.Size([512])
csp8.cv4.bn.bias 512 torch.Size([512])
csp8.bn.weight 512 torch.Size([512])
csp8.bn.bias 512 torch.Size([512])
csp8.m.0.cv1.conv.weight 65536 torch.Size([256, 256, 1, 1])
csp8.m.0.cv1.bn.weight 256 torch.Size([256])
csp8.m.0.cv1.bn.bias 256 torch.Size([256])
csp8.m.0.cv2.conv.weight 589824 torch.Size([256, 256, 3, 3])
csp8.m.0.cv2.bn.weight 256 torch.Size([256])
csp8.m.0.cv2.bn.bias 256 torch.Size([256])
detect.0.m.0.weight 32640 torch.Size([255, 128, 1, 1])
detect.0.m.0.bias 255 torch.Size([255])
detect.0.m.1.weight 65280 torch.Size([255, 256, 1, 1])
detect.0.m.1.bias 255 torch.Size([255])
detect.0.m.2.weight 130560 torch.Size([255, 512, 1, 1])
detect.0.m.2.bias 255 torch.Size([255])
