Faster Rcnn 源码解析(三)—— bbox_transform.py

简介:

这个代码里面主要是一些在anchor_targte_layer.pyproposals_layers.py中使用到的一些函数,比较简单,主要是帮助以上两个代码理解。

源码:

# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

import numpy as np
#计算与anchor有最大IOU的GT的偏移量
#ex_rois:表示anchor;gt_rois:表示GT
def bbox_transform(ex_rois, gt_rois):
    #得到anchor的(x,y,w,h)
    ex_widths = ex_rois[:, 2] - ex_rois[:, 0] + 1.0
    ex_heights = ex_rois[:, 3] - ex_rois[:, 1] + 1.0
    ex_ctr_x = ex_rois[:, 0] + 0.5 * ex_widths
    ex_ctr_y = ex_rois[:, 1] + 0.5 * ex_heights
    # 得到GT的(x,y,w,h)
    gt_widths = gt_rois[:, 2] - gt_rois[:, 0] + 1.0
    gt_heights = gt_rois[:, 3] - gt_rois[:, 1] + 1.0
    gt_ctr_x = gt_rois[:, 0] + 0.5 * gt_widths
    gt_ctr_y = gt_rois[:, 1] + 0.5 * gt_heights
    #按照损失函数中的计算公式,计算,得到对应的偏移量
    targets_dx = (gt_ctr_x - ex_ctr_x) / ex_widths
    targets_dy = (gt_ctr_y - ex_ctr_y) / ex_heights
    targets_dw = np.log(gt_widths / ex_widths)
    targets_dh = np.log(gt_heights / ex_heights)

    targets = np.vstack(
        (targets_dx, targets_dy, targets_dw, targets_dh)).transpose()
    return targets
#根据anchor和偏移量计算proposals
def bbox_transform_inv(boxes, deltas):
    if boxes.shape[0] == 0:
        return np.zeros((0, deltas.shape[1]), dtype=deltas.dtype)

    boxes = boxes.astype(deltas.dtype, copy=False)#转换数据类型,使得二者一致

    #将anchor还原为(x,y,w,h)的格式
    widths = boxes[:, 2] - boxes[:, 0] + 1.0
    heights = boxes[:, 3] - boxes[:, 1] + 1.0
    ctr_x = boxes[:, 0] + 0.5 * widths
    ctr_y = boxes[:, 1] + 0.5 * heights
    #得到(x,y,w,h)方向上的偏移量
    dx = deltas[:, 0::4]
    dy = deltas[:, 1::4]
    dw = deltas[:, 2::4]
    dh = deltas[:, 3::4]

    pred_ctr_x = dx * widths[:, np.newaxis] + ctr_x[:, np.newaxis]#np.newaxis,表示将widths增加一维,使得其能够相加
    pred_ctr_y = dy * heights[:, np.newaxis] + ctr_y[:, np.newaxis]
    pred_w = np.exp(dw) * widths[:, np.newaxis]
    pred_h = np.exp(dh) * heights[:, np.newaxis]

    pred_boxes = np.zeros(deltas.shape, dtype=deltas.dtype)
    #最后返回的是左上和右下顶点的坐标[x1,y1,x2,y2]。
    # x1
    pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w
    # y1
    pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h
    # x2
    pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w
    # y2
    pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h

    return pred_boxes
# 将proposals的边界限制在图片内
# 调用格式 proposals = clip_boxes(proposals, im_info[:2])
def clip_boxes(boxes, im_shape):
    """
    Clip boxes to image boundaries.
    """

    # x1 >= 0
    boxes[:, 0::4] = np.maximum(np.minimum(boxes[:, 0::4], im_shape[1] - 1), 0)
    # y1 >= 0
    boxes[:, 1::4] = np.maximum(np.minimum(boxes[:, 1::4], im_shape[0] - 1), 0)
    # x2 < im_shape[1]
    boxes[:, 2::4] = np.maximum(np.minimum(boxes[:, 2::4], im_shape[1] - 1), 0)
    # y2 < im_shape[0]
    boxes[:, 3::4] = np.maximum(np.minimum(boxes[:, 3::4], im_shape[0] - 1), 0)
    return boxes

以下是使用PyTorch实现的YOLOv5和Faster R-CNN的mAP对比图代码: ```python import torch import torchvision import argparse import utils import os # 设置参数 parser = argparse.ArgumentParser() parser.add_argument('--data', type=str, default='coco', help='数据集名称') parser.add_argument('--weights-yolo', type=str, default='yolov5s.pt', help='YOLOv5模型权重路径') parser.add_argument('--weights-frcnn', type=str, default='fasterrcnn_resnet50_fpn_coco.pth', help='Faster R-CNN模型权重路径') parser.add_argument('--iou-thres', type=float, default=0.5, help='IoU阈值') parser.add_argument('--conf-thres', type=float, default=0.001, help='置信度阈值') parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='输入图像尺寸') args = parser.parse_args() # 加载数据集 if args.data == 'coco': test_set = torchvision.datasets.CocoDetection(root='./data/coco', annFile='./data/coco/annotations/instances_val2017.json') num_classes = 80 elif args.data == 'voc': test_set = torchvision.datasets.VOCDetection(root='./data/voc', image_set='val', transform=None, target_transform=None, download=True) num_classes = 20 else: raise ValueError('未知数据集名称') # 创建YOLOv5模型 yolo_model = torch.hub.load('ultralytics/yolov5', 'custom', path=args.weights_yolo, source='local') yolo_model.eval() yolo_model.cuda() # 创建Faster R-CNN模型 frcnn_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, num_classes=num_classes) frcnn_model.load_state_dict(torch.load(args.weights_frcnn)) frcnn_model.eval() frcnn_model.cuda() # 计算YOLOv5的AP yolo_results = [] for idx in range(len(test_set)): image, target = test_set[idx] detections = yolo_model(image.unsqueeze(0).cuda(), img_size=args.img_size, conf_thres=args.conf_thres, iou_thres=args.iou_thres) for detection in detections: if detection is not None: for x1, y1, x2, y2, conf, cls in detection: yolo_results.append({'image_id': idx, 'category_id': cls.item(), 'bbox': [x1.item(), y1.item(), (x2-x1).item(), (y2-y1).item()], 'score': conf.item()}) yolo_eval = utils.evaluate(yolo_results, test_set.coco) print('YOLOv5 mAP: {:.3f}'.format(yolo_eval.stats[0])) # 计算Faster R-CNN的AP frcnn_results = [] for idx in range(len(test_set)): image, target = test_set[idx] detections = frcnn_model([image.cuda()]) for detection in detections: for box, conf, cls in zip(detection['boxes'], detection['scores'], detection['labels']): frcnn_results.append({'image_id': idx, 'category_id': cls.item(), 'bbox': [box[0].item(), box[1].item(), (box[2]-box[0]).item(), (box[3]-box[1]).item()], 'score': conf.item()}) frcnn_eval = utils.evaluate(frcnn_results, test_set.coco) print('Faster R-CNN mAP: {:.3f}'.format(frcnn_eval.stats[0])) # 画出AP对比图 utils.plot_results([yolo_eval, frcnn_eval], names=['YOLOv5', 'Faster R-CNN'], save_dir=os.path.join('.', args.data+'_map.png')) ``` 其中,`utils`是一个自定义的工具函数模块,包含了`evaluate`和`plot_results`函数。`evaluate`函数用于计算AP,`plot_results`函数用于画出AP对比图。这两个函数的实现可以参考[这个GitHub仓库](https://github.com/ultralytics/yolov5/blob/master/utils/general.py)
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值