如何实现Python TensorFlow训练网络模型,C++ OpenCV调用预测——示例:ResNet网络在MNIST数据集上的实现

本文介绍如何使用ResNet网络进行图像分类,并通过OpenCV调用.pb模型文件进行预测。主要内容包括网络搭建、模型训练、pb文件生成及Python与C++调用测试。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

原理这边就不多说了,主要流程就是:

1,搭建网络;

2,训练网络模型,生成ckpt文件

3,保存graph,并将训练好的参数固化到graph中,生成.pb文件

4,OpenCV调用.pb文件预测

需要注意的几点:

1,搭建网络时注意网络节点的命名,后续会用到;

2,TensorFlow有的算子OpenCV不支持,如果搭建网络时用到了OpenCV不支持的算子,进行预测时,会报错(没有找到OpenCV支持/不支持的算子文档,找到的同学可以留言告诉我噢);

3,一定要有freeze这一步骤,如果直接保存,同样的输入,每次预测的结果都不一样,一开始我是参考网上的资料训练完成之后直接用graph_util.convert_variables_to_constants保存,调用时发现每次结果都随机;

4,生成.pb文件后,在用OpenCV调用前可以先Python调用预测一下保证pb模型文件的正确性

好了,不多说,直接上代码

一. ResNet网络搭建

def Block(X_input,kernel_size,in_filter,out_filters,stride,choice):
    X_sortcut = X_input
    f1, f2, f3 = out_filters
    conv_w1 = tf.get_variable('weight1', [1, 1, in_filter, f1], initializer=tf.truncated_normal_initializer(stddev=0.1))
    conv_b1 = tf.get_variable('bias1', [f1],initializer=tf.constant_initializer(0.0))

    conv_w2 = tf.get_variable('weight2', [kernel_size, kernel_size, f1, f2], initializer=tf.truncated_normal_initializer(stddev=0.1))
    conv_b2 = tf.get_variable('bias2', [f2], initializer=tf.constant_initializer(0.0))

    conv_w3 = tf.get_variable('weight3', [1, 1, f2, f3], initializer=tf.truncated_normal_initializer(stddev=0.1))
    conv_b3 = tf.get_variable('bias3', [f3], initializer=tf.constant_initializer(0.0))

    if choice:
        # frist
        y = tf.nn.relu(tf.nn.conv2d(X_input, conv_w1, strides=[1, 1, 1, 1], padding="SAME") + conv_b1)
        # second
        y = tf.nn.relu(tf.nn.conv2d(y, conv_w2, strides=[1, 1, 1, 1], padding="SAME") + conv_b2)
        # third
        y = tf.nn.relu(tf.nn.conv2d(y, conv_w3, strides=[1, 1, 1, 1], padding="SAME") + conv_b3)

        add_result = tf.nn.relu(y + X_sortcut)

        return add_result
    else:
        # frist
        y = tf.nn.relu(
            tf.nn.conv2d(X_input, conv_w1, strides=[1, stride, stride, 1], padding="SAME") + conv_b1)
        # second
        y = tf.nn.relu(tf.nn.conv2d(y, conv_w2, strides=[1, 1, 1, 1], padding="SAME") + conv_b2)
        # third
        y = tf.nn.relu(tf.nn.conv2d(y, conv_w3, strides=[1, 1, 1, 1], padding="SAME") + conv_b3)
        # final steap
        add = tf.nn.conv2d(X_sortcut, conv_w3, strides=[1, 1, 1, 1], padding="SAME")

        add_result = tf.nn.relu(y + add)

        return add_result

