目录
1. 下载预训练模型
百度mxnet model zoo下载相应的pre-train model:
http://mxnet.incubator.apache.org/model_zoo/index.html
2. 转换数据格式
把数据转为.rec,可参照官方例子的第一块内容:
http://mxnet.incubator.apache.org/how_to/finetune.html
3. 定义数据读取函数
定义数据迭代器生成函数:
def get_fine_tune_model(model_name):
# load model
symbol, arg_params, aux_params = mx.model.load_checkpoint("data/pre_train/"+model_name, 0)
# model tuning
all_layers = symbol.get_internals()
if model_name=="vgg16":
net = all_layers['drop7_output']
else:
net = all_layers['flatten0_output']
net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='newfc1')
net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
# eliminate weights of new layer
new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})
return (net, new_args,aux_params)
4. 定义模型读取函数
定义pre-train模型读取函数以及模型修改函数
def get_fine_tune_model(model_name):
# load model
symbol, arg_params, aux_params = mx.model.load_checkpoint("data/pre_train/"+model_name, 0)
# model tuning
all_layers = symbol.get_internals()
if model_name=="vgg16":
net = all_layers['drop7_output']
else:
net = all_layers['flatten0_output']