【PyTorch深度学习】 第十一讲:RNN基础

该博客介绍了如何使用PyTorch实现一个简单的RNN模型,处理带有序列信息的数据。通过定义RNNCell并结合交叉熵损失函数进行训练,模型在20个迭代周期后逐渐学习到输入序列模式,输出预测的字符序列‘hello’,同时展示了训练过程中的损失变化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.概念原理

处理带有序列信息的数据(seqlen)
RNNCell:线性层,循环执行

2.代码实现

# author:ZhuYuYing
# data:2021/7/20
# projectName:tor-start

import torch

batch_size = 1
input_size = 4
hidden_size = 4

# prepare dataset
idx2char =['e','h','l','o']
x_data = [1,0,2,2,3]
y_data = [1,0,2,2,3]

one_hot_Lookup = [[1,0,0,0],
                  [0,1,0,0],
                  [0,0,1,0],
                  [0,0,0,1]]

x_one_hot = [one_hot_Lookup[x] for x in x_data]

inputs = torch.Tensor(x_one_hot).view(-1, batch_size, input_size)
labels = torch.LongTensor(y_data).view(-1,1)



# design model using class

class Model(torch.nn.Module):
    def __init__(self, input_size, hidden_size, batch_size):
        super(Model, self).__init__()
        self.batch_size = batch_size
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.rnncell = torch.nn.RNNCell(input_size=self.input_size,hidden_size=self.hidden_size)

    def forward(self, input, hidden):
        hidden = self.rnncell(input, hidden)
        return hidden

    def init_hidden(self): #初始隐层
        return torch.zeros(self.batch_size, self.hidden_size)

net = Model(input_size, hidden_size, batch_size)

# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss() #交叉熵
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)#adam优化器


# training cycle forward, backward, update
for epoch in range(20):
    loss = 0
    optimizer.zero_grad()#归零优化器梯度
    hidden = net.init_hidden()  #计算h0
    print('Predicted string: ', end='')
    for input, label in zip(inputs, labels):#input = seglen x batch x input
        hidden = net(input, hidden)
        loss += criterion(hidden, label) #不用item用了+=:构建计算图
        _, idx = hidden.max(dim=1)  #
        print(idx2char[idx.item()], end='')
    loss.backward() #优化
    optimizer.step()
    print(', Epoch [%d/15] loss=%.4f' % (epoch+1, loss.item()))

3.结果展示

Predicted string: oehho, Epoch [1/15] loss=6.2174
Predicted string: oelll, Epoch [2/15] loss=5.1659
Predicted string: lllll, Epoch [3/15] loss=4.7856
Predicted string: llllo, Epoch [4/15] loss=4.3970
Predicted string: hello, Epoch [5/15] loss=3.9766
Predicted string: hello, Epoch [6/15] loss=3.6266
Predicted string: hello, Epoch [7/15] loss=3.3810
Predicted string: hello, Epoch [8/15] loss=3.2153
Predicted string: hello, Epoch [9/15] loss=3.0909
Predicted string: hello, Epoch [10/15] loss=2.9726
Predicted string: hello, Epoch [11/15] loss=2.8575
Predicted string: hello, Epoch [12/15] loss=2.7557
Predicted string: hello, Epoch [13/15] loss=2.6544
Predicted string: hello, Epoch [14/15] loss=2.5534
Predicted string: hello, Epoch [15/15] loss=2.4699



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值