def ResNet_18(input_tensor):
    with tf.variable_scope('conv1'):
        conv1_weights = tf.get_variable('weight', [3, 3, 1, 16],
                                        initializer=tf.truncated_normal_initializer(stddev=0.1))
        conv1_biases = tf.get_variable('bias', [16],
                                       initializer=tf.constant_initializer(0.0))
        conv1 = tf.nn.conv2d(input_tensor, conv1_weights, strides=[1, 1, 1, 1], padding='SAME')
        relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_biases))
        pool1 = tf.nn.max_pool(relu1, ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

    with tf.variable_scope('block1'):
        Rt1 = Block(pool1, 3, 16, [8,8,16],1,True)

    with tf.variable_scope('block2'):
        Rt2 = Block(Rt1, 3, 16, [8,8,16],1,True)

    with tf.variable_scope('block3'):
        Rt3 = Block(Rt2, 3, 16, [8,8,16],1,True)

    with tf.variable_scope('block4'):
        Rt4 = Block(Rt3, 3, 16, [8,8,16],1,True)

    with tf.variable_scope('block5'):
        Rt5 = Block(Rt4, 3, 16, [16,16,32],1,False)

    with tf.variable_scope('Flat'):
        pool2 = tf.nn.max_pool(Rt5, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
        flat = tf.reshape(pool2, [-1, 7 * 7 * 32])
    # ---------------------------------全链接层------------------------------------
    with tf.variable_scope('Dense'):
        W = tf.Variable(tf.random_normal([7 * 7 * 32, 128], dtype=tf.float32, stddev=0.1))
        B = tf.Variable(tf.zeros([128]))

        W1 = tf.Variable(tf.random_normal([128, 10], dtype=tf.float32, stddev=0.1))
        B1 = tf.Variable(tf.zeros([10]))

        y0 = tf.nn.relu(tf.matmul(flat, W) + B)
        logit = tf.nn.softmax(tf.matmul(y0, W1) + B1, name='soft')

        return logit

 

二. 训练模型,保存pb文件

def train(mnist):
    x = tf.placeholder(tf.float32, [None, 28, 28, 1], name='x-input')
    y_ = tf.placeholder(tf.float32, [None, 10], name='y-input')

    y = ResNet_18(x)

    global_step = tf.Variable(0, trainable=False)

    variable_average = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    variable_average_op = variable_average.apply(
        tf.trainable_variables())
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y_, 1), logits=y)
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    loss = cross_entropy_mean
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,
                                               global_step=global_step,
                                               decay_steps=mnist.train.num_examples / BATCH_SIZE,
                                               decay_rate=LEARNING_RATE_DECAY)
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)

    with tf.control_dependencies([train_step, variable_average_op]):
        train_op = tf.no_op(name='train')

    saver = tf.train.Saver()
    with tf.Session() as sess:
        tf.train.write_graph(sess.graph_def, "./model", 'test_graph.pb')
        tf.global_variables_initializer().run()

        for i in range(TRAIN_STEPS):
            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            xs = np.reshape(xs, [-1, 28,28,1])
            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
            if i % 1000 == 0:
                print("After %d training steps, loss on training"
                      "batch is %g" % (step, loss_value))
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)

        ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            tf.train.write_graph(sess.graph_def, './model', 'model.pb')
            freeze_graph.freeze_graph('./model/model.pb', '', False, ckpt.model_checkpoint_path, 'Dense/soft',
                                      '', '', './model/frozen_model.pb',
                                      False, "")

三. 调用pb模型预测

Python调用测试代码:

import tensorflow as tf
from tensorflow.python.platform import gfile
import numpy as np
from PIL import Image

