先输入hello目标输出ohlol
第一步:定义初始数据
import torch
batch_size=1
input_size=4
hidden_size=4
进行数据准备
idx2char=['e','h','l','o']
x_data=[1,0,2,2,3]
y_data=[3,1,2,3,2]
#查询字典
one_hot_lookup=[[1,0,0,0],
[0,1,0,0],
[0,0,1,0],
[0,0,0,1]]
#x_data转化x_one_hot编码
x_one_hot=[one_hot_lookup[x] for x in x_data]
#input_size维度为(seqlen,batchsize,inputsize)
inputs=torch.Tensor(x_one_hot).view(-1,batch_size,input_size)
#labels维度为(seqlen,1)
labels=torch.LongTensor(y_data).view(-1,1)
建立模型
class Model(torch.nn.Module):
def __init__(self,input_size,hidden_size,batch_size):
super(Model,self).__init__()
self.batch_size=batch_size
self.input_size=input_size
self.hidden_size=hidden_size
#rnncell内部激活函数为tanh
self.rnncell=torch.nn.RNNCell(input_size=self.input_size,
hidden_size=self.hidden_size)
def forward(self,input,hidden):
hidden=self.rnncell(input,hidden)
return hidden
def init_hidden(self):
#hidden_0
return torch.zeros(self.batch_size,self.hidden_size)
net=Model(input_size,hidden_size,batch_size)
损失函数和优化器
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(net.parameters(),lr=0.1)
开始训练
for epoch in range(15):
loss=0
optimizer.zero_grad()
hidden=net.init_hidden()
print('pre string:',end='')
for input,label in zip(inputs,labels):
hidden=net(input,hidden)
loss+=criterion(hidden,label)
#.max()第一个返回值是内容,第二个是位置
_,idx=hidden.max(dim=1)
print(idx2char[idx.item()],end='')
loss.backward()
optimizer.step()
print(',Epoch [%d/15] loss=%.4f' % (epoch+1,loss.item()))
结果:
pre string:lheee,Epoch [1/15] loss=7.3677
pre string:lheel,Epoch [2/15] loss=6.0452
pre string:ohool,Epoch [3/15] loss=5.1959
pre string:ohool,Epoch [4/15] loss=4.6100
pre string:ohool,Epoch [5/15] loss=4.1878
pre string:ohohl,Epoch [6/15] loss=3.8787
pre string:ohlhl,Epoch [7/15] loss=3.6197
pre string:ohlol,Epoch [8/15] loss=3.3840
pre string:ohlol,Epoch [9/15] loss=3.1649
pre string:ohlol,Epoch [10/15] loss=2.9615
pre string:ohlol,Epoch [11/15] loss=2.7729
pre string:ohlol,Epoch [12/15] loss=2.6003
pre string:ohlol,Epoch [13/15] loss=2.4457
pre string:ohlol,Epoch [14/15] loss=2.3062
pre string:ohlol,Epoch [15/15] loss=2.1779