深度学习小白实现残差网络resnet18 ——pytorch
利用闲暇时间写了resnet18 的实现代码,可能存在错误,看官可以给与指正。
pytorch中给与了resnet的实现模型,可以供小白调用,这里不赘述方法。下面所有代码的实现都是使用pytorch框架书写,采用python语言。
网络上搜索到的resne18的网络结构图如下。resnet18只看图中左侧网络结构就可以。(ps:使用的是简书上一个博主的图,如有冒犯,请谅解)
接下来,根据如图的网络结构进行搭建网络。通过观察网络结构,发现在网络结构中存在两种不同基础块,第一种是实现标注跳跃连接的部分,如下:在这个块中具体实现工作流程如下图:
实现方法如下:
import torch.nn as nn
import torch.nn.functional as F
class basic_block(nn.Module):
'''定义了带实线部分的残差块'''
def __init__(self,in_channels):
super(basic_block, self).__init__()
self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=1,padding=1)
self.conv2 = nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=1,padding=1