具体代码如下
import torch
# 准备数据
index_chart = ['e', 'h', 'l', 'o']
x_data = [1, 0, 2, 2, 3]
y_data = [1, 0, 0, 3, 2]
one_hot_lookup = [[1, 0, 0, 0], # 设置一个索引表
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]]
x_one_hot = [one_hot_lookup[x] for x in x_data]
input_size = 4
batch_size = 1
inputs = torch.Tensor(x_one_hot).view(-1, batch_size, input_size)
labels = torch.LongTensor(y_data).view(-1, 1) # 增加维度方便计算loss
# 设计网络模型
class LSTM(torch.nn.Module):
# 进行基础设置
def __init__(self):
super(LSTM, self).__init__()
self.lineari = torch.nn.Linear(4, 4)
self.linearf = torch.nn.Linear(4, 4)
self.linearc = torch.nn.Linear(4, 4)
self.linearo = torch.nn.Linear(4, 4)
self.sigmoid = torch.nn.Sigmoid()
self.tanh = torch.nn.Tanh()
# 设置前向传播函数
def forward(sel