
本文来对 DSANet 从源码的角度做自定向下的分析:
总体结构
3 个支路:全局自注意力模块、局部自注意力模块、线性自回归模块。
参数含义
| 参数 | 含义 |
|---|---|
| window (int) | the length of the input window size |
| n_multiv (int) | num of univariate time series |
| n_kernels (int) | the num of channels |
| w_kernel (int) | the default is 1,初始通道数 |
| local (int) | 1维卷积核宽度 |
| d_k (int) | d_model / n_head |
| d_v (int) | d_model / n_head |
| d_model (int) | outputs of dimension |
| d_inner (int) | the inner-layer dimension of Position-wise Feed-Forward Networks |
| n_layers (int) | num of layers in Encoder |
| n_head (int) | num of Multi-head |
| drop_prob (float) | the probability of dropout |
class DSANet(LightningModule):
def __init__(self, hparams):
"""
Pass in parsed HyperOptArgumentParser to the model
"""
super(DSANet, self).__init__()
... # init variables
# build model
self.__build_model()
def __build_model(self):
"""
Layout model
"""
self.sgsf = Single_Global_SelfAttn_Module(
window=self.window, n_multiv=self.n_multiv, n_kernels=self.n_kernels,
w_kernel=self.w_kernel, d_k=self.d_k, d_v=self.d_v, d_model=self.d_model,
d_inner=self.d_inner, n_layers=self.n_layers, n_head=self.n_head, drop_prob=self.drop_prob)
self.slsf = Single_Local_SelfAttn_Module(
window=self.window, local=self.local, n_multiv=self.n_multiv, n_kernels=self.n_kernels,
w_kernel=self.w_kernel, d_k=self.d_k, d_v=self.d_v, d_model=self.d_model,
d_inner=self.d_inner, n_layers=self.n_layers, n_head=self.n_head, drop_prob=self.drop_prob)
self.ar = AR(window=self.window)
self.W_output1 = nn.Linear(2 * self.n_kernels, 1)
self.dropout = nn.Dropout(p=self.drop_prob)
self.active_func = nn.Tanh()
def forward(self, x):
"""
:param x: [batch, window, n_multiv]
:return: output: [batch, 1, n_multiv]
"""
sgsf_output, *_ = self.sgsf(x) # sgsf_output: [batch, n_multiv, n_kernel]
slsf_output, *_ = self.slsf(x) # slsf_output: [batch, n_multiv, n_kernel]
sf_output = torch.cat((sgsf_output, slsf_output), 2) # sf_output: [batch, n_multiv, 2*n_kernel]
sf_output = self.dropout(sf_output)
sf_output = self.W_output1(sf_output) # sf_output: [batch, n_multiv, 1]
sf_output = torch.transpose(sf_output, 1, 2) # sf_output: [batch, 1, n_multiv]
ar_output = self.ar(x) # ar_output: [batch, 1, n_multiv]
output = sf_output + ar_output
return output
全局自注意力模块
class Single_Global_SelfAttn_Module(nn.Module):
def __init__(
self,
window, n_multiv, n_kernels

本文深入解析DSANet网络结构,包括全局与局部自注意力模块、线性自回归模块,阐述各参数含义及模型工作原理。
最低0.47元/天 解锁文章
285





