Go语言加载Tensorflow2保存的pb模型
文章目录
一、环境部署
- OS:Ubantu18.04
- Tensorflow动态库版本:Tensorflow2.1
- Go语言版本:go1.14
- Go语言第三方的包:GoCV------>OpenCV 4.3.0
- Tensorflow2.1-Go版本----------->Tensorflow2.1的C动态库
1.1 Go语言安装
分为两种安装方式:
- 下载Go语言的安装解压包,解压到对应路径,配置PATH,GOROOT,GOPATH等变量。
- 直接使用
apt install
命令下载。
第一种方式安装的Go需要将Go的安装目录加入PATH变量。第二种安装无需加入PATH。
GOPATH变量说明:
该变量配置时和Go的安装路径没关系,该变量指定一个目录里面存放Go语言安装的第三方库文件,即使用go get
命令下载安装的库文件都会存放在GOPATH变量的文件夹下。(一定要配置该变量,方便以后安装Go的第三方包。)
配置环境变量后 注意使用以下命令使环境变量生效:
source /etc/profile
1.2 Go版Tensorflow安装
官方文档地址:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/go 中有安装方法
但执行以下代码时,发现无法下载
go get -d github.com/tensorflow/tensorflow/tensorflow/go
其实,go get
命令可分为两步
- 下载需要安装包的源代码,解压到
GOPATH
中src
文件夹中 - 使用
go install
命令来安装该包
所以使用手动安装Go的第三方包:
- 从github上下载Tensorflow2.1的源码解压包,解压后放在GOPATH路径下的src文件夹下
- 按照官方的下载路径检查在src文件下的路径是否一致(重要,不能随意将源码解压到src路径下)
- 按照官方的下载路径,src中文件夹中的路径为:
$GOPATH/src/github.com/tensorflow/tensorflow/tensorflow/go/
- 编译tensorflow C的链接库(不展开详细介绍)
cd ${GOPATH}/src/github.com/tensorflow/tensorflow
./configure
bazel build -c opt //tensorflow:libtensorflow.so
双击目录下的bazel-bin
,进入文件夹,进入文件下的tensorflow文件,可跳转到生成动态库文件目录,该目录下的文件有:
将除文件夹的所有文件复制到/usr/local/lib/
文件下,并添加以下环境变量
export LIBRARY_PATH=$LIBRARY_PATH:${GOPATH}/src/github.com/tensorflow/tensorflow/bazel-bin/tensorflow
# Linux
export LD_LIBRARY_PATH=$LIBRARY_PATH:${GOPATH}/src/github.com/tensorflow/tensorflow/bazel-bin/tensorflow
LIBRARY_PATH
是程序编译期间查找动态链接库时,共享库的查找地址
LD_LIBRARY_PATH
是程序加载运行期间除了系统默认的查找动态链接库其他的搜索地址。
5.最后安装tensorflow-go:
go install github.com/tensorflow/tensorflow/tensorflow/go
6.测试是否安装成功
package main
import (
tf "github.com/tensorflow/tensorflow/tensorflow/go"
"github.com/tensorflow/tensorflow/tensorflow/go/op"
"fmt"
)
func main() {
// Construct a graph with an operation that produces a string constant.
s := op.NewScope()
c := op.Const(s, "Hello from TensorFlow version " + tf.Version())
graph, err := s.Finalize()
if err != nil {
panic(err)
}
// Execute the graph in a session.
sess, err := tf.NewSession(graph, nil)
if err != nil {
panic(err)
}
output, err := sess.Run(nil, []tf.Output{c}, nil)
if err != nil {
panic(err)
}
fmt.Println(output[0].Value())
}
输出tensorflow的版本则完成。
1.3 GoCV安装
GoCV是Go语言的OpenCV库,GoCV安装前提是系统已经安装了OpenCV4.3.0
官网安装地址:https://gocv.io/getting-started/linux/
若没有安装过OpenCV,使用以下命令安装.
go get -u -d gocv.io/x/gocv
cd $GOPATH/src/gocv.io/x/gocv
make install //若安装了OpenCV4.3.0 无需执行此命令
若已经安装了OpenCV3.X,需要重新安装OpenCV4.3.0 ,请参考以下设置方法共存多个版本。
下载OpenCV4.3,0的解压包,解压到指定目录,cd
到解压的文件下。
mkdir build
cd build
cmake -D CMAKE_BUILD_TYPE=Release \
-D CMAKE_INSTALL_PREFIX=/home/opencv4\ #安装目录,不能是usr/local 否则会覆盖以前的OpenCV3.X
-D OPENCV_GENERATE_PKGCONFIG=ON \ #OpenCV4 安装时一定要开启pkgconfig选项,安装目录下没有pkgconfig文件夹
..
make -j8
sudo make install
然后配置以下环境变量:
gedit ~/.bashrc
#在末尾添加安装目录lib中的路径
export PKG_CONFIG_PATH=$PKG_CONFIG_PATH:/home/opencv4/lib/pkgconfig
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/opencv4/lib
source ~/.bashrc
测试目前的版本:
#若使用该命令一直显示是OpenCV3.x的版本,一直以为是自己配置错了,这里有点小坑
pkg-config --modversion opencv
pkg-config --cflags opencv
pkg-config --libs opencv
#使用下面的命令才能先是OpenCV4的版本
pkg-config --modversion opencv4
pkg-config --cflags opencv4
pkg-config --libs opencv4
注意两个点:
- 编译OpenCV4时一定要开启
pkg-configure
选项,即cmake
时设置OPENCV_GENERATE_PKGCONFIG=ON
pkg-configure
命令 查看版本要使用opencv4
选项
GoCV安装测试代码为:
package main
import (
"fmt"
"os"
"gocv.io/x/gocv"
)
func main() {
if len(os.Args) < 2 {
fmt.Println("How to run:\n\tshowimage [imgfile]")
return
}
filename := os.Args[1]
window := gocv.NewWindow("Hello")
img := gocv.IMRead(filename, gocv.IMReadColor)
if img.Empty() {
fmt.Printf("Error reading image from: %v\n", filename)
return
}
for {
window.IMShow(img)
if window.WaitKey(1) >= 0 {
break
}
}
}
运行命令:
go run test_gocv.go image_path
二、Go代码测试
需要加载的模型是tf2.1训练的pb五位验证码识别模型,该模型的输入图像要求是经过灰度化后的四维tensor:[batch_size,height,width,channels]
,要先将图像处理模型适合的输入格式和数据。
需要引入的包:
import (
"errors"
"fmt"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
"github.com/tensorflow/tensorflow/tensorflow/go/op"
"log"
"reflect"
"strings"
"gocv.io/x/gocv"
"io/ioutil"
"os"
)
2.1 使用Go读入图像,使用tf处理图像
图像处理代码主要包含两个函数:
func makeImageTensorByGocv(filename string)(*tf.Tensor,error)
func constructGraphTo4DImage()(graph *tf.Graph, input, out_put_4d tf.Output, err error)
关键步骤为:
使用ioutil.ReadFile(filename)
得到bytes的切片,
使用tf.NewTensor(string(bytes_1))
转成tensor
/*
该函数传入图像路径,读入图片,并将图像转化为tensor
并将图像处理成model需要的输入格式
返回的Tensor格式为:[batch_size,height,width,channels]
*/
func makeTensorFromImage(filename string) (*tf.Tensor, error) {
bytes_1, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
// 将bytes_1 转成tensor
tensor, err := tf.NewTensor(string(bytes_1))
if err != nil {
fmt.Println("NewTensor is failed!")
return nil, err
}
// Construct a graph to normalize the image
graph, input, output, err := constructGraphToNormalizeImage()
if err != nil {
fmt.Println("constructGraphToNormalizeImage is failed!")
return nil, err
}
// Execute that graph to normalize this one image
session, err := tf.NewSession(graph, nil)
if err != nil {
fmt.Println("New Session is failed!")
return nil, err
}
defer session.Close()
// 开启session得到处理后的图像tensor
normalized, err := session.Run(
map[tf.Output]*tf.Tensor{input: tensor},
[]tf.Output{output},
nil)
if err != nil {
fmt.Println("Run Session is failed!")
fmt.Println(err)
return nil, err
}
return normalized[0], nil
}
将input处理成适合模型输入的四维tensor:
在强转时op.Cast(s, op.DecodePng(s, input, op.DecodePngChannels(1)), tf.Float)
设置op.DecodePngChannels(1)1,便可将图像灰度化。
// 此函数构造TensorFlow操作图,该图以输入PNG编码的字符串作为输入,并返回适合作为模型输入的张量。
func constructGraphToNormalizeImage() (graph *tf.Graph, input, output_img tf.Output, err error) {
// - 这个模型需要输入图片的格式为 40(h)*100(w)
// - 图像的通道信息应该为1通道
// - 需要将图像进行灰度化处理
const (
H, W = 40, 100
Mean = float32(117)
Scale = float32(1)
)
s := op.NewScope()
input = op.Placeholder(s, tf.String)
// - 这里其实就是对图片像素进行了处理
// - 利用双线性插值的方法改变图片的大小
output_img = op.ResizeBilinear(s,
// - 将tensor扩展为4维,加入了batch_size这个维度
op.ExpandDims(s,
// - 强制转化类型,注意设置Channels为1,即做了图像的灰度化处理
op.Cast(s, op.DecodePng(s, input, op.DecodePngChannels(1)), tf.Float),
op.Const(s.SubScope("make_batch"), int32(0))),
op.Const(s.SubScope("size"), []int32{H, W}))
graph, err = s.Finalize()
return graph, input, output_img, err
}
2.2 使用GoCV处理图像并转Tensor
图像处理代码主要包含两个函数:
func makeTensorFromImage(filename string) (*tf.Tensor, error)
func constructGraphToNormalizeImage() (graph *tf.Graph, input, output_img tf.Output, err error)
关键步骤:
为使用gocv.IMRead(filename,gocv.IMReadColor)
方法读取图片
使用gocv.CvtColor(img,&img,gocv.ColorBGRToGray)
方法将图像转化为灰度图
使用img_tmp:=gocv.NewMat()
创建新的Mat,使用img.ConvertTo(&img_tmp,gocv.MatTypeCV32FC1)
方法转化。
使用slice=img_tmp.DataPtrFloat32()
方法将Mat转成float的切片,tf.NewSession(slice)
得到一维的tensor。
/*
使用GoCV处理图像并做灰度化处理
并将GoCV的Mat转成四维的tensor
*/
func makeImageTensorByGocv(filename string)(*tf.Tensor,error) {
const (
H, W = 40, 100
CHANEELS=1
)
// 读取图像
img := gocv.IMRead(filename,gocv.IMReadColor)
if img.Empty() {
fmt.Printf("Error reading image from: %v\n", filename)
return nil,errors.New("can't read image!")
}
// 转化为灰度图像
gocv.CvtColor(img,&img,gocv.ColorBGRToGray)
// 转化为需要的图像大小
// gocv.Resize(img,&img,image.Pt(H, W), 0, 0, gocv.InterpolationCubic)
img_tmp:=gocv.NewMat()
img.ConvertTo(&img_tmp,gocv.MatTypeCV32FC1)
defer img_tmp.Close()
//img_slider是图像Mat转化成的切片
img_slider,err:=img_tmp.DataPtrFloat32() //Mat转slice切片
if err!=nil{
log.Println(err)
return nil,err
}
// slice切片转成tensor
tensor,err:=tf.NewTensor(img_slider)
if err!=nil {
log.Println(err)
return nil,err
}
graph, input, output, err := constructGraphTo4DImage()
if err != nil {
fmt.Println("constructGraph4DImage is failed!")
return nil, err
}
session, err := tf.NewSession(graph, nil)
if err != nil {
fmt.Println("New Session is failed!")
return nil, err
}
defer session.Close()
//Run得到四维的tensor
img_4d, err := session.Run(
map[tf.Output]*tf.Tensor{input: tensor},
[]tf.Output{output},
nil)
if err != nil {
fmt.Println("Run Session is failed!")
fmt.Println(err)
return nil, err
}
println(img_4d[0].Shape())
return img_4d[0],nil
}
将input的一维tensor转成四维的适合模型输入的tensor
/*
定义操作图
将一维的tensor变量reshape成四维的tensor,可用于模型的输入
*/
func constructGraphTo4DImage()(graph *tf.Graph, input, out_put_4d tf.Output, err error){
const (
H, W = 40, 100
)
s := op.NewScope()
input=op.Placeholder(s,tf.Float)
out_put_4d=op.Reshape(s,input,
op.Const(s.SubScope("shape"),[]int32{-1,H,W,1}))
graph, err = s.Finalize()
return graph, input, out_put_4d, err
}
2.3 LoadSavedModel加载模型
加载保存的pb模型主要使用LoadSavedModel
方法,该方法原型为:
//exportDir是pb文件夹的路径,tags是模型的标签,默认为‘serve’
func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*SavedModel, error)
tags
是必要的参数,所以说一定要知道模型文件的tags
,可以在训练时保存模型的tag。如果不知道模型的tag
可使用Tensorflow自带的工具saved_model_cli
查看模型tags
信息
saved_model_cli show --all --dir E:\InspurSummerJob\Project_IdentifyCode\project_weobo\saved_models\model_tf2V2
/*
加载保存的pb模型,
modelPath参数为模型文件夹路径
modelsNames参数为模型标签,一般为serve
*/
func LoadModel(modelPath string,modelsNames []string)(*tf.SavedModel){
model, err := tf.LoadSavedModel(modelPath, modelsNames, nil) // 载入模型
if err != nil {
log.Fatal("LoadSavedModel(): %v", err)
}
/* 打印出每一个op,找出模型的输出输出op
log.Println("List possible ops in graphs") // 打印出所有的Operator
for _, op := range model.Graph.Operations() {
//log.Printf("Op name: %v, on device: %v", op.Name(), op.Device())
log.Printf("Op name: %v", op.Name())
}
*/
return model
}
LoadModel
方法得到的model.Graph和model.Session
可直接生成graph
和session
img_test_tensor,err := makeImageTensorByGocv(image_path)
model :=LoadModel(model_path,[]string{"serve"})
// 从model计算图中创建Session和Graph
session := model.Session
graph := model.Graph
defer session.Close()
// 运行模型结果
output, err := session.Run(
map[tf.Output]*tf.Tensor{
graph.Operation("serving_default_inputs").Output(0): img_test_tensor,
},
[]tf.Output{
//获取5个输出
graph.Operation("StatefulPartitionedCall").Output(0),
graph.Operation("StatefulPartitionedCall").Output(1),
graph.Operation("StatefulPartitionedCall").Output(2),
graph.Operation("StatefulPartitionedCall").Output(3),
graph.Operation("StatefulPartitionedCall").Output(4),
},
nil)
if err != nil {
log.Fatal(err)
}
// 保存的模型的输出有5位
res :=""
for i:=0;i<5 ;i++ {
pro_value:=output[i].Value().([][]float32)[0]
s:=index2ch[getBestIdx(pro_value)]
res+=s
}
path_tmp:=strings.Split(image_path,"/")
image_name:=path_tmp[len(path_tmp)-1]
fmt.Println("原始结果为:%v",image_name)
fmt.Println("预测结果为:%v",res)
/*
该方法获取概率最大的一个分类的坐标
*/
func getBestIdx(pro_value []float32)(int) {
best_idx :=0
for index,value := range pro_value{
if value > pro_value[best_idx]{
best_idx=index
}
}
return best_idx
}
mt.Println(“预测结果为:%v”,res)
```go
/*
该方法获取概率最大的一个分类的坐标
*/
func getBestIdx(pro_value []float32)(int) {
best_idx :=0
for index,value := range pro_value{
if value > pro_value[best_idx]{
best_idx=index
}
}
return best_idx
}