随手记录一下,还有很多看不明白的地方,并且把源码中的vgg改成了resnet。如果有错误还请指正。
GitHub地址:https://github.com/dBeker/Faster-RCNN-TensorFlow-Python3
直接从训练开始看起,train.py
if __name__ == '__main__':
train = Train()
train.train()
然后跳转到
class Train:
def __init__(self):
# Create network
if cfg.FLAGS.network == 'resnet_v1':
self.net = resnetv1(batch_size=cfg.FLAGS.ims_per_batch, num_layers=101)
#######################
##num_layers = 50,101##
#######################
else:
raise NotImplementedError
self.imdb, self.roidb = combined_roidb("voc_2007_trainval")
self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
self.output_dir = cfg.get_output_dir(self.imdb, 'default')
首先是
self.net = resnetv1(batch_size=cfg.FLAGS.ims_per_batch, num_layers=101)
跳到lib/nets/resnet_v1.py
class resnetv1(Network):
def __init__(self, batch_size=1, num_layers=101):
Network.__init__(self, batch_size=batch_size)
self._num_layers = num_layers
self._resnet_scope = 'resnet_v1_%d' % num_layers
其中的
Network.__init__(self, batch_size=batch_size)
跳转到lib/nets/network.py
class Network(object):
def __init__(self, batch_size=1):
self._feat_stride = [16, ]
self._feat_compress = [1. / 16., ]
self._batch_size = batch_size
self._predictions = {}
self._losses = {}
self._anchor_targets = {}
self._proposal_targets = {}
self._layers = {}
self._act_summaries = []
self._score_summaries = {}
self._train_summaries = []
self._event_summaries = {}
self._variables_to_fix = {}
回到train.py,下一行是
self.imdb, self.roidb = combined_roidb("voc_2007_trainval")
跳转到函数,这里的imdb_names是图像数据库,也就是训练数据的文件夹名字。
def combined_roidb(imdb_names):
"""
Combine multiple roidbs
"""
def get_roidb(imdb_name):
imdb = get_imdb(imdb_name)
print('Loaded dataset `{:s}` for training'.format(imdb.name))
imdb.set_proposal_method("gt")
print('Set proposal method: {:s}'.format("gt")) # ??
roidb = get_training_roidb(imdb)
return roidb
roidbs = [get_roidb(s) for s in imdb_names.split('+')]
roidb = roidbs[0]
if len(roidbs) > 1:
for r in roidbs[1:]:
roidb.extend(r)
tmp = get_imdb(imdb_names.split('+')[1])
imdb = imdb2(imdb_names, tmp.classes)
else:
imdb = get_imdb(imdb_names)
return imdb, roidb
其中get_roidb(imdb_name)的第一行get_imdb(imdb_name)可以跳转到lib/datasets/factory.py
def get_imdb(name):
"""Get an imdb (image database) by name."""
if name not in __sets:
raise KeyError('Unknown dataset: {}'.format(name))
return __sets[name]()
这里会输出两句
Loaded dataset `voc_2007_trainval` for training
Set proposal method: gt
下一行是
roidb = get_training_roidb(imdb)
跳转到,添加翻转的图片和返回roidb
def get_training_roidb(imdb):
"""Returns a roidb (Region of Interest database) for use in training."""
if True:
print('Appending horizontally-flipped training examples...')
imdb.append_flipped_images()
print('done')
print('Preparing training data...')
rdl_roidb.prepare_roidb(imdb)
print('done')
return imdb.roidb
跳到imdb.py的append_flipped_images()中
def append_flipped_images(self):
num_images = self.num_images # 训练的图片数目
widths = self._get_widths()
for i in range(num_images):
boxes = self.roidb[i]['boxes'].copy()
oldx1 = boxes[:, 0].copy()
oldx2 = boxes[:, 2].copy()
boxes[:, 0] = widths[i] - oldx2 - 1
boxes[:, 2] = widths[i] - oldx1 - 1
#防止左上角x和y等于0,减去1之后变65535
for b in range(len(boxes)):
if boxes[b][2] < boxes[b][0]:
boxes[b][0] = 0
assert (boxes[:, 2] >= boxes[:, 0]).all()
entry = {'boxes': boxes,
'gt_overlaps': self.roidb[i]['gt_overlaps'],
'gt_classes': self.roidb[i]['gt_classes'],
'flipped': True}
self.roidb.append(entry)
self._image_index = self._image_index * 2 # 这里的数字和下面的不矛盾吗
到了不明白的地方
def prepare_roidb(imdb):
"""Enrich the imdb's roidb by adding some derived quantities that
are useful for training. This function precomputes the maximum
overlap, taken over ground-truth boxes, between each ROI and
each ground-truth box. The class with maximum overlap is also
recorded.
"""
roidb = imdb.roidb
if not (imdb.name.startswith('coco')):
sizes = [PIL.Image.open(imdb.image_path_at(i)).size
for i in range(imdb.num_images)]
for i in range(len(imdb.image_index)):
roidb[i]['image'] = imdb.image_path_at(i) # 这里添加的'image'部分,但是image_index是不是太少了……
if not (imdb.name.startswith('coco')):
roidb[i]['width'] = sizes[i][0]
roidb[i]['height'] = sizes[i][1]
# need gt_overlaps as a dense array for argmax
gt_overlaps = roidb[i]['gt_overlaps'].toarray()
# max overlap with gt over classes (columns)
max_overlaps = gt_overlaps.max(axis=1)
# gt class that had the max overlap
max_classes = gt_overlaps.argmax(axis=1)
roidb[i]['max_classes'] = max_classes
roidb[i]['max_overlaps'] = max_overlaps
# sanity checks
# max overlap of 0 => class should be zero (background)
zero_inds = np.where(max_overlaps == 0)[0]
assert all(max_classes[zero_inds] == 0)
# max overlap > 0 => class should not be zero (must be a fg class)
nonzero_inds = np.where(max_overlaps > 0)[0]
assert all(max_classes[nonzero_inds] != 0)