keras faster rcnn中train.py解读

本文详细解读了Keras FasterRCNN中的train.py文件,介绍了如何使用命令行选项进行模型配置,包括训练数据路径、解析器选择、ROI数量、网络类型等。文章还讲解了数据增强、模型构建、权重加载及编译过程,以及训练流程,包括损失计算、预测框筛选和分类器训练。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

keras faster rcnn中train.py解读

学习笔记,以备注形式,将持续更新。

from __future__ import division
import random
import pprint
import sys
import time
import numpy as np
from optparse import OptionParser #optparse是专门用来在命令行添加选项的一个模块。
import pickle

from keras import backend as K
from keras.optimizers import Adam, SGD, RMSprop
from keras.layers import Input
from keras.models import Model
from keras_frcnn import config, data_generators
from keras_frcnn import losses as losses
import keras_frcnn.roi_helpers as roi_helpers
from keras.utils import generic_utils

sys.setrecursionlimit(40000) #设置最大深度为4000,作用是防止无限递归导致堆栈溢出崩溃

parser = OptionParser() #optparse是专门用来在命令行添加选项的一个模块。实例化对象parser
#下面的代码通过parser.add_option() 作用于对象parser 
#主要参数有:dest 变量名(字典的键) default 变量的值(字典的值) help 存储帮助信息 
parser.add_option("-p", "--path", dest="train_path", help="Path to training data.") # -p在指令中用于指定训练文件的路径
parser.add_option("-o", "--parser", dest="parser", help="Parser to use. One of simple or pascal_voc",   #simple为一个解析器
				default="pascal_voc")
parser.add_option("-n", "--num_rois", type="int", dest="num_rois", help="Number of RoIs to process at once.", default=32)
parser.add_option("--network", dest="network", help="Base network to use. Supports vgg or resnet50.", default='resnet50') #定义训练网络,默认为resnet50,但是也可以是voc
parser.add_option("--hf", dest="horizontal_flips", help="Augment with horizontal flips in training. (Default=false).", action="store_true", default=False) #在训练中增加水平翻转
parser.add_option("--vf", dest="vertical_flips", help="Augment with vertical flips in training. (Default=false).", action="store_true", default=False) #在训练中增加垂直翻转
parser.add_option("--rot", "--rot_90", dest="rot_90", help="Augment with 90 degree rotations in training. (Default=false).",                  #在训练中增加90度旋转
				  action="store_true", default=False)
parser.add_option("--num_epochs", type="int", dest="num_epochs", help="Number of epochs.", default=20)  #定义训练20轮
parser.add_option("--config_filename", dest="config_filename", help=
				"Location to store all the metadata related to the training (to be used when testing).",
				default="config.pickle")  #存储与训练相关的所有元数据的位置(在测试时使用) config_为配置文件
parser.add_option("--output_weight_path", dest="output_weight_path", help="Output path for weights.", default='./model_frcnn.hdf5') #权重的输出路径,在训练结束后,根目录下会生成一个hdf5文件
parser.add_option("--input_weight_path", dest="input_weight_path", help="Input path for weights. If not specified, will try to load default weights provided by keras.") #权重的输入路径。 如果未指定,将尝试加载keras提供的默认权重

(options, args) = parser.parse_args() #parser.parse_args() 可以获取键值对组成的字典,列表[]这里没涉及,可以添加的

if not options.train_path:   # if filename is not given
	parser.error('Error: path to training data must be specified. Pass --path to command line')

if options.parser == 'pascal_voc': #采用voc解析器
	from keras_frcnn.pascal_voc_parser import get_data
elif options.parser == 'simple':   #采用simple解析器
	from keras_frcnn.simple_parser import get_data
else:
	raise ValueError("Command line option parser must be one of 'pascal_voc' or 'simple'")

# pass the settings from the command line, and persist them in the config object
#从命令行传递设置,并将它们保留在配置对象中
C = config.Config()

C.use_horizontal_flips = bool(options.horizontal_flips)
C.use_vertical_flips = bool(options.vertical_flips)
C.rot_90 = bool(options.rot_90)

C.model_path = options.output_weight_path
C.num_rois = int(options.num_rois)

if options.network == 'vgg':    #选用的网络有vgg和resnet50两种网络
	C.network = 'vgg'
	from keras_frcnn import vgg as nn
elif options.network == 'resnet50':
	from keras_frcnn import resnet as nn
	C.network = 'resnet50'
else:
	print('Not a valid model')
	raise ValueError


# check if weight path was passed via command line 检查权重路径是否通过命令行传递
if options.input_weight_path:
	C.base_net_weights = options.input_weight_path
else:
	# set the path to weights based on backend and model 根据后端和模型设置权重路径
	C.base_net_weights = nn.get_weight_path()

all_imgs, classes_count, class_mapping = get_data(options.train_path) #所有的图片,分类数量,位置信息都通过训练路径进行导入

