MXNet SSD之multibox_target

本文深入探讨MXNet SSD中multibox_target的实现,包括`multibox_target.h`和`multibox_target.cc`两个关键文件,揭示目标检测算法的核心步骤。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

multibox_target.h

namespace mxnet {
namespace op {

namespace mshadow_op {
struct safe_divide {
  template<typename DType>
  MSHADOW_XINLINE static DType Map(DType a, DType b) {
    if (b == DType(0.0f)) return DType(0.0f);
    return DType(a / b);
  }
};  // struct safe_divide
}  // namespace mshadow_op

namespace mboxtarget_enum {
enum MultiBoxTargetOpInputs {kAnchor, kLabel, kClsPred};
enum MultiBoxTargetOpOutputs {kLoc, kLocMask, kCls};
enum MultiBoxTargetOpResource {kTempSpace};
}  // namespace mboxtarget_enum

struct MultiBoxTargetParam : public dmlc::Parameter<MultiBoxTargetParam> {
  float overlap_threshold;
  float ignore_label;
  float negative_mining_ratio;
  float negative_mining_thresh;
  int minimum_negative_samples;
  nnvm::Tuple<float> variances;
  DMLC_DECLARE_PARAMETER(MultiBoxTargetParam) {
    DMLC_DECLARE_FIELD(overlap_threshold).set_default(0.5f)
    .describe("Anchor-GT overlap threshold to be regarded as a positive match.");
    DMLC_DECLARE_FIELD(ignore_label).set_default(-1.0f)
    .describe("Label for ignored anchors.");
    DMLC_DECLARE_FIELD(negative_mining_ratio).set_default(-1.0f)
    .describe("Max negative to positive samples ratio, use -1 to disable mining");
    DMLC_DECLARE_FIELD(negative_mining_thresh).set_default(0.5f)
    .describe("Threshold used for negative mining.");
    DMLC_DECLARE_FIELD(minimum_negative_samples).set_default(0)
    .describe("Minimum number of negative samples.");
    DMLC_DECLARE_FIELD(variances).set_default({
  
  0.1f, 0.1f, 0.2f, 0.2f})
    .describe("Variances to be encoded in box regression target.");
  }
};  // struct MultiBoxTargetParam

template<typename xpu, typename DType>
class MultiBoxTargetOp : public Operator {
 public:
  explicit MultiBoxTargetOp(MultiBoxTargetParam param) {
    this->param_ = param;
  }

  virtual void Forward(const OpContext &ctx,
                       const std::vector<TBlob> &in_data,
                       const std::vector<OpReqType> &req,
                       const std::vector<TBlob> &out_data,
                       const std::vector<TBlob> &aux_args) {
    using namespace mshadow;
    using namespace mshadow_op;
    using namespace mshadow::expr;
    CHECK_EQ(in_data.size(), 3);
    CHECK_EQ(out_data.size(), 3);
    Stream<xpu> *s = ctx.get_stream<xpu>();
    Tensor<xpu, 2, DType> anchors = in_data[mboxtarget_enum::kAnchor]
      .get_with_shape<xpu, 2, DType>(
      Shape2(in_data[mboxtarget_enum::kAnchor].size(1), 4), s);
    Tensor<xpu, 3, DType> labels = in_data[mboxtarget_enum::kLabel]
      .get<xpu, 3, DType>(s);
    Tensor<xpu, 3, DType> cls_preds = in_data[mboxtarget_enum::kClsPred]
      .get<xpu, 3, DType>(s);
    Tensor<xpu, 2, DType> loc_target = out_data[mboxtarget_enum::kLoc]
      .get<xpu, 2, DType>(s);
    Tensor<xpu, 2, DType> loc_mask = out_data[mboxtarget_enum::kLocMask]
      .get<xpu, 2, DType>(s);
    Tensor<xpu, 2, DType> cls_target = out_data[mboxtarget_enum::kCls]
      .get<xpu, 2, DType>(s);

    index_t num_batches = labels.size(0);
    index_t num_anchors = anchors.size(0);
    index_t num_labels = labels.size(1);
    // TODO(zhreshold): use maximum valid ground-truth in batch rather than # in dataset
    Shape<4> temp_shape = Shape4(11, num_batches, num_anchors, num_labels);
    Tensor<xpu, 4, DType> temp_space = ctx.requested[mboxtarget_enum::kTempSpace]
      .get_space_typed<xpu, 4, DType>(temp_shape, s);
    loc_target = 0.f;
    loc_mask = 0.0f;
    cls_target = param_.ignore_label;
    temp_space = -1.0f;
    CHECK_EQ(anchors.CheckContiguous(), true);
    CHECK_EQ(labels.CheckContiguous(), true);
    CHECK_EQ(cls_preds.CheckContiguous(), true);
    CHECK_EQ(loc_target.CheckContiguous(), true);
    CHECK_EQ(loc_mask.CheckContiguous(), true);
    CHECK_EQ(cls_target.CheckContiguous(), true);
    CHECK_EQ(temp_space.CheckContiguous(), true);

    // compute overlaps
    // TODO(zhreshold): squeeze temporary memory space
    // temp_space, 0:out, 1:l1, 2:t1, 3:r1, 4:b1, 5:l2, 6:t2, 7:r2, 8:b2
    // 9: intersection, 10:union
    temp_space[1] = broadcast_keepdim(broadcast_with_axis(slice<1>(anchors, 0, 1), -1,
      num_batches), 2, num_labels);
    temp_space[2] = broadcast_keepdim(broadcast_with_axis(slice<1>(anchors, 1, 2), -1,
      num_batches), 2, num_labels);
    temp_space[3] = broadcast_keepdim(broadcast_with_axis(slice<1>(anchors, 2, 3), -1,
      num_batches), 2, num_labels);
    temp_space[4] = broadcast_keepdim(broadcast_with_axis(slice<1>(anchors, 3, 4), -1,
      num_batches), 2, num_labels);
    Shape<3> t
`torch gather` 和 MXNet的`F.gather_nd` 都是用来从张量中选取特定索引元素的功能,但在使用上有一些细微差别。 1. **PyTorch (torch)**: `gather`函数主要用于沿着给定维度`dim`获取指定索引处的元素。它接受两个参数,第一个参数是源张量(`input`),第二个参数是一个长度匹配源张量该维度的整数切片(`index`)。例如,如果你有一个三维张量和一维索引,你可以选择每个索引对应的列。输出的张量将与输入的形状相同,除了指定的维度会变为1。 ```python input = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) index = torch.tensor([0, 1]) output = torch.gather(input, dim=1, index=index) # 输出: tensor([[1, 4], # [5, 8]]) ``` 2. **MXNet (F.gather_nd)**: `F.gather_nd`与`torch`的`gather`类似,但它可以处理更高维度的索引,允许你通过多个索引来取出多维张量的元素。这个函数需要一个数据张量`(data)`、一个形状匹配的数据张量`(indices)`作为索引以及一个轴`(axis)`。这个函数返回的是一个由索引指定元素组成的张量,形状是 `(indices.shape[:-1] + data.shape[axis+1:])`。 ```python import mxnet as mx data = mx.nd.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) indices = mx.nd.array([[0, 1], [1, 0]]) # 二维索引 output = mx.nd.F.gather_nd(data, indices) # 输出: [[1, 3], [6, 5]] ``` **区别总结**: - PyTorch的`gather`更适用于单个或低维度索引的情况。 - MXNet的`F.gather_nd`支持高维度或多维索引,适合提取多位置的元素。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值