nn.Module是所有神经单元的基类
pytorch在nn.Module中实现了__call__,而在__call__方法中调用了forward函数
(内容来自于pytorch系列 ----暂时就叫5的番外吧: nn.Modlue及nn.Linear 源码理解_墨氲的博客-优快云博客pytorch系列 ----暂时就叫5的番外吧: nn.Modlue及nn.Linear 源码理解_墨氲的博客-优快云博客pytorch系列 ----暂时就叫5的番外吧: nn.Modlue及nn.Linear 源码理解_墨氲的博客-优快云博客)
因为每一个类都继承父类
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
)
print(model)
-----------------------------
Sequential(
(0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
(1): ReLU()
)
Sequential继承了nn.Module父类,调用了Sequential的forward函数,将层添加进去,print(model)就可输出,Sequential也为第一种模型定义
还有一个是,使用OrderedDict,OrderedDict为按照有序插入顺序储存的有序字典,也可按照key, val进行排序,详情这篇文章
第二种:nn.ModuleList()
net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10)) # # 类似List的append操作
print(net[-1]) # 类似List的索引访问
print(net)
----------------------------
Linear(in_features=256, out_features=10, bias=True)
ModuleList(
(0): Linear(in_features=784, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=10, bias=True)
)
第三种为:nn.ModuleDict(),nn.ModuleDict能更方便为神经网络添加名称
net = nn.ModuleDict({
'linear': nn.Linear(784, 256),
'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 访问
print(net.output)
print(net)
-------------------------------
Linear(in_features=784, out_features=256, bias=True)
Linear(in_features=256, out_features=10, bias=True)
ModuleDict(
(act): ReLU()
(linear): Linear(in_features=784, out_features=256, bias=True)
(output): Linear(in_features=256, out_features=10, bias=True)
)
源代码前点此
二:修改模型
以一个场景为例,为了增加特征图的所蕴含的信息,可以将预处理的图像进行与神经网络卷积所得的特征图进行相加
add = image # 预处理的图片
transf = transforms.ToTensor()
img_tensor = transf(img)
class Model(nn.Model):
def __int__(self, inp, out):
super(Model,self).__int__()
self.conv1 = nn.Conv2d(inp, out, 1, 1, bias=False)
self.conv2 = nn.Conv2d(inp, out, 1, 1, bias=False)
self.bn = nn.BatchNorm2d(2 * out)
def forward(self, x, add):
x = self.conv1(x)
x1 = torch.cat((self.conv2(x), add.unsqueeze(1)), 1)
x = self.bn(x1)
return x, x1
model = Model(inp, out)
outputs, outx1 = model(inputs, img_tensor )
暂时学到这些,后续学到新的在更新