介绍
使用PyTorch从OpenAI Gym中的 CartPole-v0 任务上训练一个Deep Q Learning
Agent 必须在两个动作之间做出决定 - 向左或向右移动推车 - 以使连接到它的杆保持直立。
分析过程
https://pytorch123.com/SeventhSection/ReinforcementLearning/
实验结果
完整代码+详细注释
"""
1. 需要的包
"""
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
import warnings
warnings.filterwarnings("ignore", category=UserWarning) # 用于过滤掉一种警告
env = gym.make('CartPole-v0').unwrapped # 定义使用gym库中的某一个环境,'CartPole-v0'可以改为其它环境 # 据说不做这个动作会有很多限制,unwrapped是打开限制的意思
# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend() # get_backend() Return the name of the current backend.
if is_ipython:
from IPython import display
plt.ion() # plt.ion()这个函数,使matplotlib的显示模式转换为交互(interactive)模式。即使在脚本中遇到plt.show(),代码还是会继续执行。
# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
"""
2. 复现记忆
"""
Transition = namedtuple('Transition',
('state', 'action', 'next_state', 'reward'))
# * ReplayMemory :有界大小的循环缓冲区,用于保存最近观察到的过渡。
class ReplayMemory(object):
def __init__(self, capacity):
self.capacity = capacity # 10000
self.memory = []
self.position = 0
def push(self, *args):
"""Saves a transition."""
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = Transition(*args) # 参数前面加上* 号 ,意味着参数的个数不止一个,另外带一个星号(*)参数的函数传入的参数存储为一个元组(tuple),带两个(*)号则是表示字典(dict)
self