tensorflow fine_tune已训练模型的部分参数

本文详细介绍了在TensorFlow框架下如何实现迁移学习,包括加载预训练模型的部分参数到新网络,处理不同namescope下的参数加载,以及如何固定加载进来的参数。提供了具体代码示例,帮助读者理解并实践迁移学习。

由于很多时候我们在一个新的网络中只会用到一个已训练模型的部分参数,即迁移学习。
那么,如何加载已训练模型的部分参数到当前网络。

一、当前网络加载已训练模型相同name scope的变量

方法1. 手动构建与预训练一样的部分图

将需要fine tune的变量的name scope命名为与模型中的name scope相同,然后使用如下代码将模型参数加载到当前网络。

tf.train.Saver([var for var in tf.global_variables() if var.name.startswith('train')]) \
            .restore(sess,' D:\\tuxiang\\hhh\\my_test_model-1')

注意:相同的name scope的变量的值是一样的,因此加载模型的参数就会覆盖已初始化的参数,实现fine tune的目的。

二、当前网络与已训练模型的name scope不一致的情况

例如:将训练模型name scope ‘a’的值赋给当前网络的name scope ‘m/b/t’。

方法1:改写已保存模型的name scope,使其与目标变量的name scope一致;或者将需要加载到当前网络的参数选取出来写入一个新的模型,然后直接加载就ok.

name scope改写代码,也可进行name scope删选,剔除不需要加载的变量。

import os
import tensorflow as tf
import numpy as np

# 新模型的存储地址
new_checkpoint_path = 'D:\\tuxiang\\hhh\\h\\'
# 旧模型的存储地址
checkpoint_path = 'D:\\tuxiang\\hhh\\my_test_model-1'
# 添加的name scope
add_prefix  = 'main/'
if not os.path.exists(new_checkpoint_path):
    os.makedirs(new_checkpoint_path)
with tf.Session() as sess:
    new_var_list = []  # 新建一个空列表存储更新后的Variable变量
    for var_name, _ in tf.contrib.framework.list_variables(checkpoint_path):  # 得到checkpoint文件中所有的参数(名字,形状)元组
        var = tf.contrib.framework.load_variable(checkpoint_path, var_name)  # 得到上述参数的值
        # var_name为变量的name scope,是一个字符串,可以进行改写
        # var 是该name scope对应的值
        print(var_name,var)
        new_name = var_name
        new_name = add_prefix + new_name  # 在这里加入了名称前缀,大家可以自由地作修改
        # 除了修改参数名称,还可以修改参数值(var)
        print('Renaming %s to %s.' % (var_name, new_name))
        renamed_var = tf.Variable(var, name=new_name)  # 使用加入前缀的新名称重新构造了参数
        new_var_list.append(renamed_var)  # 把赋予新名称的参数加入空列表
    print('starting to write new checkpoint !')
    saver = tf.train.Saver(var_list=new_var_list)  # 构造一个保存器
    sess.run(tf.global_variables_initializer())  # 初始化一下参数(这一步必做)
    model_name = 'deeplab_resnet_altered'  # 构造一个保存的模型名称
    checkpoint_path = os.path.join(new_checkpoint_path, model_name)  # 构造一下保存路径
    saver.save(sess, checkpoint_path)  # 直接进行保存
    print("done !")
方法2:直接通过tf.train.import_meta_graph()和saver.restore()将模型的所有参数加载到当前图中,然后再使用 sess.run(name scope)和sess.run(tf.assign())取出模型所在的name scope的数值赋给网络中需要fine tune的变量.

使用此方法会将预训练模型的所有参数和图加载进来并在保存的时候与当前网络一起保存,使参数更加庞大,必须在训练结束后定义需要保存的变量,避免保存所有参数。

with tf.variable_scope('train/x'):
    w1 = tf.get_variable('w1', shape = [2])  
    w2 = tf.get_variable( name='w2',shape=[2])  
    w3 = tf.get_variable( name='w3',shape=[2])
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) 
    print(sess.run('train/x/w3:0'))
    saver =tf.train.import_meta_graph('D:\\tuxiang\\hhh\\my_test_model-1.meta')
    saver.restore(sess,'D:\\tuxiang\\hhh\\my_test_model-1')
    data = sess.run('train/w1:0')
    
    print(data)  
    sess.run(tf.assign(w1,data))
    print(sess.run('train/x/w1:0'))

保存的参数为:

train/w1 [-0.22390778  1.227742  ]
train/w2 [-0.8072208  -0.2457586  -1.0661628   0.65901166  0.4100374 ]
train/x/w1 [-0.22390778  1.227742  ]
train/x/w2 [ 0.9039186  -0.33169258]
train/x/w3 [-0.88651645  1.0805331 ]

通过在声明tf.train.Saver类时可以提供一个变量列表来指定需要保存或加载的变量可解决上述问题。

new idea

其实,此方法可以直接通过以下代码来获得所有参数名称,然后选取我们需要fine tune的参数名称并提取相应的值,通过sess.run(tf.assign())的方法便可fine tune,由于没导入预训练模型的图,所以不存在上述问题。

import os
import tensorflow as tf
import numpy as np
checkpoint_path = 'D:\\tuxiang\\hhh\\hh1\\my_test_model'
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_path):  
# 得到checkpoint文件中所有的参数(名字,形状)元组
    var = tf.contrib.framework.load_variable(checkpoint_path, var_name)  # 得到上述参数的值
    print(var_name,var)

三、固定加载进来的参数

对于某些fine tune的参数,如果希望将这些参数固定住,即不训练,可以通过在定义变量的时候设置trainable=False;也可以在训练过程中的minimize()中添加需要训练的参数,则未添加的参数会固定住。

用var_list = tf.contrib.framework.get_variables(scope_name)获取指定scope_name下的变量,
然后optimizer.minimize()时传入指定var_list即可。

train_op = tf.train.GradientDescentOptimizer.minimize(loss,var_list=var_list)

参考文献:

1.tensorflow 选择性fine-tune(微调)
2.输出TensorFlow 模型的变量名和对应值
3.tensorflow从已经训练好的模型中,恢复(指定)权重(构建新变量、网络)并继续训练(finetuning)
4.Tensorflow Finetune方法

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

还是少年呀

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值