Faster-RCNN Tensorflow版本源码解析(二)train_net.py所用到的函数

本文深入解析Faster-RCNN Tensorflow版本的train_net.py中关键函数,包括parse_args(), get_imdb(), get_training_roidb()和get_output_dir()。这些函数分别负责参数解析、训练数据加载、数据转换为minibatch和模型保存路径设置。通过理解这些函数,有助于更好地理解和使用Faster-RCNN训练流程。" 137841724,22793133,Redis 中存储与操作 JSON 数据,"['Redis', 'JSON', '数据库']

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

这里将要解析的是Faster-RCNN Tensorflow版本,fork自githubFaster-RCNN_TF

1. 背景交代

Faster-RCNN_TF中,网络的训练文件是 Faster-RCNN_TF/tools/train_net.py。我们已经在Faster-RCNN Tensorflow版本源码解析(一)网络训练部分中对该文件进行了源码解析,现在来解析一下该文件中用到的函数。

Faster-RCNN_TF/tools/train_net.py中用到的函数有以下几个:

def parse_args():解析输入参数。 这里的参数指的是,在运行train_net.py这个文件时,需要的输入参数。该函数的定义就在Faster-RCNN_TF/tools/train_net.py中。
get_imdb():加载训练数据。函数get_imdb在Faster-RCNN/lib/datasetes/factory.py中被定义。
get_training_roidb():将训练数据变成minibatch的形式。函数get_training_roidb在Faster-RCNN/lib/fast_rcnn/train.py中被定义。
get_output_dir():设置保存(训练好的模型)的目录。如果该目录没有,会自动新建一个。函数get_output_dir在Faster-RCNN_TF/lib/fast_rcnn/config.py中被定义。
get_network():按照args.network_name获取网络。选择train网络或者test网络。为什么参数args.network_name的值有固定的格式,看函数get_network就知道了。函数get_network在Faster-RCNN_TF/lib/networks/factory.py中被定义。
train_net():启动Faster-RCNN网络训练。函数train_net在Faster-RCNN_TF/lib/fast_rcnn/train.py中被定义

2. 下面来分析一下上面每个函数的源码
2.1. def parse_args():

该函数在Faster-RCNN Tensorflow版本源码解析(一)网络训练部分中已经进行了解析。

2.2. get_imdb():

作用:加载训练数据。
定义位置:该函数在Faster-RCNN/lib/datasetes/factory.py中被定义。
文件factory.py是个工厂类,用类生成imdb类并且返回数据库供网络训练和测试使用。
factory.py 源码如下:

# coding=utf-8 #有中文注释的时候,记得加上这个
# --------------------------------------------------------
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值