1.数据集标注
ssd训练自己的模型
参考https://blog.youkuaiyun.com/u014696921/article/details/53353896
2.用别的模型进行微调,并根据自己的数据类别调整参数
如果仅仅调整程序参数,这时调用预训练模型是会出错的(从新开始训练不会报错),这是因为预训练模型的类别与调整后的c类别不一样,导致某些层输出张量维度不一样,因此出错,修正方法有两种;
方法一:
不加载这些层的参数:
tf.app.flags.DEFINE_string(
'checkpoint_exclude_scopes', 'ssd_300_vgg/block11_box/conv_cls/biases,ssd_300_vgg/block11_box/conv_cls/weights,'
'ssd_300_vgg/block10_box/conv_cls/biases,ssd_300_vgg/block10_box/conv_cls/weights,'
'ssd_300_vgg/block9_box/conv_cls/biases,ssd_300_vgg/block9_box/conv_cls/weights,'
'ssd_300_vgg/block8_box/conv_cls/biases,ssd_300_vgg/block8_box/conv_cls/weights,'
'ssd_300_vgg/block7_box/conv_cls/biases,ssd_300_vgg/block7_box/conv_cls/weights,'
'ssd_300_vgg/block6_box/conv_cls/biases,ssd_300_vgg/block6_box/conv_cls/weights,'
'ssd_300_vgg/block5_box/conv_cls/biases,ssd_300_vgg/block5_box/conv_cls/weights,'
'ssd_300_vgg/block4_box/conv_cls/biases,ssd_300_vgg/block4_box/conv_cls/weights,'
'ssd_300_vgg/block3_box/conv_cls/biases,ssd_300_vgg/block3_box/conv_cls/weights,'
'ssd_300_vgg/block2_box/conv_cls/biases,ssd_300_vgg/block2_box/conv_cls/weights,'
'ssd_300_vgg/block1_box/conv_cls/biases,ssd_300_vgg/block1_box/conv_cls/weights',
'Comma-separated list of scopes of variables to exclude when restoring '
'from a checkpoint.')
方法二:
修改模型中的参数使其张量维数保持一致:
import os
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
def readcheckpoint(model_dir="../checkpoints/ssd_300_vgg.ckpt"):
# model_dir="../checkpoints/ssd_300_vgg.ckpt" #checkpoint的文件位置
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(model_dir)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and values
for key in var_to_shape_map:
print("tensor_name: ", key) #输出变量名
# print(reader.get_tensor(key)) #输出变量值
print(reader.get_tensor(key).shape)
def savecheckpoint():
ckpt_path="../checkpoints/ssd_300_vgg.ckpt"
with tf.Session() as sess:
for var_name, _ in tf.contrib.framework.list_variables(ckpt_path):
# Load the variable
var = tf.contrib.framework.load_variable(ckpt_path, var_name)
# Set the new name
new_name = var_name
print('Renaming %s to %s.' % (var_name, new_name))
# Rename the variable
# print(var)
if new_name.__contains__('_box/conv_cls/biases'):
if new_name.__contains__('block7_box/conv_cls/biases') or new_name.__contains__('block8_box/conv_cls/biases') or new_name.__contains__('block9_box/conv_cls/biases'):
var=var[0:15*6]
else:
var = var[0:15 * 4]
if new_name.__contains__('conv_cls/weights'):
if new_name.__contains__('block7_box/conv_cls/weights') or new_name.__contains__('block8_box/conv_cls/weights') or new_name.__contains__('block9_box/conv_cls/weights'):
var=var[:,:,:,0:15*6]
else:
var = var[:, :, :, 0:15 * 4]
var = tf.Variable(var, name=new_name)
# Save the variables
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.save(sess, './test.ckpt')
savecheckpoint()
# readcheckpoint()