word_embed = mx.sym.Embedding(data=seq, input_dim=vocab_size, output_dim=num_embed, name='seq_embed')
改为:
word_embed = mx.sym.Embedding(data=mx.sym.BlockGrad(seq), input_dim=vocab_size, output_dim=num_embed, name='seq_embed')
即把输入数据用mx.sym.BlockGrad
处理一下。
这相当于在反向计算梯度时,在这里不计算梯度,直接把数据传过去。详见这里