1.概念原理
处理带有序列信息的数据(seqlen)
RNNCell:线性层,循环执行
2.代码实现
# author:ZhuYuYing
# data:2021/7/20
# projectName:tor-start
import torch
batch_size = 1
input_size = 4
hidden_size = 4
# prepare dataset
idx2char =['e','h','l','o']
x_data = [1,0,2,2,3]
y_data = [1,0,2,2,3]
one_hot_Lookup = [[1,0,0,0],
[0,1,0,0],
[0,0,1,0],
[0,0,0,1]]
x_one_hot = [one_hot_Lookup[x] for x in x_data]
inputs = torch.Tensor(x_one_hot).view(-1, batch_size, input_size)
labels = torch.LongTensor(y_data).view(-1,1)
# design model using class
class Model(torch.nn.Module):
def __init__(self, input_size, hidden_size, batch_size):
super(Model, self).__init__()
self.batch_size = batch_size
self.input_size = input_size
self.hidden_size = hidden_size
self.rnncell = torch.nn.RNNCell(input_size=self.input_size,hidden_size=self.hidden_size)
def forward(self, input, hidden):
hidden = self.rnncell(input, hidden)
return hidden
def init_hidden(self): #初始隐层
return torch.zeros(self.batch_size, self.hidden_size)
net = Model(input_size, hidden_size, batch_size)
# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss() #交叉熵
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)#adam优化器
# training cycle forward, backward, update
for epoch in range(20):
loss = 0
optimizer.zero_grad()#归零优化器梯度
hidden = net.init_hidden() #计算h0
print('Predicted string: ', end='')
for input, label in zip(inputs, labels):#input = seglen x batch x input
hidden = net(input, hidden)
loss += criterion(hidden, label) #不用item用了+=:构建计算图
_, idx = hidden.max(dim=1) #
print(idx2char[idx.item()], end='')
loss.backward() #优化
optimizer.step()
print(', Epoch [%d/15] loss=%.4f' % (epoch+1, loss.item()))
3.结果展示
Predicted string: oehho, Epoch [1/15] loss=6.2174
Predicted string: oelll, Epoch [2/15] loss=5.1659
Predicted string: lllll, Epoch [3/15] loss=4.7856
Predicted string: llllo, Epoch [4/15] loss=4.3970
Predicted string: hello, Epoch [5/15] loss=3.9766
Predicted string: hello, Epoch [6/15] loss=3.6266
Predicted string: hello, Epoch [7/15] loss=3.3810
Predicted string: hello, Epoch [8/15] loss=3.2153
Predicted string: hello, Epoch [9/15] loss=3.0909
Predicted string: hello, Epoch [10/15] loss=2.9726
Predicted string: hello, Epoch [11/15] loss=2.8575
Predicted string: hello, Epoch [12/15] loss=2.7557
Predicted string: hello, Epoch [13/15] loss=2.6544
Predicted string: hello, Epoch [14/15] loss=2.5534
Predicted string: hello, Epoch [15/15] loss=2.4699