目录
1.RNNCell

import torch
input_size = 4
hidden_size = 4
batch_size = 1
idx2char = ['e', 'h', 'l', 'o']
x_data = [1, 0, 2, 2, 3] # 输入:hello
y_data = [3, 1, 2, 3, 2] # 期待:ohlol
# 独热向量
one_hot_lookup = [[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]]
x_one_hot = [one_hot_lookup[x] for x in x_data]
inputs = torch.Tensor(x_one_hot).view(-1, batch_size, input_size) # (seqLen,batchSize,inputSize)
labels = torch.LongTensor(y_data).view(-1, 1) # (seqLen,1)
class Model(torch.nn.Module):
def __init__(self, input_size, hidden_size, batch_size):
super(Model, self).__init__()
self.batch_size = batch_size
self.input_size = input_size
self.hidden_size = hidden_size
self.rnncell =

最低0.47元/天 解锁文章
1万+

被折叠的 条评论
为什么被折叠?



