来自:https://blog.youkuaiyun.com/awyyauqpmy/article/details/79290710
修改之后:
1、建好文件夹:
deform:存放原始训练(image+label)数据集:下载images:https://github.com/decouples/Unet
- train:30张(.tif)
- label:30张(.tif)
my_train_test:程序文件夹(3个.py(将图中的test.py移至上层路径下))+生成模型(.hdf5)的文件夹:
npydata:存放生成数据(代码从此文件夹中读入image、label、test数据)
result:存放测试结果
test:存放原始测试数据集(30个.tif)
完整文件结构:将图中的test.py移至上层路径下
2、生成训练数据
在my_train_test下运行creat_imagelabel_npy.py:生成imgs_train.npy和imgs_mask_train.npy
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
import numpy as np
import os
import glob
import cv2
from libtiff import TIFF
class myAugmentation(object):
"""
一个用于图像增强的类:
首先:分别读取训练的图片和标签,然后将图片和标签合并用于下一个阶段使用
然后:使用Keras的预处理来增强图像
最后:将增强后的图片分解开,分为训练图片和训练标签
"""
def __init__(self, train_path="train", label_path="label", merge_path="merge", aug_merge_path="aug_merge", aug_train_path="aug_train", aug_label_path="aug_label", img_type="tif"):
"""
使用glob从路径中得到所有的“.img_type”文件,初始化类:__init__()
"""
self.train_imgs = glob.glob(train_path+"/*."+img_type)
self.label_imgs = glob.glob(label_path+"/*."+img_type)
self.train_path = train_path
self.label_path = label_path
self.merge_path = merge_path
self.img_type = img_type
self.aug_merge_path = aug_merge_path
self.aug_train_path = aug_train_path
self.aug_label_path = aug_label_path
self.slices = len(self.train_imgs)
self.datagen = ImageDataGenerator(
rotation_range=0.2,
width_shift_range=0.05,
height_shift_range=0.05,
shear_range=0.05,
zoom_range=0.05,
horizontal_flip=True,
fill_mode='nearest')
def Augmentation(self):
"""
Start augmentation.....
"""
trains = self.train_imgs
labels = self.label_imgs
path_train = self.train_path
path_label = self.label_path
path_merge = self.merge_path
imgtype = self.img_type
path_aug_merge = self.aug_merge_path
print(path_label)
if len(trains) != len(labels) or len(trains) == 0 or len(trains) == 0:
print ("trains can't match labels")
return 0
for i in range(len(trains)):
img_t = load_img(path_train+"/"+str(i)+"."+imgtype)
img_l = load_img(path_label+"/"+str(i)+"."+imgtype)
x_t = img_to_array(img_t)
x_l = img_to_array(img_l)
x_t[:,:,2] = x_l[:,:,0]
img_tmp = array_to_img(x_t)
img_tmp.save(path_merge+"/"+str(i)+"."+imgtype)
img = x_t
img = img.reshape((1,) + img.shape)
savedir = path_aug_merge + "/" + str(i)
if not os.path.lexists(savedir):
os.mkdir(savedir)
self.doAugmentate(img, savedir, str(i))
def doAugmentate(self, img, save_to_dir, save_prefix, batch_size=1, save_format='tif', imgnum=30):
# 增强一张图片的方法
"""
augmentate one image
"""
datagen = self.datagen
i = 0
for batch in datagen.flow(img,
batch_size=batch_size,
save_to_dir=save_to_dir,
save_prefix=save_prefix,
save_format=save_format):
i += 1
if i > imgnum:
break
def splitMerge(self):
# 将合在一起的图片分开
"""
split merged image apart
"""
path_merge = self.aug_merge_path
path_train = self.aug_train_path
path_label = self.aug_label_path
for i in range(self.slices):
path = path_merge + "/" + str(i)
train_imgs = glob.glob(path+"/*."+self.img_type)
savedir = path_train + "/" + str(i)
if not os.path.lexists(savedir):
os.mkdir(savedir)
savedir = path_label + "/" + str(i)
if not os.path.lexists(savedir):
os.mkdir(savedir)
for imgname in train_imgs:
midname = imgname[imgname.rindex("/")+1:imgname.rindex("."+self.img_type)]
img = cv2.imread(imgname)
img_train = img[:,:,2] #cv2 read image rgb->bgr
img_label = img[:,:,0]
cv2.imwrite(path_train+"/"+str(i)+"/"+midname+"_train"+"."+self.img_type,img_train)
cv2.imwrite(path_label+"/"+str(i)+"/"+midname+"_label"+"."+self.img_type,img_label)
def splitTransform(self):
# 拆分透视变换后的图像
"""
split perspective transform images
"""
#path_merge = "transform"
#path_train = "transform/data/"
#path_label = "transform/label/"
path_merge = "deform/deform_norm2"
path_train = "deform/train/"
path_label = "deform/label/"
train_imgs = glob.glob(path_merge+"/*."+self.img_type)
for imgname in train_imgs:
midname = imgname[imgname.rindex("/")+1:imgname.rindex("."+self.img_type)]
img = cv2.imread(imgname)
img_train = img[:,:,2]#cv2 read image rgb->bgr
img_label = img[:,:,0]
cv2.imwrite(path_train+midname+"."+self.img_type,img_train)
cv2.imwrite(path_label+midname+"."+self.img_type,img_label)
class dataProcess(object):
def __init__(self, out_rows, out_cols, data_path = "../deform/train", label_path = "../deform/label", test_path = "../test", npy_path = "../npydata", img_type = "tif"):
# 数据处理类,初始化
self.out_rows = out_rows
self.out_cols = out_cols
self.data_path = data_path
self.label_path = label_path
self.img_type = img_type
self.test_path = test_path
self.npy_path = npy_path
# 创建训练数据
def create_train_data(self):
i = 0
print('-'*30)
print('Creating training images...')
print('-'*30)
imgs = glob.glob(self.data_path+"/*."+self.img_type)
print(len(imgs))
imgdatas = np.ndarray((len(imgs),self.out_rows,self.out_cols,1), dtype=np.uint8)
imglabels = np.ndarray((len(imgs),self.out_rows,self.out_cols,1), dtype=np.uint8)
for imgname in imgs:
midname = imgname[imgname.rindex("/")+1:]
midname=midname.replace('train\\','')
print(midname)
img = load_img(self.data_path + '/' + midname,grayscale = True)
print(self.data_path + '/' + midname) #img路径
label = load_img(self.label_path + "/" + midname,grayscale = True)
print(self.label_path + "/" + midname) #label路径
img = img_to_array(img)
label = img_to_array(label)
imgdatas[i] = img
imglabels[i] = label
if i % 100 == 0:
print('Done: {0}/{1} images'.format(i, len(imgs)))
i += 1
print('loading done')
np.save(self.npy_path + '/imgs_train.npy', imgdatas)
np.save(self.npy_path + '/imgs_mask_train.npy', imglabels)
print('Saving to .npy files done.')
# 加载训练图片与mask
def load_train_data(self):
print('-'*30)
print('load train images...')
print('-'*30)
imgs_train = np.load(self.npy_path+"/imgs_train.npy")
imgs_mask_train = np.load(self.npy_path+"/imgs_mask_train.npy")
imgs_train = imgs_train.astype('float32')
imgs_mask_train = imgs_mask_train.astype('float32')
imgs_train /= 255
mean = imgs_train.mean(axis = 0)
imgs_train -= mean
imgs_mask_train /= 255
# 做一个阈值处理,输出的概率值大于0.5的就认为是对象,否则认为是背景
imgs_mask_train[imgs_mask_train > 0.5] = 1
imgs_mask_train[imgs_mask_train <= 0.5] = 0
return imgs_train,imgs_mask_train
if __name__ == "__main__":
mydata = dataProcess(512,512)
mydata.create_train_data()
imgs_train,imgs_mask_train = mydata.load_train_data()
print (imgs_train.shape,imgs_mask_train.shape)
3、train
在my_train_test下运行train_unet.py: (159行可改参数) 生成模型:my_unet.hdf5
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import numpy as np
from keras.models import *
from keras.layers import Input, merge, Conv2D, MaxPooling2D, UpSampling2D, Dropout, Cropping2D, concatenate
from keras.optimizers import *
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras import backend as keras
from creat_imagelabel_npy import *
class myUnet(object):
def __init__(self, img_rows = 512, img_cols = 512):
self.img_rows = img_rows
self.img_cols = img_cols
# 参数初始化定义
def load_data(self):
mydata = dataProcess(self.img_rows, self.img_cols)
imgs_train, imgs_mask_train = mydata.load_train_data()
#imgs_test = mydata.load_test_data()
return imgs_train, imgs_mask_train,
# 载入数据
def get_unet(self):
inputs = Input((self.img_rows, self.img_cols,1))
# 网络结构定义
'''
#unet with crop(because padding = valid)
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(inputs)
print "conv1 shape:",conv1.shape
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv1)
print "conv1 shape:",conv1.shape
crop1 = Cropping2D(cropping=((90,90),(90,90)))(conv1)
print "crop1 shape:",crop1.shape
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
print "pool1 shape:",pool1.shape
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool1)
print "conv2 shape:",conv2.shape
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv2)
print "conv2 shape:",conv2.shape
crop2 = Cropping2D(cropping=((41,41),(41,41)))(conv2)
print "crop2 shape:",crop2.shape
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
print "pool2 shape:",pool2.shape
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool2)
print "conv3 shape:",conv3.shape
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv3)
print "conv3 shape:",conv3.shape
crop3 = Cropping2D(cropping=((16,17),(16,17)))(conv3)
print "crop3 shape:",crop3.shape
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
print "pool3 shape:",pool3.shape
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv4)
drop4 = Dropout(0.5)(conv4)
crop4 = Cropping2D(cropping=((4,4),(4,4)))(drop4)
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv5)
drop5 = Dropout(0.5)(conv5)
up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
merge6 = merge([crop4,up6], mode = 'concat', concat_axis = 3)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge6)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv6)
up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
merge7 = merge([crop3,up7], mode = 'concat', concat_axis = 3)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge7)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv7)
up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
merge8 = merge([crop2,up8], mode = 'concat', concat_axis = 3)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge8)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv8)
up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
merge9 = merge([crop1,up9], mode = 'concat', concat_axis = 3)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge9)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv9)
conv9 = Conv2D(2, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv9)
'''
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
print ("conv1 shape:",conv1.shape)
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
print ("conv1 shape:",conv1.shape)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
print ("pool1 shape:",pool1.shape)
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
print ("conv2 shape:",conv2.shape)
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
print ("conv2 shape:",conv2.shape)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
print ("pool2 shape:",pool2.shape)
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
print ("conv3 shape:",conv3.shape)
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
print ("conv3 shape:",conv3.shape)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
print ("pool3 shape:",pool3.shape)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
drop4 = Dropout(0.5)(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
drop5 = Dropout(0.5)(conv5)
up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
#merge6 = merge([drop4,up6], mode = 'concat', concat_axis = 3)
merge6 = concatenate([drop4,up6],axis = 3)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
#merge7 = merge([conv3,up7], mode = 'concat', concat_axis = 3)
merge7 = concatenate([conv3,up7],axis = 3)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
#merge8 = merge([conv2,up8], mode = 'concat', concat_axis = 3)
merge8 = concatenate([conv2,up8],axis = 3)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
#merge9 = merge([conv1,up9], mode = 'concat', concat_axis = 3)
merge9 = concatenate([conv1,up9],axis = 3)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)
model = Model(input = inputs, output = conv10)
model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])
return model
# 如果需要修改输入的格式,那么可以从以下开始修改,上面的结构部分不需要修改
def train(self):
print("loading data")
imgs_train, imgs_mask_train= self.load_data()
print("loading data done")
model = self.get_unet()
print("got unet")
model_checkpoint = ModelCheckpoint('my_unet.hdf5', monitor='loss',verbose=1, save_best_only=True)
print('Fitting model...')
model.fit(imgs_train, imgs_mask_train, batch_size=2, nb_epoch=3, verbose=0,validation_split=0.2, shuffle=True, callbacks=[model_checkpoint])
#print('predict test data')
# imgs_mask_test = model.predict(imgs_test, batch_size=1, verbose=1)
# np.save('../results/imgs_mask_test.npy', imgs_mask_test)
# def save_img(self):
# print("array to image")
# imgs = np.load('../results/imgs_mask_test.npy')
# for i in range(imgs.shape[0]):
# img = imgs[i]
# img = array_to_img(img)
# img.save("../results/%d.jpg"%(i))
if __name__ == '__main__':
myunet = myUnet()
myunet.train()
#myunet.save_img()
4、生成测试数据
在my_train_test下运行create_test_npy.py: 生成imgs_test.npy
'''from ../test create testdata(.npy), save to ../npydata/imgs_test.npy'''
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
import numpy as np
import os
import glob
import cv2
from libtiff import TIFF
class dataProcess(object):
def __init__(self, out_rows, out_cols, data_path = "../deform/train", label_path = "../deform/label", test_path = "../test", npy_path = "../npydata", img_type = "tif"):
# 数据处理类,初始化
self.out_rows = out_rows
self.out_cols = out_cols
self.data_path = data_path
self.label_path = label_path
self.img_type = img_type
self.test_path = test_path
self.npy_path = npy_path
# 创建测试数据
def create_test_data(self):
i = 0
print('-'*30)
print('Creating test images...')
print('-'*30)
imgs = glob.glob(self.test_path+"/*."+self.img_type)
print(self.test_path) #../test
print(len(imgs))
imgdatas = np.ndarray((len(imgs),self.out_rows,self.out_cols,1), dtype=np.uint8)
for imgname in imgs:
midname = imgname[imgname.rindex("/")+1:]
midname=midname.replace('test\\','/')
print(midname)
img = load_img(self.test_path + "/" + midname,grayscale = True)
img = img_to_array(img)
#img = cv2.imread(self.test_path + "/" + midname,cv2.IMREAD_GRAYSCALE)
#img = np.array([img])
imgdatas[i] = img
i += 1
print('loading done')
np.save(self.npy_path + '/imgs_test.npy', imgdatas)
print('Saving to imgs_test.npy files done.')
if __name__ == "__main__":
mydata = dataProcess(512,512)
mydata.create_test_data()
#imgs_train,imgs_mask_train = mydata.load_train_data()
6、test
在u-net下运行,test.py: 结果会保存在u-net下
'''read ../npydata/imgs_test.npy '''
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import numpy as np
from keras.models import *
from keras.layers import Input, merge, Conv2D, MaxPooling2D, UpSampling2D, Dropout, Cropping2D, concatenate
from keras.optimizers import *
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras import backend as keras
from keras.preprocessing.image import array_to_img
#from data import *
class dataProcess(object):
def __init__(self, out_rows, out_cols, test_path = "../test", npy_path = "../npydata", img_type = "tif"):
# 数据处理类,初始化
self.out_rows = out_rows
self.out_cols = out_cols
self.img_type = img_type
self.test_path = test_path
self.npy_path = npy_path
# 加载测试图片
def load_test_data(self):
print('-'*30)
print('load test images...')
print('-'*30)
#(self.npy_path+"/imgs_test.npy")
imgs_test = np.load("npydata/imgs_test.npy")
imgs_test = imgs_test.astype('float32')
imgs_test /= 255
mean = imgs_test.mean(axis = 0)
imgs_test -= mean
return imgs_test
class myUnet(object):
def __init__(self, img_rows = 512, img_cols = 512):
self.img_rows = img_rows
self.img_cols = img_cols
# 参数初始化定义
def load_data(self):
mydata = dataProcess(self.img_rows, self.img_cols)
#imgs_train, imgs_mask_train = mydata.load_train_data()
imgs_test = mydata.load_test_data()
return imgs_test
# 载入数据
def get_unet(self):
inputs = Input((self.img_rows, self.img_cols,1))
# 网络结构定义
'''
#unet with crop(because padding = valid)
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(inputs)
print "conv1 shape:",conv1.shape
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv1)
print "conv1 shape:",conv1.shape
crop1 = Cropping2D(cropping=((90,90),(90,90)))(conv1)
print "crop1 shape:",crop1.shape
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
print "pool1 shape:",pool1.shape
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool1)
print "conv2 shape:",conv2.shape
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv2)
print "conv2 shape:",conv2.shape
crop2 = Cropping2D(cropping=((41,41),(41,41)))(conv2)
print "crop2 shape:",crop2.shape
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
print "pool2 shape:",pool2.shape
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool2)
print "conv3 shape:",conv3.shape
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv3)
print "conv3 shape:",conv3.shape
crop3 = Cropping2D(cropping=((16,17),(16,17)))(conv3)
print "crop3 shape:",crop3.shape
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
print "pool3 shape:",pool3.shape
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv4)
drop4 = Dropout(0.5)(conv4)
crop4 = Cropping2D(cropping=((4,4),(4,4)))(drop4)
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv5)
drop5 = Dropout(0.5)(conv5)
up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
merge6 = merge([crop4,up6], mode = 'concat', concat_axis = 3)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge6)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv6)
up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
merge7 = merge([crop3,up7], mode = 'concat', concat_axis = 3)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge7)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv7)
up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
merge8 = merge([crop2,up8], mode = 'concat', concat_axis = 3)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge8)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv8)
up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
merge9 = merge([crop1,up9], mode = 'concat', concat_axis = 3)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge9)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv9)
conv9 = Conv2D(2, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv9)
'''
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
print ("conv1 shape:",conv1.shape)
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
print ("conv1 shape:",conv1.shape)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
print ("pool1 shape:",pool1.shape)
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
print ("conv2 shape:",conv2.shape)
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
print ("conv2 shape:",conv2.shape)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
print ("pool2 shape:",pool2.shape)
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
print ("conv3 shape:",conv3.shape)
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
print ("conv3 shape:",conv3.shape)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
print ("pool3 shape:",pool3.shape)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
drop4 = Dropout(0.5)(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
drop5 = Dropout(0.5)(conv5)
print("nn")
up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
#merge6 = merge([drop4,up6], mode = 'concat', concat_axis = 3)
merge6 = concatenate([drop4,up6],axis=3)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
print("nn")
up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
#merge7 = merge([conv3,up7], mode = 'concat', concat_axis = 3)
merge7 = concatenate([conv3,up7],axis=3)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
#merge8 = merge([conv2,up8], mode = 'concat', concat_axis = 3)
merge8 = concatenate([conv2,up8],axis=3)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
#merge9 = merge([conv1,up9], mode = 'concat', concat_axis = 3)
merge9 = concatenate([conv1,up9],axis=3)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)
model = Model(input = inputs, output = conv10)
model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])
return model
# 如果需要修改输入的格式,那么可以从以下开始修改,上面的结构部分不需要修改
def train(self):
print("loading data")
imgs_test = self.load_data()
print("loading data done")
model = self.get_unet()
print("got unet")
model_checkpoint = ModelCheckpoint('my_unet.hdf5', monitor='loss',verbose=1, save_best_only=True)
print('Fitting model...')
#model.fit(imgs_train, imgs_mask_train, batch_size=2, nb_epoch=10, verbose=1,validation_split=0.2, shuffle=True, callbacks=[model_checkpoint])
model = load_model("my_train_test/my_unet.hdf5")
print('predict test data')
imgs_mask_test = model.predict(imgs_test, batch_size=2, verbose=1)
np.save('imgs_mask_test.npy', imgs_mask_test)
def save_img(self):
print("array to image")
imgs = np.load('imgs_mask_test.npy')
for i in range(imgs.shape[0]):
img = imgs[i]
img = array_to_img(img)
img.save("%d.jpg"%(i))
if __name__ == '__main__':
myunet = myUnet()
myunet.train()
myunet.save_img()
报错记录:
1、如果test.py在my_train_test下运行且
第184行:np.save('results/imgs_mask_test.npy', imgs_mask_test)
189行:imgs = np.load('results/imgs_mask_test.npy')
193行:img.save("../results/%d.jpg"%(i))
会报错:
改为
184行:np.save('imgs_mask_test.npy', imgs_mask_test)
189行:imgs = np.load(imgs_mask_test.npy')
后在u-net下生成:
但未生成测试图片报错:
再把193行改为:img.save("%d.jpg"%(i)),于是:
在u-net下生成(不是很懂生成的图片及其顺序?):