提示:结合论文和代码介绍时间序列预测FEDformer模型(只包括频域增强)
文章信息
- 模型: FEDformer(Frequency Enhanced Decomposed Transformer )
- 关键词: 频域增强;频率增强注意力
- 作者:Tian Zhou, Ziqing Ma, Qingsong Wen, Xue Wang, Liang Sun, Rong Jin
- 机构:阿里
- 发表情况:ICML 2022
- 网址:FEDformer: Frequency Enhanced Decomposed Transformer for Long-term Series Forecasting
前言
- 时域信息和频域信息
由于时间序列数据固有的复杂性和动态性,时域中包含的信息往往是稀疏和分散的。为了捕获和利用这些信息常依赖于复杂的方法和复杂的模型。例如Transformer架构等方法,从而在效率和可扩展性(Scalability)方面提出挑战。
与之相反,时间序列数据的频域表示为其基础信息提供了更简洁和紧凑的表示,存在这样一个事实,即大多数时间序列倾向于在众所周知的基(如傅立叶变换)中具有稀疏表示。
可扩展性:一种定义:能够通过线性增加机器来对系统的性能进行水平扩展,从而应对系统在数据量、流量、复杂性方面的增长。可以理解为:时间依赖关系的提高,并不能通过线性增加模型复杂度(增大dmodel,layers数等)来解决。
- point-wise注意力
self-attention进行的是点预测,即每个时间步(序列点)的预测是独立的,因此模型很可能无法保持时间序列作为一个整体的全局属性和统计特征。直接用Transformer的预测是将序列的每个时间步当作一个token,类似于NLP,一个点一个点的输出数据(计算时是并行的)。 - 频域中时间序列的紧凑表示
使用少量选定的傅立叶分量保持时间序列的紧凑表示可使Transformer的计算变得高效,这对建模长序列至关重要。
只简单地保留所有的频率分量可能会导致较差的表示,因为时间序列中的许多高频变化是由于噪声输入引起的。仅保留低频分量(特定部分频率)也可能不适合用于序列预测,因为时间序列中的一些趋势变化代表了重要事件。作者建议通过随机选择固定数量的傅立叶分量(包括高频和低频)来表示时间序列。
本文的设计理念
为了更好地捕捉时间序列的全局特性,并将季节趋势分解结合进来,提出了一种频率增强Transformer架构。
减小计算开销:该模型通过随机选择固定数量的傅里叶分量,实现了线性计算复杂度和内存开销。将Transformer的计算成本从二次复杂度降低到线性复杂度
一、网络结构
FEDformer 包括频率增强块(FEB)、连接编码器和解码器的频率增强注意力(FEA),混合专家分解块(MOEDecomp)。(
长时间序列预测是一个seq2seq的问题,我们用
I
I
I表示输入长度,
O
O
O表示输出长度,
D
D
D 表示序列的隐状态(d_model)。编码器的输入为
I
×
D
I\times D
I×D,而解码器的输入为
(
I
/
2
+
O
)
×
D
(I/2+O)\times D
(I/2+O)×D。
使用傅里叶变化的频率增强块(FEB-f)
计算公式:
代码实现
class FourierBlock(nn.Module):
def __init__(self, in_channels, out_channels, n_heads, seq_len, modes=0, mode_select_method='random'):
super(FourierBlock, self).__init__()
print('fourier enhanced block used!')
"""
1D Fourier block. It performs representation learning on frequency domain,
it does FFT, linear transform, and Inverse FFT.
"""
# get modes on frequency domain
self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method)
print('modes={}, index={}'.format(modes, self.index))
self.n_heads = n_heads
self.scale = (1 / (in_channels * out_channels))
self.weights1 = nn.Parameter(
self.scale * torch.rand(self.n_heads, in_channels // self.n_heads, out_channels // self.n_heads,
len(self.index), dtype=torch.float))
self.weights2 = nn.Parameter(
self.scale * torch.rand(self.n_heads, in_channels // self.n_heads, out_channels // self.n_heads,
len(self.index), dtype=torch.float))
# Complex multiplication
def compl_mul1d(self, order, x, weights):
x_flag = True
w_flag = True
if not torch.is_complex(x):
x_flag = False
x = torch.complex(x, torch.zeros_like(x).to(x.device))
if not torch.is_complex(weights):
w_flag = False
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
if x_flag or w_flag:
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
else:
return torch.einsum(order, x.real, weights.real)
def forward(self, q, k, v, mask):
# size = [B, L, H, E]
B, L, H, E = q.shape
x = q.permute(0, 2, 3, 1)
# Compute Fourier coefficients
x_ft = torch.fft.rfft(x, dim=-1)
# Perform Fourier neural operations
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
for wi, i in enumerate(self.index):
if i >= x_ft.shape[3] or wi >= out_ft.shape[3]:
continue
out_ft[:, :, :, wi] = self.compl_mul1d("bhi,hio->bho", x_ft[:, :, :, i],
torch.complex(self.weights1, self.weights2)[:, :, :, wi])
# Return to time domain
x = torch.fft.irfft(out_ft, n=x.size(-1))
return (x, None)
频率增强注意力(FEA)
输出为:
FEA-f
(
q
,
k
,
v
)
=
F
−
1
(
Padding
(
σ
(
Q
~
⋅
K
~
⊤
)
⋅
V
~
)
)
\text { FEA-f }(\boldsymbol{q}, \boldsymbol{k}, \boldsymbol{v})= \mathcal{F}^{-1}\left(\operatorname{Padding}\left(\sigma\left(\tilde{\boldsymbol{Q}} \cdot \tilde{\boldsymbol{K}}^{\top}\right) \cdot \tilde{\boldsymbol{V}}\right)\right)
FEA-f (q,k,v)=F−1(Padding(σ(Q~⋅K~⊤)⋅V~))
其中
σ
\sigma
σ 为激活函数。
代码实现
class FourierCrossAttention(nn.Module):
def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=64, mode_select_method='random',
activation='tanh', policy=0, num_heads=8):
super(FourierCrossAttention, self).__init__()
print(' fourier enhanced cross attention used!')
"""
1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT.
"""
self.activation = activation
self.in_channels = in_channels
self.out_channels = out_channels
# get modes for queries and keys (& values) on frequency domain
self.index_q = get_frequency_modes(seq_len_q, modes=modes, mode_select_method=mode_select_method)
self.index_kv = get_frequency_modes(seq_len_kv, modes=modes, mode_select_method=mode_select_method)
print('modes_q={}, index_q={}'.format(len(self.index_q), self.index_q))
print('modes_kv={}, index_kv={}'.format(len(self.index_kv), self.index_kv))
self.scale = (1 / (in_channels * out_channels))
self.weights1 = nn.Parameter(
self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float))
self.weights2 = nn.Parameter(
self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float))
# Complex multiplication
def compl_mul1d(self, order, x, weights):
x_flag = True
w_flag = True
if not torch.is_complex(x):
x_flag = False
x = torch.complex(x, torch.zeros_like(x).to(x.device))
if not torch.is_complex(weights):
w_flag = False
weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))
if x_flag or w_flag:
return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),
torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))
else:
return torch.einsum(order, x.real, weights.real)
def forward(self, q, k, v, mask):
# size = [B, L, H, E]
B, L, H, E = q.shape
xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L]
xk = k.permute(0, 2, 3, 1)
xv = v.permute(0, 2, 3, 1)
# Compute Fourier coefficients
xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat)
xq_ft = torch.fft.rfft(xq, dim=-1)
for i, j in enumerate(self.index_q):
if j >= xq_ft.shape[3]:
continue
xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
xk_ft_ = torch.zeros(B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat)
xk_ft = torch.fft.rfft(xk, dim=-1)
for i, j in enumerate(self.index_kv):
if j >= xk_ft.shape[3]:
continue
xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
# perform attention mechanism on frequency domain
xqk_ft = (self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_))
if self.activation == 'tanh':
xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
elif self.activation == 'softmax':
xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
else:
raise Exception('{} actiation function is not implemented'.format(self.activation))
xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)
xqkvw = self.compl_mul1d("bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2))
out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
for i, j in enumerate(self.index_q):
if i >= xqkvw.shape[3] or j >= out_ft.shape[3]:
continue
out_ft[:, :, :, j] = xqkvw[:, :, :, i]
# Return to time domain
out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1))
return (out, None)
混合专家分解块(MOEDecomp)
由于通常观察到的复杂周期模式与真实数据上的趋势成分相结合,因此在固定窗口的平均池化中提取趋势比较困难。为了克服这样一个问题,文章设计了一个混合专家分解块,它包含一组不同大小的平均滤波器,从输入信号中提取多个趋势成分,以及一组数据相关的权重,将它们组合成最终趋势。形式化如下:
X
trend
=
Softmax
(
L
(
x
)
)
∗
(
F
(
x
)
)
\mathbf{X}_{\text {trend }}=\operatorname{Softmax}(L(x)) *(F(x))
Xtrend =Softmax(L(x))∗(F(x))
F
(
⋅
)
F(\cdot)
F(⋅)是一组平均池化过滤器和
S
o
f
t
m
a
x
(
L
(
x
)
)
Softmax(L(x))
Softmax(L(x))是混合这些提取趋势的权重。
Encoder和Decoder
Encoder
Decoder
总结
参考
【ICML 2022】时间序列预测——FEDformer (Frequency Enhanced Decomposed Transformer)