8,训练模型
进入 models/research路径
python deeplab/train.py \
--logtostderr \
--training_number_of_steps=1000 \
--train_split="train" \
--model_variant="xception_65" \
--atrous_rates=6 \
--atrous_rates=12 \
--atrous_rates=18 \
--output_stride=16 \
--decoder_output_stride=4 \
--train_crop_size="513,513" \
--train_batch_size=2 \
--fine_tune_batch_norm=false \
--dataset="mydata" \
--tf_initial_checkpoint='/home/lw/data/cityscapes/deeplabv3_cityscapes_train/model.ckpt' \
--train_logdir='/home/lw/data/mydata/train' \
--dataset_dir='/home/lw/data/mydata/tfrecord'
模型存在于/home/lw/data/mydata/train
训练过程很烧GPU,发热,注意散热
9.验证模型
python deeplab/eval.py \
--logtostderr \
--eval_split="val" \
--model_variant="xception_65" \
--atrous_rates=6 \
--atrous_rates=12 \
--atrous_rates=18 \
--output_stride=16 \
--decoder_output_stride=4 \
--dataset="mydata" \
--checkpoint_dir='/home/lw/data/mydata/train' \
--eval_logdir='/home/lw/data/mydata/eval' \
--dataset_dir='/home/lw/data/mydata/tfrecord'
默认只有miou评价标准的值,读者可自行加入其他评价指标。比如,accuracy,precision,recall,f1_score的值
运行结果如下:
10, 预测模型
python deeplab/vis.py \
--logtostderr \
--vis_split="val" \
--model_variant="xception_65" \
--atrous_rates=6 \
--atrous_rates=12 \
--atrous_rates=18 \
--output_stride=16 \
--decoder_output_stride=4 \
--dataset="mydata" \
--checkpoint_dir='/home/lw/data/mydata/train' \
--vis_logdir="/home/lw/data/mydata/vis" \
--dataset_dir="/home/lw/data/mydata/tfrecord"
11.模型导出.pd格式,并用前面博客显示代码显示出来
python deeplab/export_model.py \
--logtostderr \
--checkpoint_path="/home/lw/data/mydata/train/model.ckpt-1000" \
--export_path="/home/lw/data/mydata/pb/frozen_inference_graph.pb" \
--model_variant="xception_65" \
--atrous_rates=6 \
--atrous_rates=12 \
--atrous_rates=18 \
--output_stride=16 \
--decoder_output_stride=4 \
--num_classes=2 \
--inference_scales=1.0
我们发现问题:预测结果和实际位置不匹配。
import os
import cv2
file_path = "/home/lw/data/mydata/mask/"
list_path = os.listdir(file_path)
for i in range(0, len(list_path)):
path = os.path.join(file_path, list_path[i])
if os.path.isfile(path) & path.endswith('.png'):
image = cv2.imread(path, -1)
image = cv2.resize(image, (512, 512) )
cv2.imwrite(path, image)
把原始图像image和标签图像mask重新整理成大小一样的,再重新生成tfrecord,训练模型,预测,再做一遍。
第一张比较准,是因为这是训练集中的数据进行测试。
第二张不准,说明模型训练的还不好,可能的原因是数据量太小了,总共才10张训练数据。后序完善数据。
12,边缘提取
import cv2
import numpy as np
img = cv2.imread("/home/lw/data/mydata/vis/segmentation_results/000000_prediction.png", 0)
# #(3, 3)表示高斯矩阵的长与宽都是3,标准差取0
# img = cv2.GaussianBlur(img,(3,3),0)
#image:源图像 threshold1:阈值1 threshold2:阈值2
#其中,较大的阈值2用于检测图像中明显的边缘,但一般情况下检测的效果不会那么完美,边缘检测出来是断断续续的。所以这时候用较小的第一个阈值用于将这些间断的边缘连接起来。
canny = cv2.Canny(img, 0, 150)
cv2.imshow('Canny', canny)
cv2.waitKey(0)
cv2.destroyAllWindows()
参考博客https://www.jianshu.com/p/4f33821d28ba