Faster R-CNN 读源码(tensorflow版)

随手记录一下,还有很多看不明白的地方,并且把源码中的vgg改成了resnet。如果有错误还请指正。

GitHub地址:https://github.com/dBeker/Faster-RCNN-TensorFlow-Python3

 

直接从训练开始看起,train.py

if __name__ == '__main__':
    train = Train()
    train.train()

然后跳转到

class Train:
    def __init__(self):

        # Create network
        if cfg.FLAGS.network == 'resnet_v1':
            self.net = resnetv1(batch_size=cfg.FLAGS.ims_per_batch, num_layers=101)
            #######################
            ##num_layers = 50,101##
            #######################
        else:
            raise NotImplementedError

        self.imdb, self.roidb = combined_roidb("voc_2007_trainval")

        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.output_dir = cfg.get_output_dir(self.imdb, 'default')

首先是

self.net = resnetv1(batch_size=cfg.FLAGS.ims_per_batch, num_layers=101)

跳到lib/nets/resnet_v1.py

class resnetv1(Network):
  def __init__(self, batch_size=1, num_layers=101):
    Network.__init__(self, batch_size=batch_size)
    self._num_layers = num_layers
    self._resnet_scope = 'resnet_v1_%d' % num_layers

其中的

Network.__init__(self, batch_size=batch_size)

跳转到lib/nets/network.py

class Network(object):
    def __init__(self, batch_size=1):
        self._feat_stride = [16, ]
        self._feat_compress = [1. / 16., ]
        self._batch_size = batch_size
        self._predictions = {}
        self._losses = {}
        self._anchor_targets = {}
        self._proposal_targets = {}
        self._layers = {}
        self._act_summaries = []
        self._score_summaries = {}
        self._train_summaries = []
        self._event_summaries = {}
        self._variables_to_fix = {}

回到train.py,下一行是

self.imdb, self.roidb = combined_roidb("voc_2007_trainval")

跳转到函数,这里的imdb_names是图像数据库,也就是训练数据的文件夹名字。

def combined_roidb(imdb_names):
    """
    Combine multiple roidbs
    """

    def get_roidb(imdb_name):
        imdb = get_imdb(imdb_name)
        print('Loaded dataset `{:s}` for training'.format(imdb.name))
        imdb.set_proposal_method("gt")
        print('Set proposal method: {:s}'.format("gt")) # ??
        roidb = get_training_roidb(imdb)
        return roidb

    roidbs = [get_roidb(s) for s in imdb_names.split('+')]
    roidb = roidbs[0]
    if len(roidbs) > 1:
        for r in roidbs[1:]:
            roidb.extend(r)
        tmp = get_imdb(imdb_names.split('+')[1])
        imdb = imdb2(imdb_names, tmp.classes)
    else:
        imdb = get_imdb(imdb_names)
    return imdb, roidb

其中get_roidb(imdb_name)的第一行get_imdb(imdb_name)可以跳转到lib/datasets/factory.py

def get_imdb(name):
  """Get an imdb (image database) by name."""
  if name not in __sets:
    raise KeyError('Unknown dataset: {}'.format(name))
  return __sets[name]()

这里会输出两句

Loaded dataset `voc_2007_trainval` for training
Set proposal method: gt

下一行是

roidb = get_training_roidb(imdb)

跳转到,添加翻转的图片和返回roidb

def get_training_roidb(imdb):
    """Returns a roidb (Region of Interest database) for use in training."""
    if True:
        print('Appending horizontally-flipped training examples...')
        imdb.append_flipped_images()
        print('done')

    print('Preparing training data...')
    rdl_roidb.prepare_roidb(imdb)
    print('done')

    return imdb.roidb

跳到imdb.py的append_flipped_images()中

    def append_flipped_images(self):
        num_images = self.num_images # 训练的图片数目
        widths = self._get_widths()
        for i in range(num_images):
            boxes = self.roidb[i]['boxes'].copy()
            oldx1 = boxes[:, 0].copy()
            oldx2 = boxes[:, 2].copy()
            boxes[:, 0] = widths[i] - oldx2 - 1
            boxes[:, 2] = widths[i] - oldx1 - 1
            #防止左上角x和y等于0,减去1之后变65535
            for b in range(len(boxes)):
                if boxes[b][2] < boxes[b][0]:
                    boxes[b][0] = 0

            assert (boxes[:, 2] >= boxes[:, 0]).all()
            entry = {'boxes': boxes,
                     'gt_overlaps': self.roidb[i]['gt_overlaps'],
                     'gt_classes': self.roidb[i]['gt_classes'],
                     'flipped': True}
            self.roidb.append(entry)
        self._image_index = self._image_index * 2  # 这里的数字和下面的不矛盾吗

到了不明白的地方

def prepare_roidb(imdb):
  """Enrich the imdb's roidb by adding some derived quantities that
  are useful for training. This function precomputes the maximum
  overlap, taken over ground-truth boxes, between each ROI and
  each ground-truth box. The class with maximum overlap is also
  recorded.
  """
  roidb = imdb.roidb
  if not (imdb.name.startswith('coco')):
    sizes = [PIL.Image.open(imdb.image_path_at(i)).size
         for i in range(imdb.num_images)]
  for i in range(len(imdb.image_index)):
    roidb[i]['image'] = imdb.image_path_at(i)  # 这里添加的'image'部分,但是image_index是不是太少了……
    if not (imdb.name.startswith('coco')):
      roidb[i]['width'] = sizes[i][0]
      roidb[i]['height'] = sizes[i][1]
    # need gt_overlaps as a dense array for argmax
    gt_overlaps = roidb[i]['gt_overlaps'].toarray()
    # max overlap with gt over classes (columns)
    max_overlaps = gt_overlaps.max(axis=1)
    # gt class that had the max overlap
    max_classes = gt_overlaps.argmax(axis=1)
    roidb[i]['max_classes'] = max_classes
    roidb[i]['max_overlaps'] = max_overlaps
    # sanity checks
    # max overlap of 0 => class should be zero (background)
    zero_inds = np.where(max_overlaps == 0)[0]
    assert all(max_classes[zero_inds] == 0)
    # max overlap > 0 => class should not be zero (must be a fg class)
    nonzero_inds = np.where(max_overlaps > 0)[0]
    assert all(max_classes[nonzero_inds] != 0)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值