tensorflow预训练权重导入
1. 代码
def load(data_path, session):
"""
load the VGG16_pretrain parameters file
:param data_path:
:param session:
:return:
"""
data_dict = np.load(data_path, encoding='latin1',allow_pickle=True).item()
keys = sorted(data_dict.keys())
for key in keys:
with tf.variable_scope(key, reuse=True):
for subkey, data in zip(('weights', 'biases'), data_dict[key]):
session.run(tf.get_variable(subkey).assign(data))
def load_with_skip(data_path, session, skip_layer