import torch.nn as nn
import math
import torch
class GateConcMechanism(nn.Module):
def __init__(self, hidden_size=None):
super(GateConcMechanism, self).__init__()
self.hidden_size = hidden_size
self.w1 = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size))
self.w2 = nn.Parameter(torch.Tensor(self.hidden_size, self.hidden_size))
self.bias = nn.Parameter(torch.Tensor(self.hidden_size))
self.reset_parameters()
def reset_parameters(self): # 作用
stdv1 = 1. / math.sqrt(self.w1.size(1))
stdv2 = 1. / math.sqrt(self.w2.size(1))
stdv = (stdv1 + stdv2) / 2.
self.w1.data.uniform_(-stdv1, stdv1)
self.w2.data.uniform_(-stdv2, stdv2)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, input, hidden):
# input: hidden state from encoder;
# hidden: hidden state from key value memory network
# output = [gate * input; (1 - gate) * hidden]
gated = input.matmul(self.w1.t()) + hidden.matmul(self.w2.t()) + self.bias # input*w1 + hidden*w2 + bias
gate = torch.sigmoid(gated)
# output = torch.add(input.mul(gate), hidden.mul(1 - gate))
output = torch.cat([input.mul(gate), hidden.mul(1 - gate)],dim=-1)
return output
gate mechanism
于 2021-02-08 10:47:20 首次发布
该博客介绍了一个名为GateConcMechanism的门控融合机制,它是在深度学习模型中使用的一种结构。该机制通过将隐藏状态从编码器和键值内存网络的隐藏状态相结合,利用权重矩阵和偏置进行加权,并通过sigmoid激活函数生成门控信号。门控融合后的输出是输入和隐藏状态经过门控信号调整后的组合。这个机制有助于信息的选择性融合,提高模型的学习能力。

4708

被折叠的 条评论
为什么被折叠?



