import mindspore as ms
import mindspore.nn as nn
import gymnasium as gym
import numpy as np
import collections
import matplotlib.pyplot as plt
class ReplayBuffer:
def __init__(self,capacity):
self.buffer=collections.deque(maxlen=capacity)
self.count=0
def push(self,state,action,reward,next_state,done):
self.buffer.append((state,action,reward,next_state,done))
self.count+=1
def sample(self,batch_size):
samples=np.random.randint(0,len(self.buffer),batch_size)
states,actions,rewards,next_states,dones=zip(*[self.buffer[i] for i in samples])
transition_dict={
"states": states,
"actions": actions,
"rewards": rewards,
"next_states": next_states,
"dones": dones
}
return transition_dict
class DQNNet(nn.Cell):
def __init__(self,state_dim,hidden_dim,action_dim):
super(DQNNet,self).__init__()
self.fc1=nn.Dense(state_dim,hidden_dim)
self.fc2=nn.Dense(hidden_dim,action_dim)
self.relu=nn.ReLU()
def construct(self,x):
x=self.fc1(x)
x=self.relu(x)
x=self.fc2(x)
return x
class DQNAgent:
def __init__(self,state_dim,hidden_dim,action_dim,learning_rate,discount_factor,
buffer_capacity,epsilon_min,epsilon_decay):
self.q_net=DQNNet(state_dim,hidden_dim,action_dim)
self.target_net=DQNNet(state_dim,hidden_dim,action_dim)
self.target_net.load_state_dict(self.q_net.parameters_dict())
self.loss_fn=nn.MSELoss()
self.optimizer=nn.Adam(self.q_net.trainable_params(),learning_rate)
self.grad_fn=ms.value_and_grad(
self.forward_fn,
None,self.optimizer.parameters
)
self.action_dim=action_dim
self.discount_factor=discount_factor
self.epsilon_max=1-epsilon_min
self.epsilon_min=epsilon_min
self.epsilon=1
self.epsilon_decay=epsilon_decay
self.count=0
self.update_count=0
self.buffer=ReplayBuffer(buffer_capacity)
def forward_fn(self, states, actions, td_targets):
# 使用 mindspore.numpy.arange 构造行索引
indices = ms.numpy.arange(actions.shape[0]) # [0, 1, ..., batch_size-1]
actions_flat = actions.squeeze(-1) # [B]
q_all = self.q_net(states) # [B, A]
q_values = q_all[indices, actions_flat] # [B], 每个样本取一个动作的Q值
td_targets_flat = td_targets.squeeze(-1) # [B]
loss = self.loss_fn(q_values, td_targets_flat)
return loss
def take_action(self,state):
if np.random.random()<self.epsilon:
action=np.random.randint(0,self.action_dim)
else:
state=ms.tensor(state,dtype=ms.float32)
action=self.q_net(state)
action=action.argmax().item()
self.count+=1
self.epsilon=self.epsilon_min+self.epsilon_max*np.exp(-self.count/self.epsilon_decay)
return action
def predict_action(self,state):
state=ms.tensor([state],dtype=ms.float32)
action=self.q_net(state)
action=action.argmax().item()
return action
def update(self,transition_dict):
states=ms.tensor(transition_dict["states"],dtype=ms.float32)
rewards=ms.tensor(transition_dict["rewards"],dtype=ms.float32).view(-1,1)
actions=ms.tensor(transition_dict["actions"],dtype=ms.int32).view(-1,1)
next_states=ms.tensor(transition_dict["next_states"],dtype=ms.float32)
# noinspection SpellCheckingInspection
dones=ms.tensor(transition_dict["dones"],dtype=ms.float32).view(-1,1)
next_rewards=self.target_net(next_states).max(1)[0].view(-1,1)
td_rewards=rewards+self.discount_factor*next_rewards*(1-dones)
loss,grads=self.grad_fn(states, actions,td_rewards)
self.optimizer(grads)
self.update_count+=1
if self.update_count % 30 ==0:
self.target_net.load_state_dict(self.q_net.parameters_dict())
def train(env,agent,batch_size,num_episodes):
return_list=[]
return_mean_list=[]
for episode in range(num_episodes):
state,_=env.reset()
terminated=truncated=False
episode_reward=0
while not (terminated or truncated):
action=agent.take_action(state)
next_state,reward,terminated,truncated,_=env.step(action)
agent.buffer.push(state,action,reward,next_state,terminated or truncated)
episode_reward+=reward
if agent.buffer.count>batch_size and agent.buffer.count % 3 ==0:
transition_dict=agent.buffer.sample(batch_size)
agent.update(transition_dict)
return_list.append(episode_reward)
if (episode+1) % 100 ==0:
return_mean_list.append(np.mean(return_list[-100:]))
print(f"Episode: {episode+1}/{num_episodes}, Reward: {return_mean_list[-1]:.1f}, ")
env.close()
return return_mean_list
def display(env_name,agent):
env=gym.make(env_name,render_mode="human")
state,_=env.reset()
terminated=truncated=False
total_reward=0
total_steps=0
try:
while not (terminated or truncated):
action=agent.predict_action(state)
next_state,reward,terminated,truncated,_=env.step(action)
state=next_state
total_reward+=reward
total_steps+=1
print(f"Total steps: {total_steps}, Total reward: {total_reward}")
finally:
env.close()
env_name="CartPole-v1"
env=gym.make(env_name)
state_dim=env.observation_space.shape[0]
action_dim=int(env.action_space.n)
hidden_dim=128
learning_rate=1e-3
discount_factor=0.999
buffer_capacity=10000
epsilon_min=0.01
epsilon_decay=3000
batch_size=64
num_episodes=10000
agent=DQNAgent(state_dim,hidden_dim,action_dim,learning_rate,discount_factor,
buffer_capacity,epsilon_min,epsilon_decay)
return_mean_list=train(env,agent,batch_size,num_episodes)
为什么我的代码reward仍然随着训练不停下降?
最新发布