欢迎访问个人网络日志🌹🌹知行空间🌹🌹
背景
使用ADE20K
数据集进行验证的分割算法,因这个数据集是exausted annotated
,也就是图像中的每个像素都标注了类别,因此背景只占了很少的一部分,因此训练时会设置ignore_index=255
,在商汤的框架mmsegmentation
中的ade20k.py
中的educe_zero_label=True
参数,正是为了实现ignore index
忽略背景
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0))]
解析
正常的交叉熵计算公式为:
l j = y j l o g ( e f y j ∑ i = 1 c e f i ) l_j = y_jlog(\frac{e^{f_{y_j}}}{\sum\limits_{i=1}^{c}e^{f_i}}) lj=yjlog(i=1∑cefiefyj)
ignore index
即需要忽略掉的label id
,再计算交叉熵的时候会去除掉ignore index
,
对于数据:
import torch
y_pre = torch.tensor([[-0.7795, -1.7373, 0.3856],
[ 0.4031, -0.5632, -0.2389],
[-2.3908, 0.7452, 0.7748]])
y_real = torch.tensor([0, 1, 2])
critien = torch.nn.CrossEntropyLoss()
print(critien(y_pre, y_real))
# tensor(1.2784)
可以将其当成batch=3
,类别数C=3
的例子,常规交叉熵的计算方式为:
l
1
=
1
∗
l
o
g
(
e
−
0.7795
e
−
0.7795
+
e
−
1.7373
+
e
0.3856
)
l_1 = 1 * log(\frac{e^{-0.7795}}{e^{-0.7795}+e^{-1.7373}+e^{0.3856}})
l1=1∗log(e−0.7795+e−1.7373+e0.3856e−0.7795)
l
2
=
1
∗
l
o
g
(
e
−
0.5632
e
0.4031
+
e
−
0.5632
+
e
−
0.2389
)
l_2 = 1 * log(\frac{e^{-0.5632}}{e^{0.4031}+e^{-0.5632}+e^{-0.2389}}) \\
l2=1∗log(e0.4031+e−0.5632+e−0.2389e−0.5632)
l
3
=
1
∗
l
o
g
(
e
0.7748
e
−
2.3908
+
e
0.7452
+
e
0.7748
)
l_3 = 1 * log(\frac{e^{0.7748}}{e^{-2.3908}+e^{0.7452}+e^{0.7748}}) \\
l3=1∗log(e−2.3908+e0.7452+e0.7748e0.7748)
l
=
l
1
+
l
2
+
l
3
3
=
1.2784
l = \frac{l_1+l_2+l_3}{3} = 1.2784
l=3l1+l2+l3=1.2784
- 使用
ignore_index
参数时,
l j = y j l o g ( e f y j ∑ i = 1 c e f i ) ∗ 1 { y j = = i g n o r e _ i n d e x } l_j = y_jlog(\frac{e^{f_{y_j}}}{\sum\limits_{i=1}^{c}e^{f_i}})*\mathcal{1}\{y_j==ignore\_index\} lj=yjlog(i=1∑cefiefyj)∗1{yj==ignore_index}
譬如ignore_index=2
,则其计算会忽略掉label id 2
,相当于只剩下类别0,1
l
1
=
1
∗
l
o
g
(
e
−
0.7795
e
−
0.7795
+
e
−
1.7373
)
l_1 = 1 * log(\frac{e^{-0.7795}}{e^{-0.7795}+e^{-1.7373}})
l1=1∗log(e−0.7795+e−1.7373e−0.7795)
l
2
=
1
∗
l
o
g
(
e
−
0.5632
e
0.4031
+
e
−
0.5632
)
l_2 = 1 * log(\frac{e^{-0.5632}}{e^{0.4031}+e^{-0.5632}})
l2=1∗log(e0.4031+e−0.5632e−0.5632)
l
=
l
1
+
l
2
2
=
1.5678
l = \frac{l_1+l_2}{2} = 1.5678
l=2l1+l2=1.5678
critien = torch.nn.CrossEntropyLoss(ignore_index=2)
print(critien(y_pre, y_real))
# tensor(1.5678)
- 带权重
w
时,
l j = y j l o g ( e f y j ∑ i = 1 c e f i ) ∗ w [ l a b e l s [ j ] ] l_j = y_jlog(\frac{e^{f_{y_j}}}{\sum\limits_{i=1}^{c}e^{f_i}})*w[labels[j]] \\ lj=yjlog(i=1∑cefiefyj)∗w[labels[j]]
计算average loss ,算样本个数时也需乘上w[labels[i]]
譬如w=[0.3, 0.6, 0.1]
, 带权重交叉熵的计算方式为:
y
j
=
1
l
1
=
1
∗
l
o
g
(
e
−
0.7795
e
−
0.7795
+
e
−
1.7373
+
e
0.3856
)
∗
0.3
y_j = 1\\ l_1 = 1 * log(\frac{e^{-0.7795}}{e^{-0.7795}+e^{-1.7373}+e^{0.3856}}) * 0.3 \\
yj=1l1=1∗log(e−0.7795+e−1.7373+e0.3856e−0.7795)∗0.3
l
2
=
1
∗
l
o
g
(
e
−
0.5632
e
0.4031
+
e
−
0.5632
+
e
−
0.2389
)
∗
0.6
l_2 = 1 * log(\frac{e^{-0.5632}}{e^{0.4031}+e^{-0.5632}+e^{-0.2389}}) * 0.6\\
l2=1∗log(e0.4031+e−0.5632+e−0.2389e−0.5632)∗0.6
l
3
=
1
∗
l
o
g
(
e
0.7748
e
−
2.3908
+
e
0.7452
+
e
0.7748
)
∗
0.1
l_3 = 1 * log(\frac{e^{0.7748}}{e^{-2.3908}+e^{0.7452}+e^{0.7748}}) * 0.1\\
l3=1∗log(e−2.3908+e0.7452+e0.7748e0.7748)∗0.1
l
=
l
1
+
l
2
+
l
3
1
∗
0.3
+
1
∗
0.6
+
1
∗
0.1
=
1.4941
l = \frac{l_1+l_2+l_3}{1*0.3+1*0.6+1*0.1} = 1.4941
l=1∗0.3+1∗0.6+1∗0.1l1+l2+l3=1.4941
w = torch.tensor([0.3, 0.6, 0.1])
critien = torch.nn.CrossEntropyLoss(weight=w)
print(critien(y_pre, y_real))
# tensor(1.4941)
带权重交叉熵损失函数主要用于处理类别不平衡的问题。
- 同时使用
ignore_index
和权重w
,
综合以上两种情况,譬如ignore_index=2
和权重w=[0.3, 0.6, 0.1]
l
1
=
1
∗
l
o
g
(
e
−
0.7795
e
−
0.7795
+
e
−
1.7373
)
∗
0.3
l_1 = 1 * log(\frac{e^{-0.7795}}{e^{-0.7795}+e^{-1.7373}}) * 0.3 \\
l1=1∗log(e−0.7795+e−1.7373e−0.7795)∗0.3
l
2
=
1
∗
l
o
g
(
e
−
0.5632
e
0.4031
+
e
−
0.5632
)
∗
0.6
l_2 = 1 * log(\frac{e^{-0.5632}}{e^{0.4031}+e^{-0.5632}}) * 0.6\\
l2=1∗log(e0.4031+e−0.5632e−0.5632)∗0.6
l
=
l
1
+
l
2
1
∗
0.3
+
1
∗
0.6
=
1.5824
l = \frac{l_1+l_2}{1*0.3+1*0.6} = 1.5824
l=1∗0.3+1∗0.6l1+l2=1.5824
critien = torch.nn.CrossEntropyLoss(ignore_index=2, weight=w)
print(critien(y_pre, y_real))
# tensor(1.5824)
动手实现的计算带权重和ignore_index
参数的交叉熵损失函数如下,测试可以得到与torch.nn.CrossEntropyLoss
一样的结果,帮助理解其原理
import torch
def cross_entropy(inp, tar,ig=None, w=None):
ss = 0
for i,s in enumerate(tar):
ts = torch.log((torch.exp(inp[i][s])/torch.sum(inp[i].exp())))
ts = 0 if s.item() == ig else ts
if w is not None:
ts = ts * w[s]
ss += ts
bi = torch.bincount(tar)
ns = 0
for i in range(torch.max(tar)+1):
if i == ig:
continue
if w is not None:
ns += bi[i] * w[i]
else:
ns +=bi[i]
return -ss / ns
w = torch.tensor([0.3, 0.6, 0.1])
y_pre = torch.tensor([[-0.7795, -1.7373, 0.3856],
[ 0.4031, -0.5632, -0.2389],
[-2.3908, 0.7452, 0.7748]])
y_real = torch.tensor([0, 1, 2])
critien = torch.nn.CrossEntropyLoss(ignore_index=2, weight=w)
print(cross_entropy(y_pre, y_real, ig=2, w=w))
print(critien(y_pre, y_real))
"""
tensor(1.5824)
tensor(1.5824)
"""
参考资料
欢迎访问个人网络日志🌹🌹知行空间🌹🌹