前言
\quad 之前在Keras下训练Unet十分方便,但是想要平台移植和嵌入到C++代码却成为了一个很困难的问题,我花费了几天时间完成了Caffe版本的Unet在Windows下的训练,以及利用VS2015进行前向推理的过程,由于这个过程在网络上没有任何资料,所以打算将这个过程分享一下。
步骤
- 编译windows下的caffe并编译python接口,大家可以用BVLC下windows分支的caffe代码,使用scripts/build_win.cmd来编译,这里需要配置cmake和python环境,如果使用gpu的话还需要配置cuda8.0和cudnn5.1,环境配置很无脑,就不说了,自行配置即可。
- 制作数据集,准备一份原始图片,然后mask图片(通过labelme把json文件转成png的图片),注意由于python脚本里面img.txt只存图片的名字,要求原始图片和mask图片的名字一样,且后缀名也一样,所以需要把labelme转出来的mask图片png格式转为jpg格式。最后得到的数据集就是3个东西,一是原始图片,一个是mask图片(名字和原始图片对应),一个img.txt,这个保存图片名字,例如000001.jpg,每行保存一张图片名字,多少张图片就多少行,这个可以自己写一个简单的脚本实现。
- 下载caffe版本unet的网络结构,https://github.com/warden3344/unet, 使用这里面的train_val.prototxt进行训练,这个prototxt需要更改的地方就是最后的分类数,我是两个分类,所以score层的num_output=2。然后输入层是python实现的,克隆这个工程下面有一个mydatalayer.py,默认是一个分类,由于我们是二分类需要更改一下代码,数据集3个东西的路径自己相应更改即可,如下:
import caffe
import numpy as np
import cv2
import numpy.random as random
class DataLayer(caffe.Layer):
def setup(self, bottom, top):
self.imgdir = "/home/pic/zxy-project/caffe-unet/data/image/"
self.maskdir = "/home/pic/zxy-project/caffe-unet/data/mask_png/"
self.imgtxt = "/home/pic/zxy-project/caffe-unet/data/img.txt"
self.random = True
self.seed = None
if len(top) != 2:
raise Exception("Need to define two tops: data and mask.")
if len(bottom) != 0:
raise Exception("Do not define a bottom.")
self.lines = open(self.imgtxt, 'r').readlines()
self.idx = 0
if self.random:
random.seed(self.seed)
self.idx = random.randint(0, len(self.lines) - 1)
def reshape(self, bottom, top):
# load image + label image pair
self.data = self.load_image(self.idx)
self.mask = self.load_mask(self.idx)
# reshape tops to fit (leading 1 is for batch dimension)
top[0].reshape(1, *self.data.shape)
top[1].reshape(1, *self.mask.shape)
def forward(self, bottom, top):
# assign output
top[0].data[...] = self.data
top[1].data[...] = self.mask
# pick next input
if self.random:
self.idx = random.randint(0, len(self.lines) - 1)
else:
self.idx += 1
if self.idx == len(self.lines):
self.idx = 0
def backward(self, top, propagate_down, bottom):
pass
def load_image(self, idx):
imname = self.imgdir + self.lines[idx]
imname = imname[:-1]
#print 'load img %s' %imname
im = cv2.imread(imname)
#im = cv2.imread(imname)
#print im.shape
im = cv2.resize(im,(512,512))
im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
#im = np.arra

最低0.47元/天 解锁文章
1335

被折叠的 条评论
为什么被折叠?



