rnncell简单实现

先输入hello目标输出ohlol

第一步:定义初始数据

import torch
batch_size=1

input_size=4
hidden_size=4

进行数据准备 

idx2char=['e','h','l','o']
x_data=[1,0,2,2,3]
y_data=[3,1,2,3,2]
#查询字典
one_hot_lookup=[[1,0,0,0],
                [0,1,0,0],
                [0,0,1,0],
                [0,0,0,1]]
#x_data转化x_one_hot编码
x_one_hot=[one_hot_lookup[x] for x in x_data]
#input_size维度为(seqlen,batchsize,inputsize)
inputs=torch.Tensor(x_one_hot).view(-1,batch_size,input_size)
#labels维度为(seqlen,1)
labels=torch.LongTensor(y_data).view(-1,1)

建立模型

class Model(torch.nn.Module):
    def __init__(self,input_size,hidden_size,batch_size):
        super(Model,self).__init__()
        self.batch_size=batch_size
        self.input_size=input_size
        self.hidden_size=hidden_size
        #rnncell内部激活函数为tanh
        self.rnncell=torch.nn.RNNCell(input_size=self.input_size,
                                      hidden_size=self.hidden_size)
    def forward(self,input,hidden):
        hidden=self.rnncell(input,hidden)
        return hidden
    def init_hidden(self):
        #hidden_0
        return torch.zeros(self.batch_size,self.hidden_size)
net=Model(input_size,hidden_size,batch_size)

损失函数和优化器

criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(net.parameters(),lr=0.1)

 开始训练

for epoch in range(15):
    loss=0
    optimizer.zero_grad()
    hidden=net.init_hidden()
    print('pre string:',end='')
    for input,label in zip(inputs,labels):
        hidden=net(input,hidden)
        loss+=criterion(hidden,label)
        #.max()第一个返回值是内容,第二个是位置
        _,idx=hidden.max(dim=1)
        print(idx2char[idx.item()],end='')
    loss.backward()
    optimizer.step()
    print(',Epoch [%d/15] loss=%.4f' % (epoch+1,loss.item()))

结果:

pre string:lheee,Epoch [1/15] loss=7.3677
pre string:lheel,Epoch [2/15] loss=6.0452
pre string:ohool,Epoch [3/15] loss=5.1959
pre string:ohool,Epoch [4/15] loss=4.6100
pre string:ohool,Epoch [5/15] loss=4.1878
pre string:ohohl,Epoch [6/15] loss=3.8787
pre string:ohlhl,Epoch [7/15] loss=3.6197
pre string:ohlol,Epoch [8/15] loss=3.3840
pre string:ohlol,Epoch [9/15] loss=3.1649
pre string:ohlol,Epoch [10/15] loss=2.9615
pre string:ohlol,Epoch [11/15] loss=2.7729
pre string:ohlol,Epoch [12/15] loss=2.6003
pre string:ohlol,Epoch [13/15] loss=2.4457
pre string:ohlol,Epoch [14/15] loss=2.3062
pre string:ohlol,Epoch [15/15] loss=2.1779

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值