此前我们学习了FaceNet源码的使用方法,但主要是基于train_softmax.py的,而中心损失在样本量大于一定值时实验效果会大打折扣,于是我们考虑train_tripletloss.py的一些使用方法。
1. 修改部分
与train_softmax.py略有不同,我需要在
saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3)
前添加
all_vars = tf.trainable_variables()
var_to_restore = [v for v in all_vars if not v.name.startswith('Logits')]
saver_restore = tf.train.Saver(var_to_restore)
而不是直接用后者代替前者。
与train_softmax.py类似,我需要修改
print('Restoring pretrained model: %s' % pretrained_model)
saver.restore(sess, pretrained_model)
为
print('Restoring pretrained model: %s' % args.pretrained_model)
model_exp = os.path.expanduser(args.pretrained_model)
_,ckpt_file = facenet.get_model_filenames(model_exp)
saver_restore.restore(sess,os.path.join(model_exp,ckpt_file))
之后我们就可以愉快的训练自己的数据了,另外说一下,为什么train_tripletloss.py的修改方法与train_softmax.py不同,是因为直接按train_softmax.py的修改方法进行训练后,加载模型会提示“模型中没有batch_size”从而导致模型无法正常加载。
2. 训练与预训练部分
首先cd到src目录下
训练时于cmd中输入
python train_tripletloss.py --models_base_dir facenet --model_def models.inception_resnet_v1 --data_dir data\CASIA_WebFace_182 --image_size 160 --optimizer RMSPROP --learning_rate -1 --max_nrof_epochs 20 --keep_probability 0.8 --random_crop --random_flip --learning_rate_schedule_file data\learning_rate_schedule_classifier_casia.txt --weight_decay 5e-5 --alpha 0.1 --gpu_memory_fraction 0.9
有预训练模型时于cmd中输入
python train_tripletloss.py --models_base_dir facenet --pretrained_model pretrained_model\20170512-110547 --model_def models.inception_resnet_v1 --data_dir data\CASIA_WebFace_182 --image_size 160 --optimizer RMSPROP --learning_rate -1 --max_nrof_epochs 20 --keep_probability 0.8 --random_crop --random_flip --learning_rate_schedule_file data\learning_rate_schedule_classifier_casia.txt --weight_decay 5e-5 --alpha 0.1 --gpu_memory_fraction 0.9
即可
3. 基于train_tripletloss后的模型调用compare文件
使用tensorboard观察使用train_tripletloss.py后的模型就会发现,一些节点的命名方式与原本并不相同,于是我们需要做如下修改,将
images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
改为
images_placeholder = tf.get_default_graph().get_tensor_by_name("batch_join:0")
之后我们就可以调用使用三元损失训练好的模型了
首先cd到src目录下,然后于cmd输入
python compare.py 训练好的模型地址 data\test