参考博客:
博客一:Address class imbalance easily with Pytorch | by Mastafa Foufa | Analytics Vidhya | Medium
播客二:Address class imbalance easily with Pytorch Part 2 | by Mastafa Foufa | Towards Data Science
类不平衡
如论文所给出的结论,处理类不平衡的主要方法是过采样。过采样应被应用至完全消除类不平衡,而优化的欠采样系数取决于不平衡的程度。与一些经典的机器学习模型不同,过采样不会导致CNN网络过拟合。
假设数据集中包含两类:
c
l
a
s
s
1
class_1
class1和
c
l
a
s
s
2
class_2
class2,基于均匀分布,那么从
c
l
a
s
s
1
class_1
class1中随机采样得到的概率为
p
(
x
∈
c
l
a
s
s
i
)
=
#
{
c
l
a
s
s
i
}
#
{
t
r
a
i
n
}
=
N
c
l
a
s
s
i
N
t
r
a
i
n
p(x\in class_i)=\frac{\#\{class_i\}}{\#\{train\}}=\frac{N_{class_i}}{N_{train}}
p(x∈classi)=#{train}#{classi}=NtrainNclassi
但是,实际可能二分类中,某一类数量远大于另一类
N
c
l
a
s
s
1
≫
N
c
l
a
s
s
2
N_{class_1} \gg N_{class_2}
Nclass1≫Nclass2
也就是说
p
(
x
∈
c
l
a
s
s
1
)
≫
p
(
x
∈
c
l
a
s
s
2
)
p(x\in class_1) \gg p(x\in class_2)
p(x∈class1)≫p(x∈class2)
如果使用该数据集来训练模型,那么模型看到
c
l
a
s
s
1
class_1
class1的机会远大于
c
l
a
s
s
2
class_2
class2,导致模型无法从
c
l
a
s
s
2
class_2
class2中学到有用的特征。
因此,我们应首先进行人工增强数据,即增强小类数据,使得
p
(
x
∈
c
l
a
s
s
1
)
≈
p
(
x
∈
c
l
a
s
s
2
)
p(x\in class_1) \approx p(x\in class_2)
p(x∈class1)≈p(x∈class2)
使用 WeightedRandomSampler
博客一以二分类(比例为9:1),给出的源代码为
处理前和处理后的每个batch中的类分布
由该函数的Pytorch源代码可以看出,关键思想为,由控制参数的多项式分布中进行样本采样。
Pytorch使用多项式分布,其参数为weights, number of samples,以及采样是否放回的replacement.
Pytorch中引入的关键思想为基于多项式分布来从一组点中进行采样。每个样本被赋予采样的概率。该概率由其类权重参数来定义。
一个简单的例子
假设数据集具有以下形式,左边为100个样本,中间为类分布,右边为WeightedRandomSampler赋予的权重参数。蓝色为大类,红色为小类。
我们可以控制权重,对小类给予更大的权重:
W
1
≫
W
0
W_1 \gg W_0
W1≫W0
权重参数设置如下:
W
0
=
N
N
0
=
100
90
≈
1.11
W_0=\frac{N}{N_0}=\frac{100}{90} \approx 1.11
W0=N0N=90100≈1.11
W 1 = N N 1 = 100 10 = 10 W_1 = \frac{N}{N_1}=\frac{100}{10}=10 W1=N1N=10100=10
使用类似softmax函数的方法来正规化权重矢量来得到采样概率
p
(
c
0
)
=
90
W
0
(
10
W
0
+
90
W
1
)
≈
0.0056
p(c_0)=\frac{90W_0}{(10W_0+90W_1)}\approx0.0056
p(c0)=(10W0+90W1)90W0≈0.0056
p ( c 1 ) = 10 W 1 ( 10 W 0 + 90 W 1 ) = 0.05 p(c_1)=\frac{10W_1}{(10W_0+90W_1)}=0.05 p(c1)=(10W0+90W1)10W1=0.05
注意:我们必须在整个数据集上进行正规化。目的是为了矢量的元素和等于1
接下来,我们从数学上来证明,100次随机采样,可以从两个类中分别采样到50和50样本。
从类
c
1
c_1
c1中采样到的样本数为
E
[
c
1
]
=
∑
i
=
91
100
m
∗
p
(
c
1
)
=
∑
i
=
91
100
100
∗
0.05
=
50
E[c_1]=\sum_{i=91}^{100}{m*p(c_1)}=\sum_{i=91}^{100}{100*0.05}=50
E[c1]=i=91∑100m∗p(c1)=i=91∑100100∗0.05=50
从类
c
0
c_0
c0中采样到的样本数为
E
[
c
0
]
=
∑
i
=
1
90
m
∗
p
(
c
0
)
=
∑
i
=
1
90
100
∗
0.0056
≈
50.4
E[c_0]=\sum_{i=1}^{90}{m*p(c_0)}=\sum_{i=1}^{90}{100*0.0056} \approx 50.4
E[c0]=i=1∑90m∗p(c0)=i=1∑90100∗0.0056≈50.4