论文:Visual Domain Adaptation with Manifold Embedded Distribution Alignment
地址:https://arxiv.org/abs/1807.07258
一、简介
在解决 unsupervised domain adaption 时主要出现两个问题:
- 恶化的特征转化: 特征对齐通常在原来的特征空间中实施,而且特征扭曲现象难以克服。另一方面,子空间学习不足以减少分布差异。
- 分布对齐没有测量的标准:之前的特征对齐方法仅仅将边缘分布和条件分布以视为同等重要的方式对齐,没有考虑到在真实应用中这两种方式的重要性可能不同。比如Figure1 所示,当两个域是非常不同(Figure 1(a) → 1(b)),此时对齐边缘分布可能更加重要。如果边缘分布是相近的(Figure 1(a) → 1©),此时对齐条件分布可能更加重要。
Manifold Embedded Distribution Alignment (MEDA) 方法解决了这两个问题,主要通过具有结构化危险最小化(structural risk minimization )的拉格拉斯流行(Grassmann manifold) 学习一个域不变的分类器,同时通过考虑条件分布和边缘分布的重要性来进行动态分布对齐。
二、细节
Main Idea
MEDA方法主要有两个基本步骤,分别解决上述提出的两个问题。
- 使用流形特征学习(manifold feature learning) 去解决恶化的特征表示问题。
- 使用动态特征对齐去定量的考虑边缘分布对齐和条件分布对齐的重要性。
根据结构危险最小化原则(SRM),最终可以把学习到域不变分类器
f
f
f 为:
f
=
arg
min
f
∈
∑
i
=
1
n
H
K
l
(
f
(
g
(
x
i
)
)
,
y
i
)
+
η
∣
∣
f
∣
∣
K
2
+
λ
D
f
‾
(
D
s
,
D
t
)
+
ρ
R
f
(
D
s
,
D
t
)
f = {\arg\min}_{f\in\sum_{i=1}^n\mathcal H_K} l(f(g(\mathbf x_i)), \mathbf y_i) + \eta||f||_K^2 + \lambda\overline {D_f}(\mathcal D_s, \mathcal D_t) + \rho R_f(\mathcal D_s, \mathcal D_t)
f=argminf∈∑i=1nHKl(f(g(xi)),yi)+η∣∣f∣∣K2+λDf(Ds,Dt)+ρRf(Ds,Dt)
∣ ∣ f ∣ ∣ K 2 ||f||_K^2 ∣∣f∣∣K2 是 f f f 的二次范式,后面两项分别为动态分布对齐和拉普拉斯正则项。其中 η 、 λ 、 ρ \eta、\lambda、\rho η、λ、ρ 是相应的参数。
Manifold Feature Learning
作者使用测地线流式核方法(GFK)来进行Manifold Feature learning。经过转换后的特征可以表示为
z
=
g
(
x
)
=
Φ
(
t
)
T
x
\mathbf {z = g(x) = \Phi(t)^T x}
z=g(x)=Φ(t)Tx,转换后的特征
z
i
z_i
zi 和
z
j
z_j
zj 的内积形成一个半正定测地线流式核:
<
z
i
,
z
j
>
=
∫
0
1
(
Φ
(
t
)
T
x
i
)
T
(
Φ
(
t
)
x
j
)
d
t
=
x
i
t
G
X
j
<\mathbf z_i, \mathbf z_j> = \mathbf{\int_0^1{(\Phi(t)^Tx_i)^T(\Phi(t)x_j) dt = x_i^tGX_j}}
<zi,zj>=∫01(Φ(t)Txi)T(Φ(t)xj)dt=xitGXj
最后得到
z
=
g
(
x
)
=
G
x
\mathbf{z = g(x) = \sqrt{G}x}
z=g(x)=Gx。
Dynamic Distribution Alignment
这里首先定义
P
P
P为边缘分布,
Q
Q
Q为条件分布。作者提出使用Adaption factor
μ
\mu
μ 来平衡两个分布的重要性,因此动态分布对齐
D
f
‾
\overline{D_f}
Df 可以被定义为:
D
f
‾
(
D
s
,
D
t
)
=
(
1
−
μ
)
D
f
(
P
s
,
P
t
)
+
μ
∑
c
=
1
C
D
f
(
c
)
(
Q
s
,
Q
t
)
\overline{D_f}(\mathcal D_s, \mathcal D_t) = (1 - \mu)D_f(P_s, P_t) + \mu\sum_{c=1}^CD_f^{(c)}(Q_s, Q_t)
Df(Ds,Dt)=(1−μ)Df(Ps,Pt)+μc=1∑CDf(c)(Qs,Qt)
其中
μ
∈
[
0
,
1
]
\mu\in[0, 1]
μ∈[0,1]就是 adaptive factor ,
c
∈
{
1
,
⋅
⋅
⋅
,
C
}
c\in\{1, ···,C\}
c∈{1,⋅⋅⋅,C}是类别指示符号。
D
f
D_f
Df可以用最大化均值差异(maximum mean discrepancy,MMD)来度量,因此,改写
D
f
‾
\overline{D_f}
Df 为:
D
f
‾
(
D
s
,
D
t
)
=
(
1
−
μ
)
∣
∣
E
[
f
(
z
s
)
]
−
E
[
f
(
z
t
)
]
∣
∣
H
K
2
+
μ
∑
c
=
1
C
∣
∣
E
[
f
(
z
s
(
c
)
)
]
−
E
[
f
(
z
t
(
c
)
)
]
∣
∣
H
K
2
\overline{D_f}(\mathcal D_s, \mathcal D_t) = (1-\mu){||\mathbb E[f(\mathbf z_s)] - \mathbb E[f(\mathbf z_t)] ||}_{\mathcal H_K}^2 + \mu\sum_{c=1}^C{||\mathbb E[f(\mathbf z_s^{(c)})] - \mathbb E[f(\mathbf z_t^{(c)})] ||}_{\mathcal H_K}^2
Df(Ds,Dt)=(1−μ)∣∣E[f(zs)]−E[f(zt)]∣∣HK2+μc=1∑C∣∣E[f(zs(c))]−E[f(zt(c))]∣∣HK2
值得注意的是,这里的
D
t
\mathcal D_t
Dt 是没有label的,所以无法直接评估条件分布
Q
t
=
Q
t
(
y
t
∣
z
t
)
Q_t = Q_t(\mathbf y_t|\mathbf z_t)
Qt=Qt(yt∣zt),使用类条件分布
Q
t
(
z
t
∣
y
t
)
Q_t(\mathbf z_t|\mathbf y_t)
Qt(zt∣yt)去近似
Q
t
Q_t
Qt,为了得到
Q
t
(
z
t
∣
y
t
)
Q_t(\mathbf z_t|\mathbf y_t)
Qt(zt∣yt),作者首先使用一个在
D
s
\mathcal D_s
Ds上训练的基本分类器来预测
D
t
\mathcal D_t
Dt的软标签。尽管这可能不值得信赖,但是可以迭代的去从新修正这个分类器。而且仅在第一次迭代的时候使用这个基础分类器,之后都使用MEDA来自动重新修订
D
t
D_t
Dt的标签。
为了定量的得到适应参数
μ
\mu
μ ,作者使用
A
\mathcal A
A-distance 去作为基本的测量方式,
A
\mathcal A
A-distance 被定义为构建一个线性分类器去判断两个域的错误率。
ϵ
(
h
)
\epsilon(h)
ϵ(h) 定义为一个线性分类器 h 判别两个域
D
s
\mathcal D_s
Ds 和
D
t
\mathcal D_t
Dt 的错误率。
A
\mathcal A
A-distance 公式化定义如下:
d
A
(
D
s
,
D
t
)
=
2
(
1
−
2
ϵ
(
h
)
)
d_A(\mathcal D_s, \mathcal D_t) = 2(1 - 2\epsilon(h))
dA(Ds,Dt)=2(1−2ϵ(h))
所以
μ
\mu
μ的值可以被估计为:
μ
^
≈
1
−
d
M
d
M
+
∑
c
=
1
C
d
c
\hat \mu \ \approx 1 - {d_M\over{d_M + \sum_{c=1}^Cd_c}}
μ^ ≈1−dM+∑c=1CdcdM
Learning Classifier f f f
引入平方loss
l
2
l_2
l2,
f
f
f可以重新写成:
f
=
arg
min
f
∈
H
K
∑
i
=
1
n
(
y
i
−
f
(
z
i
)
)
2
+
η
∣
∣
f
∣
∣
K
2
+
λ
D
f
‾
(
D
s
,
D
t
)
+
ρ
R
f
(
D
s
,
D
t
)
f = {\arg\min}_{f\in\mathcal H_K}\sum_{i=1}^n(y_i - f(\mathbf z_i))^2 + \eta||f||_K^2 + \lambda\overline {D_f}(\mathcal D_s, \mathcal D_t) + \rho R_f(\mathcal D_s, \mathcal D_t)
f=argminf∈HKi=1∑n(yi−f(zi))2+η∣∣f∣∣K2+λDf(Ds,Dt)+ρRf(Ds,Dt)
接下来,作者详细的讨论了
f
f
f 中每一项的细节,包括
f
、
D
f
‾
(
D
s
,
D
t
)
、
R
f
(
D
s
,
D
t
)
f、\overline {D_f}(\mathcal D_s, \mathcal D_t)、R_f(\mathcal D_s, \mathcal D_t)
f、Df(Ds,Dt)、Rf(Ds,Dt) 的准确表达形式:
f
=
arg
min
f
∈
H
K
∣
∣
(
Y
−
β
T
K
)
A
∣
∣
F
2
−
η
t
r
(
β
T
K
β
)
+
t
r
(
β
T
K
(
λ
M
+
ρ
L
)
K
β
)
f = {\arg\min}_{f\in\mathcal H_K}||\mathbf {(Y-\beta^TK)A}||_F^2 - \eta tr(\beta^TK\beta) + tr(\mathbf{\beta^TK(\lambda M + \rho L)K\beta})
f=argminf∈HK∣∣(Y−βTK)A∣∣F2−ηtr(βTKβ)+tr(βTK(λM+ρL)Kβ)
然后设置
∂
f
/
∂
β
=
0
\partial f/\partial \beta = 0
∂f/∂β=0, 最后获得求解答案:
β
⋆
=
(
(
A
+
λ
M
+
ρ
L
)
K
+
η
I
)
−
1
A
Y
T
\beta^\star = \mathbf{((A + \lambda M + \rho L)K + \eta I)^{-1}AY^T}
β⋆=((A+λM+ρL)K+ηI)−1AYT
三、实验
作者使用了七个公开数据集:Office+Caltech10, USPS + MNIST, ImageNet + VOC2007, and Office-31,这些数据集都可以在这里看到 transferlearning dataset。(数据准备的详细细节见paper)。作者与当前几个state-of-the-art 传统和深度域适应方法。相关方法在论文中有简单的描述。为了公平比较,作者使用同样的准则去得到特征。准确度:
A
c
c
u
r
a
c
y
=
∣
x
:
x
∈
D
t
∧
y
^
(
x
)
=
y
(
x
)
∣
∣
x
:
x
∈
D
s
∣
Accuracy = {{|\mathbf {x:x \in \mathcal D_t \land \hat y(x)=y(x)}|} \over \mathbf {|x:x \in \mathcal D_s|}}
Accuracy=∣x:x∈Ds∣∣x:x∈Dt∧y^(x)=y(x)∣
其中
y
(
x
)
y(\mathbf x)
y(x) 和
y
^
(
x
)
\hat y(\mathbf x)
y^(x) 是分别是对于目标域的ground truth label 和 predicted labels。
所有实验分类结果准确度分别展示在Tables 2,3,4。可以得到如下结果:
- MEDA的性能超过大多数传统或深度域适应方法(21/28个任务)。MEDA在28项任务中的平均分类准确率为73.2%。 与最佳基线方法JGSA(69.7%)相比,平均性能改善为3.5%,显示出显着的平均误差降低11.6%。 请注意,由于空间限制,Office-31数据集上的结果位于补充文件2中,并且观察结果相同。由于这些结果是从广泛的图像数据集中获得的,因此它表明MEDA能够显着降低域适应问题中的分布差异。
- 可以看到所有的分布对齐方法(TCA, JDA, ARTL, TJM, JGSA, and DMM) 和 子空间学习方法(GFK, CORAL, and SCA) 性能比MEDA方法效果差。每一种方法都有它们的限制(恶化的特征转化 或 分布对齐没有测量的标准),MEDA解决了这两个问题。
- MEDA的性能也超过了深度方法(AlexNet, DDC, DAN, DCORAL, and DUCDA)。深度学习方法需要调整很多的超参数,而MEDA仅仅涉及几个参数。
除了MEDA与state-of-the-art方法的性能,还分析了MEDA中流行特征学习、动态分布对齐、每一个组件评估,最后还分析了参数敏感性、收敛和时间复杂度等,限于篇幅这里不在赘述。
四、总结
这不愧为ACMMM oral的论文,内容十分充足。论文提出了一个新奇的方法,解决unsupervised domain adaption中之前工作没有解决的两个核心问题1)恶化的特征转化。2)分布对齐没有测量的标准。分别引入流形特征学习和动态分布对齐方法分布解决这两个问题。在实验阶段,作者进行了全面而丰富的实验,使用了7个公共数据集,比较了很多 state-of-the-art 传统的和深度的域适应方法,证实了MEDA方法卓越的性能。之后还对MEDA每一个组件进行了实验分析,分析参数敏感性和MEDA时间复杂度。