参考论文及链接:
论文:Trusted Multi-View Classification
解读链接
1 参数介绍
对模型,需要输入分类的类别个数,由于是多模态分类,要输入视图的个数,以及每个视图的维度数。
def __init__(self, classes, views, classifier_dims, lambda_epochs=1):
"""
:param classes: Number of classification categories
:param views: Number of views
:param classifier_dims: Dimension of the classifier
:param annealing_epoch: KL divergence annealing epoch during training
"""
super(TMC, self).__init__()
self.views = views
self.classes = classes
self.lambda_epochs = lambda_epochs
self.Classifiers = nn.ModuleList([Classifier(classifier_dims[i], self.classes) for i in range(self.views)])
对于代码中的Classifier类,需要对如下参数初始化:
class Classifier(nn.Module):
def __init__(self, classifier_dims, classes):
super(Classifier, self).__init__()
self.num_layers = len(classifier_dims)
self.fc = nn.ModuleList()
for i in range(self.num_layers-1):
self.fc.append(nn.Linear(classifier_dims[i], classifier_dims[i+1]))
self.fc.append(nn.Linear(classifier_dims[self.num_layers-1], classes))
self.fc.append(nn.Softplus())
def forward(self, x):
h = self.fc[0](x)
for i in range(1, len(self.fc)):
h = self.fc[i](h)
return h
2 前向传递流程
前向传递流程如下:首先各输入从Classifier类前向传递的神经网络中学习证据,最后经过softplus层,得到evidence。每个evidence存储在字典中。
def infer(self, input):
"""
:param input: Multi-view data
:return: evidence of every view
"""
evidence = dict()
for v_num in range(self.views):
evidence[v_num] = self.Classifiers[v_num](input[v_num])
return evidence
之后根据下图公式,求出u,b,S的值。在使用DS组合理论来

代码如下:
2.1 公式求值
后面的代码还计算了损失。
def ce_loss(p, alpha, c, global_step, annealing_step):
S = torch.sum(alpha, dim=1, keepdim=True)
E = alpha - 1
label = F.one_hot(p, num_classes=c)
A = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True)
annealing_coef = min(1, global_step / annealing_step)
alp = E * (1 - label) + 1
B = annealing_coef * KL(alp, c)
return (A + B)
2.2 DS组合理论
def DS_Combin(self, alpha):
"""
:param alpha: All Dirichlet distribution parameters.
:return: Combined Dirichlet distribution parameters.
"""
def DS_Combin_two(alpha1, alpha2):
"""
:param alpha1: Dirichlet distribution parameters of view 1
:param alpha2: Dirichlet distribution parameters of view 2
:return: Combined Dirichlet distribution parameters
"""
alpha = dict()
alpha[0], alpha[1] = alpha1, alpha2
b, S, E, u = dict(), dict(), dict(), dict()
for v in range(2):
S[v] = torch.sum(alpha[v], dim=1, keepdim=True)
E[v] = alpha[v]-1
b[v] = E[v]/(S[v].expand(E[v].shape))
u[v] = self.classes/S[v]
# b^0 @ b^(0+1)
bb = torch.bmm(b[0].view(-1, self.classes, 1), b[1].view(-1, 1, self.classes))
# b^0 * u^1
uv1_expand = u[1].expand(b[0].shape)
bu = torch.mul(b[0], uv1_expand)
# b^1 * u^0
uv_expand = u[0].expand(b[0].shape)
ub = torch.mul(b[1], uv_expand)
# calculate C
bb_sum = torch.sum(bb, dim=(1, 2), out=None)
bb_diag = torch.diagonal(bb, dim1=-2, dim2=-1).sum(-1)
C = bb_sum - bb_diag
# calculate b^a
b_a = (torch.mul(b[0], b[1]) + bu + ub)/((1-C).view(-1, 1).expand(b[0].shape))
# calculate u^a
u_a = torch.mul(u[0], u[1])/((1-C).view(-1, 1).expand(u[0].shape))
# calculate new S
S_a = self.classes / u_a
# calculate new e_k
e_a = torch.mul(b_a, S_a.expand(b_a.shape))
alpha_a = e_a + 1
return alpha_a
for v in range(len(alpha)-1):
if v==0:
alpha_a = DS_Combin_two(alpha[0], alpha[1])
else:
alpha_a = DS_Combin_two(alpha_a, alpha[v+1])
return alpha_a
3 通过学习获得M

多模态分类与DS组合模型的理解与实现
该博客介绍了TrustedMulti-ViewClassification模型的细节,包括输入参数、Classifier类的初始化以及前向传播过程。模型利用多模态数据进行分类,并通过DS组合理论来融合不同视图的信息。在训练过程中,使用了KL散度退火策略来优化损失函数。博客深入探讨了损失函数的计算以及DS组合方法,为理解多模态分类和信息融合提供了深入见解。
7840

被折叠的 条评论
为什么被折叠?