if 'bg' not in classes_count: #如果分类中没有背景这个类别,则加入背景这个类别
	classes_count['bg'] = 0
	class_mapping['bg'] = len(class_mapping)

C.class_mapping = class_mapping

inv_map = {v: k for k, v in class_mapping.items()} #将class_mapping中的key和value对调

print('Training images per class:')
pprint.pprint(classes_count)
print('Num classes (including bg) = {}'.format(len(classes_count)))

config_output_filename = options.config_filename

with open(config_output_filename, 'wb') as config_f:
	pickle.dump(C,config_f)
	print('Config has been written to {}, and can be loaded when testing to ensure correct results'.format(config_output_filename))

random.shuffle(all_imgs) #Shuffle描述着数据从map task输出到reduce task输入的这段过程。shuffle是连接Map和Reduce之间的桥梁

num_imgs = len(all_imgs)
# 将all_imgs分为训练集和测试集
train_imgs = [s for s in all_imgs if s['imageset'] == 'trainval']
val_imgs = [s for s in all_imgs if s['imageset'] == 'test']

print('Num train samples {}'.format(len(train_imgs)))
print('Num val samples {}'.format(len(val_imgs)))

# 生成anchor
data_gen_train = data_generators.get_anchor_gt(train_imgs, classes_count, C, nn.get_img_output_length, K.image_dim_ordering(), mode='train')
data_gen_val = data_generators.get_anchor_gt(val_imgs, classes_count, C, nn.get_img_output_length,K.image_dim_ordering(), mode='val')
#查看后端是th还是tf,纠正输入方式
if K.image_dim_ordering() == 'th':
	input_shape_img = (3, None, None)
else:
	input_shape_img = (None, None, 3)

img_input = Input(shape=input_shape_img)
roi_input = Input(shape=(None, 4))

# define the base network (resnet here, can be VGG, Inception, etc
#定义基础网络(在这是resnet,也可以是VGG,inception等)
shared_layers = nn.nn_base(img_input, trainable=True)

# define the RPN, built on the base layers基于基础层构建RPN网络
num_anchors = len(C.anchor_box_scales) * len(C.anchor_box_ratios) #获取anchor的个数,3种面积大小,3种比例的框,(1:1  1:2   2:1)
rpn = nn.rpn(shared_layers, num_anchors)#定义RPN层,return [x_class, x_regr, base_layers](rpn层做一次分类,区分是前景还是背景。也作回归,选择合适的框)

classifier = nn.classifier(shared_layers, roi_input, C.num_rois, nb_classes=len(classes_count), trainable=True)
#定义rpn模型的输入和输出一个框2分类(最后使用的sigmod而不是softmax)和框的回归
model_rpn = Model(img_input, rpn[:2])
model_classifier = Model([img_input, roi_input], classifier) ##定义classifier的输入和输出

# this is a model that holds both the RPN and the classifier, used to load/save weights for the models
#这是一个包含RPN和分类器的模型,用于加载/保存模型的权重
model_all = Model([img_input, roi_input], rpn[:2] + classifier)
####加载预训练模型参数####
try:
	print('loading weights from {}'.format(C.base_net_weights))
	model_rpn.load_weights(C.base_net_weights, by_name=True)
	model_classifier.load_weights(C.base_net_weights, by_name=True)
except:
	print('Could not load pretrained model weights. Weights can be found in the keras application folder \
		https://github.com/fchollet/keras/tree/master/keras/applications')
