最近开始学pytorch相关的网络模型训练,作为基础练习训练一个关键点检测模型。
入坑看了这个项目 说明的还是很详细的
https://github.com/tensor-yu/PyTorch_Tutorial
里面的loss函数介绍以处理分类模型为主,而数据处理和读取,pytorch训练流程都讲的很全。因此我这次练点检测的模型只要稍微改一下数据读取和loss的计算即可。
我目前生成的数据保存在txt里,一行包含的信息为“图像路径+4个关键点的八个坐标”,因此在自定义数据模块只做了如下处理:(一般放在utils.py里定义自己的data类)
class MyDataset(Dataset):
def __init__(self, txt_path, transform = None, target_transform = None):
fh = open(txt_path, 'r')
imgs = []
for line in fh:
line = line.rstrip()
words = line.split()
points_tensor = torch.Tensor([[float(words[1]),float(words[2])],[float(words[3]),float(words[4])],[float(words[5]),float(words[6])],[float(words[7]),float(words[8])]])
imgs.append((words[0], points_tensor))
这里把8个数分成4组用于对应4个点 其他读取部分照抄教程里的例子即可。
接下来要选择loss函数,我这边就简单的采用模型输出的4个点和label里4个点欧氏距离的和作为loss,查完资料发现计算两点距离在torch里已经实现了,因此可以直接拿来用于loss的计算部分(教程里例子都是分类的一些loss,用不上了) 复杂一点的loss也要自己建一个loss的类,但是pytorch下欧氏距离的实现比较容易,这里就直接改main.py里的loss计算代码了
import torch.nn.functional as F
/******************教程中其他代码略*****************/
outputs = net(inputs) #输出网络推理结果,reshape后计算欧氏距离
outputs = outputs.reshape(-1,4,2)
loss = torch.sum(F.pairwise_distance(outputs,labels, p=2))
loss.backward()
optimizer.step()
最后是模型的选择,由于我要实现的点图像比较单一,就把教程例子里的RESNET砍一砍拿来用了,这里只列出动过的部分(resnet.py可以自己在model文件夹里创建一个,在main.py里调用即可)
class ResidualBlock(nn.Module):
/***************略************/
class ResNets(BasicModule):
def __init__(self, num_classes=8): #输出8个数代表四个坐标
super(ResNets, self).__init__()
self.model_name = 'resnets'
# 前几层: 图像转换
self.pre = nn.Sequential(
nn.Conv2d(3, 32, 7, 2, 3, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, 2, 1))
# 重复的layer,residual block都被我砍成1个了
self.layer1 = self._make_layer( 32, 64, 1)
self.layer2 = self._make_layer( 64, 128, 1, stride=2)
self.layer3 = self._make_layer( 128, 256, 1, stride=2)
self.layer4 = self._make_layer( 256, 256, 1, stride=2)
#分类用的全连接
self.fc = nn.Linear(256, num_classes)
def _make_layer(self, inchannel, outchannel, block_num, stride=1):
/***************略************/
def forward(self, x):
/***************略************/
其他:lr和优化器optimizer都可以照抄不动,上述修改完后再去除main.py里和分类模型有关的计算和测试代码就可以跑了。