FCN-tensorflow版本训练测试并用C++调用python检测
这里插入一个参考博主的链接: https://blog.youkuaiyun.com/qq_40994943/article/details/85041493.
由于博主是个git小白,正在研究上传github。文中代码下载先借助万能的百度盘
网盘下载点我
提取码:qs01
1.配置环境
由于win10环境下tensorflow的FCN。推荐大家下载使用anaconda进行python环境管理,我这里安装的tensorflow是cpu版本。下载装好anaconda后。
cmd输入
pip install keras
anaconda会自动匹配安装keras与tensorflow。这里博主使用的是labelme进行数据标注。
安装好keras与tensorflow之后。
cmd输入
pip install pyqt5
pip install labelme
安装完成之后cmd输入labelme即可运行标注软件labelme。
2.数据准备
数据集的格式是这个样子:
这里盗用一张参考博主的图
annotation文件夹放的是train和valid的label文件,具体形式是图片(png),image文件夹放的是原照片(jpg),两个文件夹的内容除了图片格式有区别外,文件名等必须一一对应。打开labelme进行数据标注,点击Save后会生成改图片对应的json文件。
运行convert_json.py,注意修改自己的jsonfile地址。将json文件 转换为五个文件。我们训练所需要的只是label文件。有可能是全黑的,不过没关系。
转换完之后,运行16-8.py,将我们生成的16位深的label图像转换为8位。注意修改自己的图像地址。
如果你不确定是否成功将这些黑黑的图像转换成功,运行(检测是否为8位.py)给你确定一下,输出两个数则正确。运行之前注意修改图像地址。
然后将图片按照数据集说明,放进去。对,千万别放错了。
代码跑不通,十有八九是数据集出了问题————卜芷叨誰朔得
数据集准备完毕,我们先把它压缩一下。对,就是先压缩他。像这样:
为什么要压缩呢,因为FCN.py里面有一个函数,找不到zip就下载。我们为了蒙蔽他,我们自己先压缩一个,这样避免了自动下载原FCN的数据集。嗯嗯。。。
3.开始训练
当你看到这里,你基本已经成功了。原VGG网络的权重依然依赖百度网盘:
VGG权重下载点我
提取码:m388
下载好后放入Model_zoo中即可。
训练时把FCN.py中的全局变量mode改为“train”,运行该文件,测试时修改测试函数里的图片地址,并把mode改为“test”运行即可。
4.C++调用python检测脚本
训练结束后会生成ckpt文件与checkpoint以及权重参数。为了在C++中部署与调用模型。直接用tensorflow 的C++API目前博主还没研究明白。采用了C++中的Python接口进行调用模型测试。
首先呢贴一个参考博主的连接https://blog.youkuaiyun.com/xiaomu_347/article/details/81040855
需要注意的是。win10系统+vs2015 或者vc2017也可以,亲测可用。
具体c++工程的创建与配置请参考刚才贴的那位博主。感谢感谢。
废话不多说,C++代码:
#include<iostream>
#include<Python.h>
#include<windows.h>
#include<ctime>
#pragma execution_character_set("utf-8")
using namespace std;
PyGILState_STATE gstate;
PyObject*pred = NULL;
PyObject*valid_images = NULL;
PyObject*pFunc_read_batch_image = NULL;
PyObject*Return_loadmodel = NULL;
PyObject*pFunc_test_image = NULL;
PyObject*pFunc_save = NULL;
void testImage(int W, int H, int batch_size,char * data_dir,char * meta_graph_path,char * model_ckpt_path)
{
try {
/*Py_Initialize();
PyEval_InitThreads();*/
PyObject*pred_annotation = NULL;
PyObject*image = NULL;
PyObject*keep_probability = NULL;
PyObject*new_saver = NULL;
PyObject*sess = NULL;
PyObject*valid_images = NULL;
PyObject*pFunc_loadmodel = NULL;
PyObject*pArg_loadmodel = PyTuple_New(6);
PyTuple_SetItem(pArg_loadmodel, 0, Py_BuildValue("s", meta_graph_path));
PyTuple_SetItem(pArg_loadmodel, 1, Py_BuildValue("s", model_ckpt_path));
PyTuple_SetItem(pArg_loadmodel, 2, Py_BuildValue("s", data_dir));
PyTuple_SetItem(pArg_loadmodel, 3, Py_BuildValue("i", batch_size));
PyTuple_SetItem(pArg_loadmodel, 4, Py_BuildValue("i", W));
PyTuple_SetItem(pArg_loadmodel, 5, Py_BuildValue("i", H));
PyObject*Pymodule = NULL;
Pymodule = PyImport_ImportModule("call_function");//调用py文件进行模型载入和预测
if (!Pymodule) {
printf("cannot open Pymodule!");
Py_Finalize();
return;
}
pFunc_loadmodel = PyObject_GetAttrString(Pymodule, "load_model");//从本地文件中载入模型
pFunc_read_batch_image = PyObject_GetAttrString(Pymodule, "read_batch_image");//读取一批次图片
pFunc_test_image = PyObject_GetAttrString(Pymodule, "test_image");//检测
pFunc_save = PyObject_GetAttrString(Pymodule, "save");//保存
if (!pFunc_loadmodel) {
printf("cannot open FUNC!");
Py_Finalize();
return;
}
gstate = PyGILState_Ensure();
Return_loadmodel = PyEval_CallObject(pFunc_loadmodel, pArg_loadmodel);
//Py_Finalize();
}
catch (exception& e)
{
cout << "Standard exception: " << e.what() << endl;
}
return;
}
void readImage(int W, int H, int batch_size, char * data_dir)
{
PyObject*pArg_read_batch_image = PyTuple_New(4);
PyTuple_SetItem(pArg_read_batch_image, 0, Py_BuildValue("s", data_dir));
PyTuple_SetItem(pArg_read_batch_image, 1, Py_BuildValue("i", batch_size));
PyTuple_SetItem(pArg_read_batch_image, 2, Py_BuildValue("i", W));
PyTuple_SetItem(pArg_read_batch_image, 3, Py_BuildValue("i", H));
valid_images = PyEval_CallObject(pFunc_read_batch_image, pArg_read_batch_image);
}
void detectImage(int batch_size)
{
pred = PyEval_CallObject(pFunc_test_image, Return_loadmodel);
PyObject*pArg_save = PyTuple_New(3);
PyTuple_SetItem(pArg_save, 0, pred);
PyTuple_SetItem(pArg_save, 1, Py_BuildValue("i", batch_size));
PyTuple_SetItem(pArg_save, 2, valid_images);
PyEval_CallObject(pFunc_save, pArg_save);
}
int main()
{
int batch_size = 10; //batch 大小
int W = 200; //图像大小
int H = 200;
char * data_dir = "C:\\testimg"; //存放数据集的路径
char * meta_graph_path = "C:\\testmodel\\model.ckpt-99500.meta";//tensor图结构地址
char * model_ckpt_path = "C:\\testmodel\\model.ckpt-99500";//模型权重文件地址
//clock_t startTime, endTime;
//startTime = clock();//计时开始
Py_Initialize();
PyEval_InitThreads();
testImage(W,H,batch_size,data_dir,meta_graph_path,model_ckpt_path);
readImage(W, H, batch_size, data_dir);
detectImage(batch_size);
//endTime = clock();//计时结束
//cout << "The run time is :" << (double)(endTime - startTime) / CLOCKS_PER_SEC << "s" << endl;
//PyGILState_Release(gstate);
Py_Finalize();
system("pause");
return 0;
}
然后,调用的python代码,就是你之前下载的FCN中的call_function.py。
#from __future__ import print_function
import tensorflow as tf
import numpy as np
import scipy.misc as misc
import time
import os
import cv2
#import TensorflowUtils as utils
#import read_MITSceneParsingData as scene_parsing
#import BatchDatsetReader as dataset
#import datetime
#from six.moves import xrange
#from random import randint
#from PIL import Image
def load_model(meta_graph_path,model_ckpt_path,data_dir,batch_size,W,H):
global sess
global new_saver
sess = tf.Session()
new_saver = tf.train.import_meta_graph(meta_graph_path)
new_saver.restore(sess, model_ckpt_path)
# tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
pred_annotation = tf.get_collection('pred_annotation')[0]
graph = tf.get_default_graph()
# 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
keep_probability = graph.get_operation_by_name('keep_probabilty').outputs[0]
image = graph.get_operation_by_name('input_image').outputs[0]
#annotation = graph.get_operation_by_name('annotation').outputs[0]
valid_images = read_batch_image(data_dir,batch_size,W,H)
return valid_images,pred_annotation, image, keep_probability, new_saver, sess
def read_batch_image(data_dir,batch_size,W,H):
valid_images = []
batch_size = batch_size
os.chdir(data_dir)
imgList = os.listdir(data_dir)
for pic in imgList:
img = cv2.imread(pic)
size = (int(W), int(H))
img = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
valid_images.append(img)
batch_size = batch_size - 1
if batch_size <= 0:
valid_images = np.array(valid_images)
return valid_images
else:
continue
def test_image(valid_images,pred_annotation,image,keep_probability,new_saver,sess):
sess = sess
new_saver = new_saver
since = time.time() # 时间模块
pred = sess.run(pred_annotation, feed_dict={image: valid_images, keep_probability:1.0})
time_elapsed = time.time() - since
print("Check batch image time is :%f s "%time_elapsed)
#print("Check one image time is :%f s"%(time_elapsed/batch_size))
sess.close()
return pred
def save(pred,batch_size,valid_images):
print("Saving Image...")
pred = np.squeeze(pred, axis=3) #从数组的形状中删除单维条目,即把shape中为1的维度去掉
for itr in range(batch_size):
save_image(pred[itr].astype(np.uint8),'C:/Users/wz/Desktop/1/', name="pred_" + str(1 + itr))
save_image(valid_images[itr].astype(np.uint8), 'C:/Users/wz/Desktop/', name="inp_" + str(1 + itr))
#save_image(valid_annotations[itr].astype(np.uint8), 'C:/Users/wz/Desktop/', name="gt_" + str(1 + itr))
print("Saved image: succeed")
def save_image(image, save_dir, name, mean=None):
if mean:
image = unprocess_image(image, mean)
misc.imsave(os.path.join(save_dir, name + ".png"), image)
def unprocess_image(image, mean_pixel):
return image + mean_pixel
# if __name__ == '__main__':
# batch_size = 10 # batch 大小
# IMAGE_SIZE = 200 # 图像尺寸
# data_dir = "C:/Users/wz/Desktop/FCN-tensorflow-hzp-master/Data_zoo1/MIT_SceneParsing/ADEChallengeData2016/images/training/" # 存放数据集的路径,需要提前下载
# #data_name = "ADEChallengeData2016"
# meta_graph_path = 'logs/model.ckpt-99500.meta' #tensor图结构地址
# model_ckpt_path = 'logs/model.ckpt-99500'#模型权重文件
# W = 200#图像大小
# H = 200
#pred_annotation,image,keep_probability,new_saver,sess = load_model(meta_graph_path,model_ckpt_path)
#valid_images = read_batch_image(data_dir,batch_size)
# since = time.time() # 时间模块
# pred = test_image(valid_images,pred_annotation,image,keep_probability,new_saver,sess)
# time_elapsed = time.time() - since
# print("Detection function time is :%f s"%time_elapsed)
# save(pred)
# sess.close()
调用的时候注意修改上面的路径为你自己的路径。将py文件放入x64的debug文件中,生成。OK,大功告成。
第一次写博客好紧张。机械系小白新人以后会将自己采坑与学习的过程分享给大家。
之后会有YOLOv3、SSD、Mask_Rcnn等cv算法的实现与windows上面的C++部署。
希望下次不会借助百度网盘了。GIT正在下载中,有帮助的话,期待你们的关注与点星哦!