Traceback (most recent call last):
File "C:\Users\Kanyun\Desktop\FEDformer\run.py", line 153, in <module>
exp.train(setting, args.root_path, args.data_path)
File "C:\Users\Kanyun\Desktop\FEDformer\exp\exp_main.py", line 161, in train
outputs = self.model(batch_x, dec_inp) # 修改为只传入 batch_x 和 dec_inp
File "C:\Users\Kanyun\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Kanyun\Desktop\FEDformer\models\Informer.py", line 76, in forward
dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
File "C:\Users\Kanyun\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Kanyun\Desktop\FEDformer\layers\Transformer_EncDec.py", line 124, in forward
x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
File "C:\Users\Kanyun\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Kanyun\Desktop\FEDformer\layers\Transformer_EncDec.py", line 99, in forward
attn_mask=x_mask
File "C:\Users\Kanyun\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Kanyun\Desktop\FEDformer\layers\SelfAttention_Family.py", line 177, in forward
attn_mask
File "C:\Users\Kanyun\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Kanyun\Desktop\FEDformer\layers\SelfAttention_Family.py", line 135, in forward
scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)
File "C:\Users\Kanyun\Desktop\FEDformer\layers\SelfAttention_Family.py", line 76, in _prob_QK
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
IndexError: max(): Expected reduction dim 2 to have non-zero size.
Process finished with exit code 1
根据以上报错修改对应代码:
def _prob_QK(self, Q, K, sample_k, n_top):
B, H, L_K, E = K.shape
_, _, L_Q, _ = Q.shape
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
index_sample = torch.randint(L_K, (L_Q, sample_k))
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
M_top = M.topk(n_top, sorted=False)[1]
Q_reduce = Q[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None],
M_top, :]
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
return Q_K, M_top