本篇博客主要介绍采用RNN做回归。
示例代码:
import torch
from torch import nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
# 超参数
TIME_STEP = 10
INPUT_SIZE = 1
LR = 0.02
# 生成回归数据并显示
# steps = np.linspace(0, np.pi * 2, 100, dtype=np.float32)
# x_np = np.sin(steps)
# y_np = np.cos(steps)
# plt.plot(steps, y_np, 'r-', label='target (cos)')
# plt.plot(steps, x_np, 'b-', label='target (sin)')
# plt.show()
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.RNN(
input_size=INPUT_SIZE,
hidden_size=32, # hidden state的神经元数目