resnet网络即为多个残差网络进行堆叠形成,主要是保证第n+1层网络的特征图所含特征比第n层多。
这里主要实现的是resnet101网络的迁移学习。
kaggle猴子分类的数据源:https://pan.baidu.com/s/1VVixGKyafn2qr9nC-oqSdA 提取码:lz1p
首先是构建resnet101网络
残差网络的构建
其中expansion=4是查看 输入的通道数 * 4 == 最后输出的通道数 是否成立,否则则进行1*1的卷积进行升维。
class Bottleneck(nn.Module):
expansion=4
def __init__(self,in_channel,out_channel,stride=1,downsample=None):
super(Bottleneck, self).__init__()
self.conv1=nn.Conv2d(in_channels=in_channel,out_channels=out_channel,
kernel_size=1,stride=1,bias=False)
self.bn1=nn.BatchNorm2d(out_channel)
#----
self.conv2=nn.Conv2d(in_channels=out_channel,out_channels=out_channel,
kernel_size=3,stride=stride,bias=False,padding=1)
self.bn2=nn.BatchNorm2d(out_channel)
#----
self.conv3=nn.Conv2d(in_channels=out_channel,out_channels=out_channel*self.expansion,
kernel_size=1,stride=1,bias=False)
self.bn3=nn.BatchNorm2d(out_channel*self.expansion)
self.relu=nn.ReLU(inplace=True)
self.downsample=downsample
def forward(self,x):
identity=x
if self.downsample is not None:
identity=self.downsample(x)
out=self.conv1(x)
out=self.bn1(out)
out=self.relu(out)
out=self.conv2(out)
out=self.bn2(out)
out=self.relu(out)
out=self.conv3(out)
out=self.bn3(out)
out+=identity
out=self.relu(out)
return out
module.py文件
import torch
import torch.nn as nn
#定义resnet50、101的残差网络
class Bottleneck(nn.Module):
expansion=4
def __init__(self,in_channel,out_channel,stride=1,downsample=None):
super(Bottleneck, self).__init__()
self.conv1=nn.Conv2d(in_channels=in_channel,out_channels=out_channel,
kernel_size=1,stride=1,bias=False)
self.bn1=nn.BatchNorm2d(out_channel)
#----
self.conv2=nn.Conv2d(in_channels=out_channel,out_channels=out_channel,
kernel_size=3,stride=stride,bias=False,padding=1)
self.bn2=nn.BatchNorm2d(out_channel)
#----
self.conv3=nn.Conv2d(in_channels=out_channel,out_channels=out_channel*self.expansion,
kernel_size=1,stride=1,bias=False)
self.bn3=nn.BatchNorm2d(out_channel*self.expansion)
self.relu=nn.ReLU(inplace=True)
self.downsample=downsample
def forward(self,x):
identity=x
if self.downsample is not None:
identity=self.downsample(x)
out=self.conv1(x)
out=self.bn1(out)
out=self.relu(out)
out=self.conv2(out)
out=self.bn2(out)
out=self.relu(out)
out=self.conv3(out)
out=self.bn3(out)
out+=identity
out=self.relu(out)
return out
class resnet(nn.Module):
def __init__(self,block,blocks_num,num_classes=1000,include_top=True):
super(resnet,self).__init__()
self.include_top=include_top
self.in_channel=64
self.conv1=nn.Conv2d(3,self.in_channel,kernel_size=7,stride=2,padding=3,bias=False)
self.bn1=nn.BatchNorm2d(self.in_channel)
self.relu=nn.ReLU(inplace=True)
self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
self.layer1=self._make_layer(block,64,blocks_num[0])
#经过layer1后的in_channel=256
self.layer2=self._make_layer(block,128,blocks_num[1],stride=2)
self.layer3=self._make_layer(block,256,blocks_num[2],stride=2)
self.layer4=self._make_layer(block,512,blocks_num[3],stride=2)
if self.include_top:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def _make_layer(self,block,channel,block_num,stride=1):
downsample=None
if stride!=1 or self.in_channel!=channel*block.expansion:
downsample=nn.Sequential(
nn.Conv2d(self.in_channel,channel*block.expansion,kernel_size=1,stride=stride,bias=False),
nn.BatchNorm2d(channel*block.expansion)
)
layers=[]
#加入基础网络
layers.append(block(self.in_channel,channel,downsample=downsample,stride=stride))
self.in_channel=channel*block.expansion
#每个基础网络的个数
for _ in range(1,block_num):
layers.append(block(self.in_channel,channel))
return nn.Sequential(*layers)
def forward(self,x):
x=self.conv1(x)
x=self.bn1(x)
x=self.relu(x)
x=self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.include_top:
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def resnet101(num_classes=1000, include_top=True):
return resnet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)
if __name__ == '__main__':
print(resnet101())
train.py文件
resnet模型参数下载:https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model import resnet34
from model import resnet101
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "training"),
transform=data_transform["train"])
train_num = len(train_dataset)
batch_size = 4
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "validation"),
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=nw)
print("using {} images for training, {} images for validation.".format(train_num,
val_num))
net=resnet101()
# load pretrain weights
model_weight_path = "./model_p/resnet101-5d3b4d8f.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
net.load_state_dict(torch.load(model_weight_path, map_location=device))
# change fc layer structure
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 10)
net.to(device)
# define loss function
loss_function = nn.CrossEntropyLoss()
# construct an optimizer
params = [p for p in net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)
epochs = 3
best_acc = 0.0
save_path = './model_p/monkey_ResNet101.pth'
train_steps = len(train_loader)
for epoch in range(epochs):
# train
net.train()
running_loss = 0.0
train_bar = tqdm(train_loader)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
logits = net(images.to(device))
loss = loss_function(logits, labels.to(device))
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
epochs,
loss)
# validate
net.eval()
acc = 0.0 # accumulate accurate number / epoch
with torch.no_grad():
val_bar = tqdm(validate_loader)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
# loss = loss_function(outputs, test_labels)
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
epochs)
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
print('Finished Training')
if __name__ == '__main__':
main()