import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
建立卷积模块,卷积模块的顺序:BN——>ReLU——>Conv
def conv_block(in_channel, out_channel):
layer = nn.Sequential(
nn.BatchNorm2d(in_channel),
nn.ReLU(True),
nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=False)
)
return layer
建立Dense模块,每次卷积的输出为growth_rate, growth_rate用于:每次输出channel等于上一个输入in_channel+growth_rate
class dense_block(nn.Module):
def __init__(self, in_channel, growth_rate, num_layers):
super(dense_block, self).__init__()
block = []
channel = in_channel
for i in range(num_layers):
block.append(conv_block(channel, growth_rate))
channel += growth_rate
self.net = nn.Sequential(*block)
def forward(self, x):
for layer in self.net:
out = layer(x)
x = torch.cat((out, x), dim=1)
return x
以下代码主要实现densenet的特点,这样设置才能保证每层输出与下一层输入的channel数相同
out = layer(x)
x = torch.cat((out, x), dim=1)
test_net = dense_block(3, 12, 3)
test_x = Variable(torch.zeros(1, 3, 96, 96))
print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))
test_y = test_net(test_x)
print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape

本文详细介绍了如何使用PyTorch实现DenseNet网络,包括建立卷积和Dense模块,确保每层输出与下一层输入的channel数相同。通过添加过渡层来控制通道数的增长,以减少参数量并保持模型效率。最后,展示了DenseNet模型的定义过程。
最低0.47元/天 解锁文章
2067

被折叠的 条评论
为什么被折叠?



