转换darknet yolov3.weights为pb模型并测试pb模型

1.去giithub下载项目:https://github.com/YunYang1994/tensorflow-yolov3.git,解压成功后使用编译器打开为以下目录结构
在这里插入图片描述

2.编辑from_darknet_weights_to_ckpt.py脚本,将load_weights函数以外的部分注释掉。

import tensorflow as tf
from core.yolov3 import YOLOV3
import numpy as np

# iput_size = 416
# darknet_weights = 'yolov3.weights'
# ckpt_file = './checkpoint/yolov3_coco.ckpt'

def load_weights(var_list, weights_file):
    """
    Loads and converts pre-trained weights.
    :param var_list: list of network variables.
    :param weights_file: name of the binary file.
    :return: list of assign ops
    """
    with open(weights_file, "rb") as fp:
        _ = np.fromfile(fp, dtype=np.int32, count=5)
        weights = np.fromfile(fp, dtype=np.float32)  # np.ndarray
    print('weights_num:', weights.shape[0])
    ptr = 0
    i = 0
    assign_ops = []
    while i < len(var_list) - 1:
        var1 = var_list[i]
        var2 = var_list[i + 1]
        # do something only if we process conv layer
        if 'conv' in var1.name.split('/')[-2]:
            # check type of next layer
            if 'batch_normalization' in var2.name.split('/')[-2]:
                # load batch norm params
                gamma, beta, mean, var = var_list[i + 1:i + 5]
                batch_norm_vars = [beta, gamma, mean, var]
                for vari in batch_norm_vars:
                    shape = vari.shape.as_list()
                    num_params = np.prod(shape)
                    vari_weights = weights[ptr:ptr + num_params].reshape(shape)
                    ptr += num_params
                    assign_ops.append(
                        tf.assign(vari, vari_weights, validate_shape=True))
                i += 4
            elif 'conv' in var2.name.split('/')[-2]:
                # load biases
                bias = var2
                bias_shape = bias.shape.as_list()
                bias_params = np.prod(bias_shape)
                bias_weights = weights[ptr:ptr +
                                           bias_params].reshape(bias_shape)
                ptr += bias_params
                assign_ops.append(
                    tf.assign(bias, bias_weights, validate_shape=True))
                i += 1
            shape = var1.shape.as_list()
            num_params = np.prod(shape)

            var_weights = weights[ptr:ptr + num_params].reshape(
                (shape[3], shape[2], shape[0], shape[1]))
            # remember to transpose to column-major
            var_weights = np.transpose(var_weights, (2, 3, 1, 0))
            ptr += num_params
            assign_ops.append(
                tf.assign(var1, var_weights, validate_shape=True))
            i += 1
    print('ptr:', ptr)
    return assign_ops
    
# with tf.name_scope('input'):
#     input_data = tf.placeholder(dtype=tf.float32,shape=(None, iput_size, iput_size, 3), name='input_data')
# model = YOLOV3(input_data, trainable=False)
# load_ops = load_weights(tf.global_variables(), darknet_weights)
#
# saver = tf.train.Saver(tf.global_variables())
#
# with tf.Session() as sess:
#     sess.run(load_ops)
#     save_path = saver.save(sess, save_path=ckpt_file)
#     print('Model saved in path: {}'.format(save_path))

3.修改from_darknet_weights_to_pb.py脚本,并运行脚本会生成yolov3.pb文件。

import tensorflow as tf
from core.yolov3 import YOLOV3
from from_darknet_weights_to_ckpt import load_weights

input_size = 416
darknet_weights = 'yolov3.weights'      #将此处修改为自己的darknet yolov3.weights路径
pb_file = './yolov3.pb'
output_node_names = ["input/input_data", "pred_sbbox/concat_2", "pred_mbbox/concat_2", "pred_lbbox/concat_2"]

with tf.name_scope('input'):
    input_data = tf.placeholder(dtype=tf.float32, shape=(None, input_size, input_size, 3), name='input_data')
model = YOLOV3(input_data, trainable=False)
load_ops = load_weights(tf.global_variables(), darknet_weights)

with tf.Session() as sess:
    sess.run(load_ops)
    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        tf.get_default_graph().as_graph_def(),
        output_node_names=output_node_names
    )

    with tf.gfile.GFile(pb_file, "wb") as f:
        f.write(output_graph_def.SerializeToString())

    print("{} ops written to {}.".format(len(output_graph_def.node), pb_file))

4.修改image_demo.py文件中pb_file为自己生成的pb模型路径,然后运行脚本。(video_demo.py文件与其操作相同)
在这里插入图片描述

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值