###编译模型###
optimizer = Adam(lr=1e-5)
optimizer_classifier = Adam(lr=1e-5)
model_rpn.compile(optimizer=optimizer, loss=[losses.rpn_loss_cls(num_anchors), losses.rpn_loss_regr(num_anchors)]) #输出RPN网络对应的两个loss值
model_classifier.compile(optimizer=optimizer_classifier, loss=[losses.class_loss_cls, losses.class_loss_regr(len(classes_count)-1)], metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
model_all.compile(optimizer='sgd', loss='mae') #model_all的建立和编译是为了最后方便保存整体的权重
####训练#######
epoch_length = 1000    #1000张图片
num_epochs = int(options.num_epochs)    #原函数定义为2000轮。
iter_num = 0

losses = np.zeros((epoch_length, 5))
rpn_accuracy_rpn_monitor = []
rpn_accuracy_for_epoch = []
start_time = time.time()

best_loss = np.Inf

class_mapping_inv = {v: k for k, v in class_mapping.items()}
print('Starting training')

vis = True

for epoch_num in range(num_epochs): #部分训练参数配置
    #每一个epoch(默认共2000个epochs)都进行下面的操作。
	progbar = generic_utils.Progbar(epoch_length)
	print('Epoch {}/{}'.format(epoch_num + 1, num_epochs))

	while True:
		try:
            #生成当前epoch的进度条,输出提示信息
			if len(rpn_accuracy_rpn_monitor) == epoch_length and C.verbose:
				mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor))/len(rpn_accuracy_rpn_monitor)
				rpn_accuracy_rpn_monitor = []
				print('Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'.format(mean_overlapping_bboxes, epoch_length))
				if mean_overlapping_bboxes == 0:
					print('RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.')
            #利用生成器生成数据。X、Y、img_data分别对应yield np.copy(x_img), [np.copy(y_rpn_cls), np.copy(y_rpn_regr)], img_data_aug中的三个值
			X, Y, img_data = next(data_gen_train) 

			loss_rpn = model_rpn.train_on_batch(X, Y)

			P_rpn = model_rpn.predict_on_batch(X)

			R = roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], C, K.image_dim_ordering(), use_regr=True, overlap_thresh=0.7, max_boxes=300)
			# note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
			#X2是存selected predict boxes 坐标信息[x1, y1, w, h]。shape: (1, selected_boxes_num, 4).
			#Y1是对应selected box的物体类别标签
			#Y2的shape: (1, selected_boxes_num, 160)。 (selected_boxes_num, :80)表示每个物体类的回归参数
			X2, Y1, Y2, IouS = roi_helpers.calc_iou(R, img_data, C, class_mapping)
            #calc_iou()未筛选出满足条件的predicted box,就会返回None。将当前记录RPN网络的准确度记为0(epoch length最大设为1000,rpn_accuracy_rpn_monitor列表的最大长度也是1000),continue进行下一张图片的训练。
			if X2 is None:
				rpn_accuracy_rpn_monitor.append(0)
				rpn_accuracy_for_epoch.append(0)
				continue

			neg_samples = np.where(Y1[0, :, -1] == 1)  #Y1[0, :, -1]表示取最后一列(表示背景)
			pos_samples = np.where(Y1[0, :, -1] == 0)

			if len(neg_samples) > 0:
				neg_samples = neg_samples[0]
			else:
				neg_samples = []

			if len(pos_samples) > 0:
				pos_samples = pos_samples[0]
			else:
				pos_samples = []
			
			rpn_accuracy_rpn_monitor.append(len(pos_samples))
			rpn_accuracy_for_epoch.append((len(pos_samples)))

			if C.num_rois > 1:
				if len(pos_samples) < C.num_rois//2:
					selected_pos_samples = pos_samples.tolist()
				else:
					selected_pos_samples = np.random.choice(pos_samples, C.num_rois//2, replace=False).tolist()
				try:
					selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=False).tolist()
				except:
					selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=True).tolist()

				sel_samples = selected_pos_samples + selected_neg_samples
			else:
				# in the extreme case where num_rois = 1, we pick a random pos or neg sample
				selected_pos_samples = pos_samples.tolist()
				selected_neg_samples = neg_samples.tolist()
				if np.random.randint(0, 2):
					sel_samples = random.choice(neg_samples)
				else:
					sel_samples = random.choice(pos_samples)

			loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]], [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

			losses[iter_num, 0] = loss_rpn[1]
			losses[iter_num, 1] = loss_rpn[2]

			losses[iter_num, 2] = loss_class[1]
			losses[iter_num, 3] = loss_class[2]
			losses[iter_num, 4] = loss_class[3]

			iter_num += 1

			progbar.update(iter_num, [('rpn_cls', np.mean(losses[:iter_num, 0])), ('rpn_regr', np.mean(losses[:iter_num, 1])),
									  ('detector_cls', np.mean(losses[:iter_num, 2])), ('detector_regr', np.mean(losses[:iter_num, 3]))])
              #现在RPN网络和Classifier网络都各训练完一次,这里统计他们的loss并更新进度条显示。
			if iter_num == epoch_length:
				loss_rpn_cls = np.mean(losses[:, 0])
				loss_rpn_regr = np.mean(losses[:, 1])
				loss_class_cls = np.mean(losses[:, 2])
				loss_class_regr = np.mean(losses[:, 3])
				class_acc = np.mean(losses[:, 4])

				mean_overlapping_bboxes = float(sum(rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
				rpn_accuracy_for_epoch = []

				if C.verbose:
					print('Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'.format(mean_overlapping_bboxes))
					print('Classifier accuracy for bounding boxes from RPN: {}'.format(class_acc))
					print('Loss RPN classifier: {}'.format(loss_rpn_cls))
					print('Loss RPN regression: {}'.format(loss_rpn_regr))
					print('Loss Detector classifier: {}'.format(loss_class_cls))
					print('Loss Detector regression: {}'.format(loss_class_regr))
					print('Elapsed time: {}'.format(time.time() - start_time))

				curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
				iter_num = 0
				start_time = time.time()

				if curr_loss < best_loss:
					if C.verbose:
						print('Total loss decreased from {} to {}, saving weights'.format(best_loss,curr_loss))
					best_loss = curr_loss
					model_all.save_weights(C.model_path) #保存模型权重到指定路径,类型是HDF5(.h5)

				break

		except Exception as e:
			print('Exception: {}'.format(e))
			continue

print('Training complete, exiting.')
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小俊俊的博客

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值