1.基本组成模块
1.1 Foucs模块
# --------------------------Foucs模块-------------——-
class Foucs(nn.Module):
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups =1普通卷积 >1深度可分离卷积
super(Foucs, self).__init__()
self.conv = Conv(c1*4, c2, k, s, p, g, act)
def forward(self, x):
# focus层:每隔一个像素取一个,然后进行堆叠,每个通道转变为四个特征层
# torch.cat((a,b,c),dim) dim为几按第几为拼接
# (batch,3,640,640)——》(batch,12,320,320)
img = torch.cat(
[
x[..., ::2, ::2],
x[..., 1::2, ::2],
x[..., ::2, 1::2],
x[..., 1::2, 1::2]
], 1
)
# 第一层卷积
img = self.conv(img)
return img
1.2 Conv模块
# -------------------------Conv模块------------------
# Conv这个函数是整个网络中最基础的组件,由卷积层 + BN层 + 激活函数 组成
class Conv(nn.Module):
def __init__(self, c1, c2, k, s, p=None, g=1, act=True):
super(Conv, self).__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) # 卷积层
self.bn = nn.BatchNorm