强化学习DQN实践——CartPole-v0完整代码分析+详细注释

介绍

使用PyTorch从OpenAI Gym中的 CartPole-v0 任务上训练一个Deep Q Learning

Agent 必须在两个动作之间做出决定 - 向左或向右移动推车 - 以使连接到它的杆保持直立。

分析过程

https://pytorch123.com/SeventhSection/ReinforcementLearning/

实验结果

完整代码+详细注释

"""
1. 需要的包
"""

import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

import warnings
warnings.filterwarnings("ignore", category=UserWarning) # 用于过滤掉一种警告

env = gym.make('CartPole-v0').unwrapped # 定义使用gym库中的某一个环境,'CartPole-v0'可以改为其它环境 # 据说不做这个动作会有很多限制,unwrapped是打开限制的意思
# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()   # get_backend() Return the name of the current backend.
if is_ipython:
    from IPython import display

plt.ion()   # plt.ion()这个函数,使matplotlib的显示模式转换为交互(interactive)模式。即使在脚本中遇到plt.show(),代码还是会继续执行。
# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

"""
2. 复现记忆
"""

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


# * ReplayMemory :有界大小的循环缓冲区,用于保存最近观察到的过渡。
class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity    # 10000
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)  # 参数前面加上* 号 ,意味着参数的个数不止一个,另外带一个星号(*)参数的函数传入的参数存储为一个元组(tuple),带两个(*)号则是表示字典(dict)
        self
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值