1、完整代码如下:
class TransformerModel(nn.Module):
def __init__(self,input_dim,output_dim,num_heads,num_layers,hidden_dim,max_len = 5000):
super(TransformerModel,self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
#输入嵌入层
self.embedding = nn.Linear(input_dim,hidden_dim)
#位置编码
self.position_encoding = PositionEncoding(hidden_dim,max_len)
#Transformer编码器层
self.encoder_layer = nn.TransformerEncoderLayer(d_model = hidden_dim,nhead = num_heads,
dim_feedforward = hidden_dim * 2)
#Transformer编码器
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer,num_layers = num_layers)
#输出层
self.output_layer = nn.Linear(hidden_dim,output_dim)
def forward(self, x):
#x的数据格式:(batch_size,se1_len,input_dim)
#然后经过嵌入层,转换为了(batch_size,se1_len,hidden_dim)
x = self.embedding(x) #嵌入层
#进行位置编码
x = self.position_encoding(x)
#将维度进行转置:(batch_size, seq_len, hidden_dim) -->> (seq_len, batch_size, hidden_dim)
x = x.permute(1,0,2)
#通过Transformer编码器
x = self.transformer_encoder(x)
#单步预测,输出最后一个时间步
last_output = x[:, -1, :]
#输出层
output = self.output_layer(last_output)
return output