这里将要解析的是Faster-RCNN Tensorflow版本,fork自githubFaster-RCNN_TF。
1. 背景交代
Faster-RCNN_TF中,网络的训练文件是 Faster-RCNN_TF/tools/train_net.py。我们已经在Faster-RCNN Tensorflow版本源码解析(一)网络训练部分中对该文件进行了源码解析,现在来解析一下该文件中用到的函数。
Faster-RCNN_TF/tools/train_net.py中用到的函数有以下几个:
def parse_args():解析输入参数。 这里的参数指的是,在运行train_net.py这个文件时,需要的输入参数。该函数的定义就在Faster-RCNN_TF/tools/train_net.py中。
get_imdb():加载训练数据。函数get_imdb在Faster-RCNN/lib/datasetes/factory.py中被定义。
get_training_roidb():将训练数据变成minibatch的形式。函数get_training_roidb在Faster-RCNN/lib/fast_rcnn/train.py中被定义。
get_output_dir():设置保存(训练好的模型)的目录。如果该目录没有,会自动新建一个。函数get_output_dir在Faster-RCNN_TF/lib/fast_rcnn/config.py中被定义。
get_network():按照args.network_name获取网络。选择train网络或者test网络。为什么参数args.network_name的值有固定的格式,看函数get_network就知道了。函数get_network在Faster-RCNN_TF/lib/networks/factory.py中被定义。
train_net():启动Faster-RCNN网络训练。函数train_net在Faster-RCNN_TF/lib/fast_rcnn/train.py中被定义
2. 下面来分析一下上面每个函数的源码
2.1. def parse_args():
该函数在Faster-RCNN Tensorflow版本源码解析(一)网络训练部分中已经进行了解析。
2.2. get_imdb():
作用:加载训练数据。
定义位置:该函数在Faster-RCNN/lib/datasetes/factory.py中被定义。
文件factory.py是个工厂类,用类生成imdb类并且返回数据库供网络训练和测试使用。
factory.py 源码如下:
# coding=utf-8 #有中文注释的时候,记得加上这个
# --------------------------------------------------------