直接上代码:
#coding = utf-8
import torch
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import gym
#parameters
Batch_size = 32
Lr = 0.01
Epsilon = 0.9 #greedy policy
Gamma = 0.9 #reward discount
Target_replace_iter = 100 #target update frequency
Memory_capacity = 2000
env = gym.make('CartPole-v0')
env = env.unwrapped
N_actions = env.action_space.n
N_states = env.observation_space.shape[0]
ENV_A_SHAPE = 0 if isinstance(env.action_space.sample(), int) else env.action_space.sample().shape
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(N_states,50)
self.fc1.weight.data.normal_(0,0.1)
self.out = nn.Linear(50,N_actions)
self.out.weight.data.normal_(0,0.1)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
actions_value =self.out(x)
return