《深入解析 MMDetection 中的高斯目标生成函数》
一、生成二维高斯核函数:gaussian2D
-
参数设置:
- 设定
radius = 1
,这个参数决定了生成的二维高斯核的范围大小。
- 设定
-
生成横坐标和纵坐标序列:
x = torch.arange(-radius, radius + 1).view(1, -1)
生成了横坐标序列,输出为tensor([[-1, 0, 1]])
。这表示在[-1, 1]
范围内的整数序列作为横坐标,形状为(1, 3)
。y = torch.arange(-radius, radius + 1).view(-1, 1)
生成了纵坐标序列,输出为tensor([[-1],[0],[1]])
。这表示同样在[-1, 1]
范围内的整数序列作为纵坐标,形状为(3, 1)
。
-
计算直径和标准差:
diameter = 2 * radius + 1
计算出直径为 3。sigma = diameter / 6
,所以标准差为 0.5。
-
生成二维高斯核:
h = (-(x * x + y * y) / (2 * sigma * sigma)).exp()
使用二维高斯分布的公式变体形式计算二维高斯核。- 以坐标
(0, 0)
为例,计算过程为(e{-\frac{02 + 0^2}{2 * 0.5 * 0.5}} = 1.0000)。 - 对所有的坐标组合进行这样的计算,得到最终的二维高斯核输出为
tensor([[0.0183, 0.1353, 0.0183],[0.1353, 1.0000, 0.1353],[0.0183, 0.1353, 0.0183]])
,形状为(3, 3)
。
- 以坐标
数据模拟:
radius = 1
x = torch.arange(
-radius, radius + 1).view(1, -1)
print(x)
y = torch.arange(
-radius, radius + 1).view(-1, 1)
print(y)
diameter = 2 * radius + 1
sigma = diameter / 6
h = (-(x * x + y * y) / (2 * sigma * sigma)).exp()
print(h)
tensor([[-1, 0, 1]])
tensor([[-1],
[ 0],
[ 1]])
tensor([[0.0183, 0.1353, 0.0183],
[0.1353, 1.0000, 0.1353],
[0.0183, 0.1353, 0.0183]])
二、生成二维高斯热图函数:gen_gaussian_target
该函数用于在给定的热图上生成以指定中心为中心的二维高斯热图。
-
参数解释:
heatmap
:输入的热图,高斯核将覆盖在这个热图上并保持最大值。center
:高斯核中心的坐标列表。radius
:高斯核的半径。k
:高斯核的系数,默认为 1。
-
执行过程:
- 首先,根据半径计算出直径,并调用
gaussian2D
函数生成高斯核。 - 然后,根据中心坐标和热图的形状,确定需要覆盖的热图区域以及对应的高斯核区域。
- 最后,通过
torch.max
函数将覆盖区域的热图与高斯核进行比较,取最大值更新热图。
- 首先,根据半径计算出直径,并调用
3.1 函数功能
这个函数的目的是在给定的输入热图上,以指定的中心坐标为中心生成一个二维高斯热图,并确保在覆盖区域内热图的值保持为最大值。
3.2 参数解释
heatmap
:输入的热图张量,形状通常为[batch, channels, height, width]
,二维高斯核将覆盖在这个热图上,并且在覆盖区域内热图的值将被更新为最大值。center
:一个包含两个整数的列表,表示高斯核中心的坐标,例如[x, y]
。radius
:高斯核的半径,决定了高斯分布的范围。k
:高斯核的系数,默认为 1,用于调整高斯核的强度。
3.3 函数执行过程
-
计算高斯核的直径:
diameter = 2 * radius + 1
:根据给定的半径计算高斯核的直径。
-
生成二维高斯核:
gaussian_kernel = gaussian2D(radius, sigma=diameter / 6, dtype=heatmap.dtype, device=heatmap.device)
:调用gaussian2D
函数生成二维高斯核,其中sigma
设置为直径的六分之一,数据类型和设备与输入热图保持一致。
-
提取中心坐标和热图尺寸:
x, y = center
:从中心坐标列表中提取横坐标和纵坐标。height, width = heatmap.shape[:2]
:获取输入热图的高度和宽度。
-
确定覆盖区域的边界:
left, right = min(x, radius), min(width - x, radius + 1)
:计算覆盖区域在水平方向上的左右边界,确保不超出热图的范围。top, bottom = min(y, radius), min(height - y, radius + 1)
:计算覆盖区域在垂直方向上的上下边界。
-
提取覆盖区域的热图和高斯核:
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
:从输入热图中提取覆盖区域的热图。masked_gaussian = gaussian_kernel[radius - top:radius + bottom, radius - left:radius + right]
:从生成的高斯核中提取对应的覆盖区域。
-
更新热图:
out_heatmap = heatmap
:首先将输出热图初始化为输入热图。torch.max(masked_heatmap, masked_gaussian * k, out=out_heatmap[y - top:y + bottom, x - left:x + right])
:在覆盖区域内,比较原始热图和高斯核乘以系数k
的值,取最大值更新输出热图的对应区域。
-
返回更新后的热图:
return out_heatmap
:返回更新后的热图,其中覆盖区域被二维高斯热图更新为最大值。
模拟数据:
# 1. gaussian2D()
radius = 1
x = torch.arange(
-radius, radius + 1).view(1, -1)
print(x)
y = torch.arange(
-radius, radius + 1).view(-1, 1)
print(y)
diameter = 2 * radius + 1
sigma = diameter / 6
h = (-(x * x + y * y) / (2 * sigma * sigma)).exp()
print(h)
h[h < torch.finfo(h.dtype).eps * h.max()] = 0
print(h)
# 2. gen_gaussian_target()
# 初始化数据
gaussian_kernel = h
center = [2, 2]
heatmap = torch.tensor(torch.rand((5, 5))) # 这里手动获取后两个唯独:宽度和高度。 因为之前做了[left_idx, top_idx]
print("heatmap:",heatmap)
# 开始
x, y = center
height, width = heatmap.shape[:2]
print("height, width:",height, width)
print("x:",x)
print("y",y)
print("radius:",radius)
left, right = min(x, radius), min(width - x, radius + 1)
top, bottom = min(y, radius), min(height - y, radius + 1)
print("left, right, top, bottom:",left, right, top, bottom)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
print("masked_heatmap:",masked_heatmap)
masked_gaussian = gaussian_kernel[radius - top:radius + bottom,
radius - left:radius + right]
print("masked_gaussian:",masked_gaussian)
k=2
out_heatmap = heatmap
print("out_heatmap:",out_heatmap)
print("masked_gaussian * k:",masked_gaussian * k)
torch.max(
masked_heatmap,
masked_gaussian * k,
out=out_heatmap[y - top:y + bottom, x - left:x + right])
print(out_heatmap)
tensor([[-1, 0, 1]])
tensor([[-1],
[ 0],
[ 1]])
tensor([[0.0183, 0.1353, 0.0183],
[0.1353, 1.0000, 0.1353],
[0.0183, 0.1353, 0.0183]])
tensor([[0.0183, 0.1353, 0.0183],
[0.1353, 1.0000, 0.1353],
[0.0183, 0.1353, 0.0183]])
heatmap: tensor([[0.6840, 0.0176, 0.3079, 0.2870, 0.3737],
[0.3751, 0.2407, 0.3384, 0.3107, 0.3177],
[0.3313, 0.2069, 0.6418, 0.0436, 0.5058],
[0.2378, 0.3770, 0.2741, 0.0111, 0.4656],
[0.5823, 0.7769, 0.0964, 0.5200, 0.0987]])
height, width: 5 5
x: 2
y 2
radius: 1
left, right, top, bottom: 1 2 1 2
masked_heatmap: tensor([[0.2407, 0.3384, 0.3107],
[0.2069, 0.6418, 0.0436],
[0.3770, 0.2741, 0.0111]])
masked_gaussian: tensor([[0.0183, 0.1353, 0.0183],
[0.1353, 1.0000, 0.1353],
[0.0183, 0.1353, 0.0183]])
out_heatmap: tensor([[0.6840, 0.0176, 0.3079, 0.2870, 0.3737],
[0.3751, 0.2407, 0.3384, 0.3107, 0.3177],
[0.3313, 0.2069, 0.6418, 0.0436, 0.5058],
[0.2378, 0.3770, 0.2741, 0.0111, 0.4656],
[0.5823, 0.7769, 0.0964, 0.5200, 0.0987]])
masked_gaussian * k: tensor([[0.0366, 0.2707, 0.0366],
[0.2707, 2.0000, 0.2707],
[0.0366, 0.2707, 0.0366]])
tensor([[0.6840, 0.0176, 0.3079, 0.2870, 0.3737],
[0.3751, 0.2407, 0.3384, 0.3107, 0.3177],
[0.3313, 0.2707, 2.0000, 0.2707, 0.5058],
[0.2378, 0.3770, 0.2741, 0.0366, 0.4656],
[0.5823, 0.7769, 0.0964, 0.5200, 0.0987]])
解析:
-
gaussian_kernel:
tensor([[0.0183, 0.1353, 0.0183],
[0.1353, 1.0000, 0.1353],
[0.0183, 0.1353, 0.0183]]) -
根据left, right, top, bottom: 1 2 1 2
左上角坐标[1,1],
右下角坐标[2,2] -
heatmap:
tensor([[0.6840, 0.0176, 0.3079, 0.2870, 0.3737],
[0.3751, 0.2407, 0.3384, 0.3107, 0.3177],
[0.3313, 0.2069, 0.6418, 0.0436, 0.5058],
[0.2378, 0.3770, 0.2741, 0.0111, 0.4656],
[0.5823, 0.7769, 0.0964, 0.5200, 0.0987]]) -
masked_heatmap:
tensor(
[[0.2407, 0.3384, 0.3107],
[0.2069, 0.6418, 0.0436],
[0.3770, 0.2741, 0.0111]]) -
masked_gaussian
tensor(
[[0.0183, 0.1353, 0.0183],
[0.1353, 1.0000, 0.1353],
[0.0183, 0.1353, 0.0183]]) -
masked_gaussian * k:
tensor(
[[0.0366, 0.2707, 0.0366],
[0.2707, 2.0000, 0.2707],
[0.0366, 0.2707, 0.0366]])
out_heatmap: (两者取最大)
tensor([[0.6840, 0.0176, 0.3079, 0.2870, 0.3737],
[0.3751, 0.2407, 0.3384, 0.3107, 0.3177],
[0.3313, 0.2707, 2.0000, 0.2707, 0.5058],
[0.2378, 0.3770, 0.2741, 0.0366, 0.4656],
[0.5823, 0.7769, 0.0964, 0.5200, 0.0987]])
三、计算二维高斯半径函数:gaussian_radius
这个函数用于根据目标的尺寸和最小重叠要求计算二维高斯半径。
-
参数解释:
det_size
:目标的尺寸,通常是一个包含高度和宽度的列表。min_overlap
:最小重叠比例,用于确定生成的框与真实框之间的最小交并比(IoU)要求。
-
执行过程:
- 考虑了三种情况来计算高斯半径:
- 情况 1:一个角点在真实框内,另一个角点在框外。
- 情况 2:两个角点都在真实框内。
- 情况 3:两个角点都在真实框外。
- 对于每种情况,通过推导得出二次方程的形式,根据 Vieta 公式计算半径。
- 最后,返回三种情况下计算出的半径中的最小值作为最终的高斯半径。
- 考虑了三种情况来计算高斯半径:
3.1 注释翻译
“Generate 2D gaussian radius.(生成二维高斯半径。)”
“This function is modified from the `official github repo(这个函数是从官方 GitHub 仓库修改而来,https://github.com/princeton-vl/CornerNet-Lite/blob/master/core/sample/utils.py#L65。)”
“Given min_overlap
, radius could computed by a quadratic equation according to Vieta’s formulas.(给定最小重叠度min_overlap
,可以根据韦达公式通过二次方程计算半径。)”
“There are 3 cases for computing gaussian radius, details are following:(有三种情况用于计算高斯半径,细节如下:)”
“- Explanation of figure: lt
and br
indicates the left-top and bottom-right corner of ground truth box. x
indicates the generated corner at the limited position when radius=r
.(图形解释:lt
和br
表示真实框的左上角和右下角。x
表示当半径为r
时在受限位置生成的角点。)”
“- Case1: one corner is inside the gt box and the other is outside.(情况 1:一个角点在真实框内,另一个角点在框外。)”
“To ensure IoU of generated box and gt box is larger than min_overlap
:(为确保生成的框和真实框的交并比大于min_overlap
:)”
“- Case2: both two corners are inside the gt box.(情况 2:两个角点都在真实框内。)”
“To ensure IoU of generated box and gt box is larger than min_overlap
:(为确保生成的框和真实框的交并比大于min_overlap
:)”
“- Case3: both two corners are outside the gt box.(情况 3:两个角点都在真实框外。)”
“To ensure IoU of generated box and gt box is larger than min_overlap
:(为确保生成的框和真实框的交并比大于min_overlap
:)”
3.2 公式
3.2.1 情况 1:一个角点在真实框内,另一个角点在框外
为确保生成的框和真实框的交并比大于 min_overlap
,推导过程如下:
- 交并比(IoU)公式:
- I o U = ( w − r ) ⋅ ( h − r ) w ⋅ h + ( w + h ) ⋅ r − r 2 ≥ m i n _ o v e r l a p IoU = \frac{(w - r) \cdot (h - r)}{w \cdot h + (w + h) \cdot r - r^2} \geq min\_overlap IoU=w⋅h+(w+h)⋅r−r2(w−r)⋅(h−r)≥min_overlap,其中 w w w 是宽度, h h h 是高度, r r r 是高斯半径。
- 推导:
- ( w − r ) ⋅ ( h − r ) w ⋅ h + ( w + h ) ⋅ r − r 2 ≥ m i n _ o v e r l a p \frac{(w - r) \cdot (h - r)}{w \cdot h + (w + h) \cdot r - r^2} \geq min\_overlap w⋅h+(w+h)⋅r−r2(w−r)⋅(h−r)≥min_overlap,变形得到 r 2 − ( w + h ) ⋅ r + 1 − m i n _ o v e r l a p 1 + m i n _ o v e r l a p ⋅ w ⋅ h ≥ 0 r^2 - (w + h) \cdot r + \frac{1 - min\_overlap}{1 + min\_overlap} \cdot w \cdot h \geq 0 r2−(w+h)⋅r+1+min_overlap1−min_overlap⋅w⋅h≥0。
- 令 a = 1 a = 1 a=1, b = − ( w + h ) b = -(w + h) b=−(w+h), c = 1 − m i n _ o v e r l a p 1 + m i n _ o v e r l a p ⋅ w ⋅ h c = \frac{1 - min\_overlap}{1 + min\_overlap} \cdot w \cdot h c=1+min_overlap1−min_overlap⋅w⋅h。
- 最后得到半径 r r r 的计算公式为 r ≤ − b − b 2 − 4 ⋅ a ⋅ c 2 ⋅ a r \leq \frac{-b - \sqrt{b^2 - 4 \cdot a \cdot c}}{2 \cdot a} r≤2⋅a−b−b2−4⋅a⋅c。
3.2.2 情况 2:两个角点都在真实框内
为确保生成的框和真实框的交并比大于 min_overlap
,推导过程如下:
- 交并比(IoU)公式:
- I o U = ( w − 2 ⋅ r ) ⋅ ( h − 2 ⋅ r ) w ⋅ h ≥ m i n _ o v e r l a p IoU = \frac{(w - 2 \cdot r) \cdot (h - 2 \cdot r)}{w \cdot h} \geq min\_overlap IoU=w⋅h(w−2⋅r)⋅(h−2⋅r)≥min_overlap。
- 推导:
- ( w − 2 ⋅ r ) ⋅ ( h − 2 ⋅ r ) w ⋅ h ≥ m i n _ o v e r l a p \frac{(w - 2 \cdot r) \cdot (h - 2 \cdot r)}{w \cdot h} \geq min\_overlap w⋅h(w−2⋅r)⋅(h−2⋅r)≥min_overlap,变形得到 4 ⋅ r 2 − 2 ⋅ ( w + h ) ⋅ r + ( 1 − m i n _ o v e r l a p ) ⋅ w ⋅ h ≥ 0 4 \cdot r^2 - 2 \cdot (w + h) \cdot r + (1 - min\_overlap) \cdot w \cdot h \geq 0 4⋅r2−2⋅(w+h)⋅r+(1−min_overlap)⋅w⋅h≥0。
- 令 a = 4 a = 4 a=4, b = − 2 ⋅ ( w + h ) b = -2 \cdot (w + h) b=−2⋅(w+h), c = ( 1 − m i n _ o v e r l a p ) ⋅ w ⋅ h c = (1 - min\_overlap) \cdot w \cdot h c=(1−min_overlap)⋅w⋅h。
- 半径 r r r 的计算公式为 r ≤ − b − b 2 − 4 ⋅ a ⋅ c 2 ⋅ a r \leq \frac{-b - \sqrt{b^2 - 4 \cdot a \cdot c}}{2 \cdot a} r≤2⋅a−b−b2−4⋅a⋅c。
3.2.3 情况 3:两个角点都在真实框外
为确保生成的框和真实框的交并比大于 min_overlap
,推导过程如下:
- 交并比(IoU)公式:
- I o U = w ⋅ h ( w + 2 ⋅ r ) ⋅ ( h + 2 ⋅ r ) ≥ m i n _ o v e r l a p IoU = \frac{w \cdot h}{(w + 2 \cdot r) \cdot (h + 2 \cdot r)} \geq min\_overlap IoU=(w+2⋅r)⋅(h+2⋅r)w⋅h≥min_overlap,变形得到 4 ⋅ m i n _ o v e r l a p ⋅ r 2 + 2 ⋅ m i n _ o v e r l a p ⋅ ( w + h ) ⋅ r + ( m i n _ o v e r l a p − 1 ) ⋅ w ⋅ h ≤ 0 4 \cdot min\_overlap \cdot r^2 + 2 \cdot min\_overlap \cdot (w + h) \cdot r + (min\_overlap - 1) \cdot w \cdot h \leq 0 4⋅min_overlap⋅r2+2⋅min_overlap⋅(w+h)⋅r+(min_overlap−1)⋅w⋅h≤0。
- 令:
- a = 4 ⋅ m i n _ o v e r l a p a = 4 \cdot min\_overlap a=4⋅min_overlap, b = 2 ⋅ m i n _ o v e r l a p ⋅ ( w + h ) b = 2 \cdot min\_overlap \cdot (w + h) b=2⋅min_overlap⋅(w+h), c = ( m i n _ o v e r l a p − 1 ) ⋅ w ⋅ h c = (min\_overlap - 1) \cdot w \cdot h c=(min_overlap−1)⋅w⋅h。
- 半径
r
r
r 的计算公式:
- r ≤ − b + b 2 − 4 ⋅ a ⋅ c 2 ⋅ a r \leq \frac{-b + \sqrt{b^2 - 4 \cdot a \cdot c}}{2 \cdot a} r≤2⋅a−b+b2−4⋅a⋅c。
3.3 公式解析
结合 get_targets
方法解释 gaussian_radius
方法中的公式元素意义
在 gaussian_radius
方法中,主要是为了计算二维高斯半径,这个半径的计算是基于三种不同的情况,每种情况对应一个特定的公式。以下分别解释这些公式中元素的意义:
3.3.1 情况 1:一个角点在真实框内,另一个角点在框外
-
交并比(IoU)公式:
- I o U = ( w − r ) ⋅ ( h − r ) w ⋅ h + ( w + h ) ⋅ r − r 2 ≥ m i n _ o v e r l a p IoU = \frac{(w - r) \cdot (h - r)}{w \cdot h + (w + h) \cdot r - r^2} \geq min\_overlap IoU=w⋅h+(w+h)⋅r−r2(w−r)⋅(h−r)≥min_overlap
- 元素意义:
-
w
w
w 和
h
h
h:在
get_targets
方法中,这通常对应目标在特征图上的宽度和高度。例如,当计算高斯热图和偏移等目标时,会用到目标在特征图上的尺寸信息。 - r r r:待求的高斯半径。
-
m
i
n
_
o
v
e
r
l
a
p
min\_overlap
min_overlap:在
get_targets
方法中,这个参数通常表示生成的高斯热图所对应的框与真实框之间的最小交并比要求。在生成角点目标时,为了确保生成的热图能够准确地反映目标的位置,需要保证一定的交并比。
-
w
w
w 和
h
h
h:在
-
变形后的公式:
- r 2 − ( w + h ) ⋅ r + 1 − m i n _ o v e r l a p 1 + m i n _ o v e r l a p ⋅ w ⋅ h ≥ 0 r^2 - (w + h) \cdot r + \frac{1 - min\_overlap}{1 + min\_overlap} \cdot w \cdot h \geq 0 r2−(w+h)⋅r+1+min_overlap1−min_overlap⋅w⋅h≥0
- 元素意义:
- a = 1 a = 1 a=1:二次方程的二次项系数。
- b = − ( w + h ) b = -(w + h) b=−(w+h):一次项系数,这里的 w + h w + h w+h 表示目标的宽度和高度之和,反映了目标的整体尺寸对半径计算的影响。
- c = 1 − m i n _ o v e r l a p 1 + m i n _ o v e r l a p ⋅ w ⋅ h c = \frac{1 - min\_overlap}{1 + min\_overlap} \cdot w \cdot h c=1+min_overlap1−min_overlap⋅w⋅h:常数项,其中 1 − m i n _ o v e r l a p 1 + m i n _ o v e r l a p \frac{1 - min\_overlap}{1 + min\_overlap} 1+min_overlap1−min_overlap 是根据最小交并比计算得到的系数, w ⋅ h w \cdot h w⋅h 是目标的面积,综合反映了交并比和目标尺寸对半径计算的影响。
3.3.2 情况 2:两个角点都在真实框内
-
交并比(IoU)公式:
- I o U = ( w − 2 ⋅ r ) ⋅ ( h − 2 ⋅ r ) w ⋅ h ≥ m i n _ o v e r l a p IoU = \frac{(w - 2 \cdot r) \cdot (h - 2 \cdot r)}{w \cdot h} \geq min\_overlap IoU=w⋅h(w−2⋅r)⋅(h−2⋅r)≥min_overlap
- 元素意义:
- w w w 和 h h h:同样,这是目标在特征图上的宽度和高度。
- r r r:高斯半径。
- m i n _ o v e r l a p min\_overlap min_overlap:最小交并比要求。
-
变形后的公式:
- 4 ⋅ r 2 − 2 ⋅ ( w + h ) ⋅ r + ( 1 − m i n _ o v e r l a p ) ⋅ w ⋅ h ≥ 0 4 \cdot r^2 - 2 \cdot (w + h) \cdot r + (1 - min\_overlap) \cdot w \cdot h \geq 0 4⋅r2−2⋅(w+h)⋅r+(1−min_overlap)⋅w⋅h≥0
- 元素意义:
- a = 4 a = 4 a=4:二次方程的二次项系数。
- b = − 2 ⋅ ( w + h ) b = -2 \cdot (w + h) b=−2⋅(w+h):一次项系数,这里的 − 2 ⋅ ( w + h ) -2 \cdot (w + h) −2⋅(w+h) 表示当两个角点都在框内时,宽度和高度分别减去两倍半径后对半径计算的影响。
- c = ( 1 − m i n _ o v e r l a p ) ⋅ w ⋅ h c = (1 - min\_overlap) \cdot w \cdot h c=(1−min_overlap)⋅w⋅h:常数项,其中 ( 1 − m i n _ o v e r l a p ) (1 - min\_overlap) (1−min_overlap) 是根据最小交并比计算得到的系数, w ⋅ h w \cdot h w⋅h 是目标的面积,反映了交并比和目标尺寸在这种情况下对半径计算的影响。
3.3.3 情况 3:两个角点都在真实框外
-
交并比(IoU)公式:
- I o U = w ⋅ h ( w + 2 ⋅ r ) ⋅ ( h + 2 ⋅ r ) ≥ m i n _ o v e r l a p IoU = \frac{w \cdot h}{(w + 2 \cdot r) \cdot (h + 2 \cdot r)} \geq min\_overlap IoU=(w+2⋅r)⋅(h+2⋅r)w⋅h≥min_overlap
- 元素意义:
- w w w 和 h h h:这是目标在特征图上的宽度和高度。
- r r r:高斯半径。
- m i n _ o v e r l a p min\_overlap min_overlap:最小交并比要求。
-
变形后的公式:
- 4 ⋅ m i n _ o v e r l a p ⋅ r 2 + 2 ⋅ m i n _ o v e r l a p ⋅ ( w + h ) ⋅ r + ( m i n _ o v e r l a p − 1 ) ⋅ w ⋅ h ≤ 0 4 \cdot min\_overlap \cdot r^2 + 2 \cdot min\_overlap \cdot (w + h) \cdot r + (min\_overlap - 1) \cdot w \cdot h \leq 0 4⋅min_overlap⋅r2+2⋅min_overlap⋅(w+h)⋅r+(min_overlap−1)⋅w⋅h≤0
- 元素意义:
- a = 4 ⋅ m i n _ o v e r l a p a = 4 \cdot min\_overlap a=4⋅min_overlap:二次方程的二次项系数,这里的 4 ⋅ m i n _ o v e r l a p 4 \cdot min\_overlap 4⋅min_overlap 表示在两个角点都在框外的情况下,交并比对二次项系数的影响。
- b = 2 ⋅ m i n _ o v e r l a p ⋅ ( w + h ) b = 2 \cdot min\_overlap \cdot (w + h) b=2⋅min_overlap⋅(w+h):一次项系数,反映了交并比和目标尺寸在这种情况下对半径计算的影响。
- c = ( m i n _ o v e r l a p − 1 ) ⋅ w ⋅ h c = (min\_overlap - 1) \cdot w \cdot h c=(min_overlap−1)⋅w⋅h:常数项,其中 ( m i n _ o v e r l a p − 1 ) (min\_overlap - 1) (min_overlap−1) 是根据最小交并比计算得到的系数, w ⋅ h w \cdot h w⋅h 是目标的面积,综合体现了交并比和目标尺寸在两个角点都在框外时对半径计算的影响。
3.4 处理步骤解析
-
参数获取:
- 这个函数接受两个参数,
det_size
是一个包含目标形状的列表,通常是目标的高度和宽度;min_overlap
是由高斯核内的关键点生成的框与真实框的最小交并比(Intersection over Union,IoU)。
- 这个函数接受两个参数,
-
提取高度和宽度:
height, width = det_size
:从输入的目标形状列表中提取出高度和宽度。
-
计算情况 1 的系数和半径:
- 根据公式计算情况 1 的系数
a1
、b1
和c1
。 a1 = 1
:对应二次方程的二次项系数。b1 = (height + width)
:一次项系数。c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
:常数项。- 计算判别式
sq1 = sqrt(b1**2 - 4 * a1 * c1)
。 - 计算半径
r1 = (b1 - sq1) / (2 * a1)
。
- 根据公式计算情况 1 的系数
-
计算情况 2 的系数和半径:
- 同样根据公式计算情况 2 的系数
a2
、b2
和c2
。 a2 = 4
。b2 = 2 * (height + width)
。c2 = (1 - min_overlap) * width * height
。- 计算判别式
sq2 = sqrt(b2**2 - 4 * a2 * c2)
。 - 计算半径
r2 = (b2 - sq2) / (2 * a2)
。
- 同样根据公式计算情况 2 的系数
-
计算情况 3 的系数和半径:
- 按照公式计算情况 3 的系数
a3
、b3
和c3
。 a3 = 4 * min_overlap
。b3 = -2 * min_overlap * (height + width)
。c3 = (min_overlap - 1) * width * height
。- 计算判别式
sq3 = sqrt(b3**2 - 4 * a3 * c3)
。 - 计算半径
r3 = (b3 + sq3) / (2 * a3)
。
- 按照公式计算情况 3 的系数
-
返回最小半径:
return min(r1, r2, r3)
:最后,函数返回三种情况下计算出的半径中的最小值,作为最终的高斯核半径。
3.5 举例解析计算过程
假设我们有一个目标,其在特征图上的尺寸为det_size = (height=10, width=8)
,并且我们设定最小重叠比例min_overlap = 0.5
。
-
首先计算情况 1 的半径:
- 根据公式计算系数:
a1 = 1
。b1 = height + width = 10 + 8 = 18
。c1 = width * height * (1 - min_overlap) / (1 + min_overlap) = 8 * 10 * (1 - 0.5) / (1 + 0.5) = 80 * 0.5 / 1.5 = 26.67
(近似值)。
- 计算判别式:
sq1 = sqrt(b1**2 - 4 * a1 * c1) = sqrt(18**2 - 4 * 1 * 26.67) = sqrt(324 - 106.68) = sqrt(217.32) = 14.74
(近似值)。
- 计算半径:
r1 = (b1 - sq1) / (2 * a1) = (18 - 14.74) / 2 = 1.63
(近似值)。
- 根据公式计算系数:
-
接着计算情况 2 的半径:
- 计算系数:
a2 = 4
。b2 = 2 * (height + width) = 2 * 18 = 36
。c2 = (1 - min_overlap) * width * height = (1 - 0.5) * 8 * 10 = 40
。
- 计算判别式:
sq2 = sqrt(b2**2 - 4 * a2 * c2) = sqrt(36**2 - 4 * 4 * 40) = sqrt(1296 - 640) = sqrt(656) = 25.61
(近似值)。
- 计算半径:
r2 = (b2 - sq2) / (2 * a2) = (36 - 25.61) / 8 = 1.3
(近似值)。
- 计算系数:
-
最后计算情况 3 的半径:
- 计算系数:
a3 = 4 * min_overlap = 4 * 0.5 = 2
。b3 = -2 * min_overlap * (height + width) = -2 * 0.5 * 18 = -18
。c3 = (min_overlap - 1) * width * height = (0.5 - 1) * 8 * 10 = -40
。
- 计算判别式:
sq3 = sqrt(b3**2 - 4 * a3 * c3) = sqrt((-18)**2 - 4 * 2 * (-40)) = sqrt(324 + 320) = sqrt(644) = 25.38
(近似值)。
- 计算半径:
r3 = (b3 + sq3) / (2 * a3) = (-18 + 25.38) / 4 = 1.845
(近似值)。
- 计算系数:
-
最后返回三个半径中的最小值:
return min(r1, r2, r3)
,在这个例子中,返回值为 1.3。
综上所述,对于给定的目标尺寸和最小重叠比例,通过上述计算过程得到了高斯半径的值。这个半径将用于生成二维高斯热图等操作,以更好地表示目标在特征图上的位置和重要性。
四、提取局部最大值函数:get_local_maximum
该函数用于从热图中提取局部最大值像素。
-
参数解释:
heat
:目标热图。kernel
:最大池化的核大小,默认为 3。
-
执行过程:
- 首先,根据核大小计算填充值。
- 然后,使用
F.max_pool2d
函数进行最大池化操作。 - 最后,通过比较最大池化后的结果与原始热图,保留局部最大值像素,其他位置置为 0。
五、从热图中获取前 k 个最大值函数:get_topk_from_heatmap
这个函数用于从热图中获取前k
个最大值的位置和相关信息。
-
参数解释:
scores
:目标热图,形状为[batch, num_classes, height, width]
。k
:目标数量,默认为 20。
-
执行过程:
- 首先,将热图展平为二维张量。
- 然后,使用
torch.topk
函数获取前k
个最大值及其索引。 - 根据索引计算出类别、纵坐标和横坐标。
六、根据索引收集特征函数:gather_feat
该函数用于根据索引从特征图中收集特征。
-
参数解释:
feat
:目标特征图。ind
:目标坐标索引。mask
:特征图的掩码,默认为 None。
-
执行过程:
- 首先,根据特征图的维度扩展索引。
- 然后,使用
feat.gather
函数根据索引收集特征。 - 如果有掩码,则根据掩码进一步处理收集到的特征。
七、转置并根据索引收集特征函数:transpose_and_gather_feat
这个函数先对特征图进行转置,然后根据索引收集特征。
-
参数解释:
feat
:目标特征图。ind
:目标坐标索引。
-
执行过程:
- 首先,对特征图进行转置和形状调整,使其更易于根据索引收集特征。
- 然后,调用
gather_feat
函数进行特征收集。