def main():
    with tf.Session() as sess:
        with gfile.FastGFile("model/frozen_model.pb", 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            sess.graph.as_default()
            tf.import_graph_def(graph_def, name='')

        sess.run(tf.global_variables_initializer())
        input_x = sess.graph.get_tensor_by_name('x-input:0')
        op = sess.graph.get_tensor_by_name('Dense/soft:0')
        num = tf.argmax(op, 1)
        img = Image.open('3.bmp')
        x_img_array = np.float32(np.asarray(img)) / 255.0
        reshape_img = np.reshape(x_img_array, [-1, 28, 28, 1])
        validate_feed = {input_x: reshape_img}
        y_value,result = sess.run([op,num], feed_dict=validate_feed)
        print(result[0])

main()

C++ OpenCV调用测试代码:

#include "stdafx.h"
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
#include <Windows.h>
#include <fstream>
#include <iostream>
#include <cstdlib>
using namespace cv;
using namespace cv::dnn;
using namespace std;
//自己新建一个txt文件,写入分类的标签(一行写一个标签,例如二分类,第一行写good,第二行bad)
String labels_txt_file = "G:\\CODE\\Lesson\\lesson_leNet5\\model\\label.txt";
String tf_pb_file = "G:\\CODE\\Lesson\\lesson_ResNet\\model\\frozen_model.pb";
//String tf_pbtxt_file = "G:\\CODE\\Lesson\\lesson_leNet5\\model\\protobuf.pbtxt";
vector <String> readClassNames();
void main()
{
	Mat src = imread("D:\\2.bmp");
	if (src.empty())
	{
		cout << "error:no img" << endl;
	}

	vector <String> labels = readClassNames();
	Mat rgb;
	int w = 28;
	int h = 28;
	resize(src, src, Size(w, h));
	cvtColor(src, rgb, COLOR_BGR2GRAY);
	int n = src.channels();
	Net net = readNetFromTensorflow(tf_pb_file);
	DWORD timestart = GetTickCount();
	if (net.empty())
	{
		cout << "error:no model" << endl;
	}
	Mat inputBlob = blobFromImage(rgb, 0.003921, Size(w, h), Scalar(), false, false);
	//执行图像分类
	Mat prob;
	net.setInput(inputBlob, "x-input");
	prob = net.forward("Dense/soft");
	cout << prob << endl;
	//得到最大分类概率
	Mat probMat = prob.reshape(1, 1);
	Point classNumber;
	double classProb;
	minMaxLoc(probMat, NULL, &classProb, NULL, &classNumber);
	DWORD timeend = GetTickCount();
	int classidx = classNumber.x;
	printf("\n current image classification : %s, possible : %.2f\n", labels.at(classidx).c_str(), classProb);
	cout << "用时(毫秒):" << timeend - timestart << endl;
	// 显示文本
	putText(src, labels.at(classidx), Point(20, 20), FONT_HERSHEY_SIMPLEX, 1.0, Scalar(0, 0, 255), 2, 8);
	imshow("Image Classfication", src);
	waitKey(0);

}

vector <String>readClassNames()
{
	vector <String>classNames;
	fstream fp(labels_txt_file);
	if (!fp.is_open())
	{
		cout << "does not open" << endl;
		exit(-1);
	}
	string name;
	while (!fp.eof())
	{
		getline(fp, name);
		if (name.length())
			classNames.push_back(name);
	}
	fp.close();
	return classNames;
}

 

### 使用C语言加载和运行ONNX模型 尽管大多数关于ONNX模型的部署案例都集中在C++上,但由于C语言本身并不具备像C++那样的高级抽象库支持(如OpenCV DNN模块),因此直接通过纯C语言实现ONNX模型的加载和运行会更加复杂。然而,可以通过一些间接方式来完成这一目标。 #### 方法一:借助第三方库(如onnxruntime) `onnxruntime` 是一个高性能的开源推理引擎,支持多种编程语言,包括C语言。以下是具体步骤: 1. **下载并配置 onnxruntime 库** 需要从官方仓库获取 `onnxruntime` 的源码或预编译二进制文件,并将其集成到项目中[^1]。 2. **初始化环境** 在C程序中引入必要的头文件以及链接动态/静态库。例如: ```c #include "onnxruntime_c_api.h" ``` 3. **创建Session对象** 创建一个 session 来管理模型实例化过程中的资源分配与释放操作。 ```c OrtEnv* env; OrtSessionOptions* session_options; OrtSession* session; OrtCreateEnv(ORT_API_VERSION, "Default", &env); OrtCreateSessionOptions(env, &session_options); OrtCreateSession(env, session_options, "./model.onnx", &session); // 加载模型路径 ``` 4. **准备输入数据** 将待预测的数据按照模型需求格式化成张量形式存储于内存缓冲区中。 ```c float input_data[] = { /* 输入样本 */ }; size_t dims_input[] = { 1, 784 }; // 假设 MNIST 数据集大小为 (1, 784) OrtValue* input_tensor; OrtCreateTensorWithDataAsOrtValue( NULL, ORT_TENSOR_FLOAT, input_data, sizeof(input_data), dims_input, 2, &input_tensor); ``` 5. **执行推理计算** 调用 API 函数触发实际运算逻辑。 ```c const char* input_names[] = {"input_name"}; const char* output_names[] = {"output_name"}; OrtRunOptions* run_opts; OrtCreateRunOptions(&run_opts); OrtValue** outputs; OrtRun(session, run_opts, input_names, &input_tensor, 1, output_names, 1, &outputs); ``` 6. **处理输出结果** 获取推理后的数值并对它们做进一步解析或者可视化展示等工作。 ```c float* result_ptr; OrtGetTensorMutableData(outputs[0], (void**)&result_ptr); printf("Prediction Result: %f\n", *result_ptr); ``` 7. **清理资源** 推理完成后记得销毁不再使用的变量以防止泄露问题发生。 ```c OrtReleaseValue(input_tensor); OrtReleaseValue(outputs[0]); OrtReleaseSession(session); OrtReleaseSessionOptions(session_options); OrtReleaseEnv(env); ``` --- #### 方法二:利用封装好的工具链 如果不想手动编写底层代码,则可以考虑采用某些已经做好高度优化工作的现成解决方案,比如 TensorFlow Lite 或者 PyTorch Mobile 提供的相关功能组件。不过这些通常是以嵌入式设备为目标场景设计出来的轻量化版本,在桌面端应用时可能无法充分发挥硬件性能优势[^2]。 --- ### 注意事项 - 确保所选框架能够完全兼容您的特定 ONNX 版本及其内部结构定义规则; - 对于大型网络架构而言,单纯依靠 CPU 进行实时性较高的任务可能会遇到瓶颈现象,此时建议升级至 GPU 平台或将算法简化后再尝试移植过去; - 如果计划长期维护该项目的话,请定期关注上游社区发布的最新改动情况以便及时更新依赖关系列表从而保持稳定性[^3]。 ```python # 示例 Python 模型导出流程 import torch from torchvision import models dummy_input = torch.randn(1, 3, 224, 224) model = models.resnet18(pretrained=True).eval() torch.onnx.export(model, dummy_input, "resnet18.onnx", opset_version=11) ```
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值