tensorflow2解析模型参数

本文详细介绍如何在TensorFlow2中构建与保存CNN模型,并演示如何加载已保存的模型,解析其参数,适用于模型剪枝、量化等后续操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

tensorflow2 解析模型参数

introduction

CNN模型是计算机视觉里面常用的工具了,训模师训好模型可能还需要其他的操作,比如可能做剪枝,或者量化,需要对模型的参数做一些操作。这时候就需要解析模型的参数了。这篇文章主要叙述一下,在tensorflow2 下怎么解析模型参数。

step by step

1. 先构建并保存一个模型

tensorflow2 构建模型还是首选keras接口, 在模型保存方面有好几个接口可以选择,触类旁通, 此处我习惯使用tf.save_model。

class PlainCNN(tf.keras.Model):
    def __init__(self,
                 kernel_initializer='glorot_normal'):
        super(PlainCNN, self).__init__()


        self.conv1= tf.keras.layers.Conv2D( filters=32,
                                           kernel_size=(3, 3),
                                           strides=2,
                                           padding='same',
                                           use_bias=False,
                                           kernel_initializer=kernel_initializer)

        self.conv2 = tf.keras.layers.Conv2D(filters=64,
                                            kernel_size=(3, 3),
                                            strides=2,
                                            padding='same',
                                            use_bias=False,
                                            kernel_initializer=kernel_initializer)

        self.conv3 = tf.keras.layers.Conv2D(filters=128,
                                            kernel_size=(3, 3),
                                            strides=2,
                                            padding='same',
                                            use_bias=False,
                                            kernel_initializer=kernel_initializer)
    @tf.function(input_signature=[tf.TensorSpec([None, None, None, 3], tf.float32)])
    def call(self, inputs, training=False):
        x1 = self.conv1(inputs)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        return x3
import numpy as np
model = PlainCNN()
image=np.zeros(shape=(1,48,48,3),dtype=np.float32)
x=model(image)

# 将会打印出模型的信息:
# the result shape is : (1, 6, 6, 128)
print('the result shape is :', x.shape)
##保存模型为 tmp_model
tf.saved_model.save(model,'tmp_model')

此时我们有了tmp_model 模型目录

2. 加载模型
model=tf.saved_model.load('./tmp_model')
###这时model可以理解为上面的代码,可以直接inference, 也可以依次获取每个变量的variables,####或者trainable_variables,等等
print(model.conv1.variables)

输出实例
‘ListWrapper([<tf.Variable ‘conv2d/kernel:0’ shape=(3, 3, 3, 32) dtype=float32, numpy=
array([[[[ 8.40592012e-02, -1.82091370e-02, 6.67047426e-02,
-1.57765336e-02, 5.53220101e-02, 6.69927672e-02,
1.18893180e-02, -3.55695821e-02, 2.96269413e-02,
6.24449886e-02, 8.93794820e-02, 1.48759335e-01,
-6.80063153e-03, -2.44185757e-02, -1.68685019e-01,


然后就可以想做什么就做什么了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值