DQN PYTORCH 玩FlappyBird

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# import gym

from ple.games.flappybird import FlappyBird
from ple import PLE
from pygame.constants import K_w
import time
import random
import collections
import numpy as np
import os

# Hyper Parameters
BATCH_SIZE = 32
LR = 0.0001                 # learning rate
EPSILON = 0.9               # greedy policy
GAMMA = 0.999                 # reward discount
TARGET_REPLACE_ITER = 100   # target update frequency
MEMORY_CAPACITY = 20000

game = FlappyBird()
env = PLE(game, fps=30, display_screen=True)

N_ACTIONS =2# env.action_space.n
N_STATES = 8#env.observation_space.shape[0]


class Net(nn.Module):
    def __init__(self, ):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(N_STATES, 128)
        self.fc1.weight.data.normal_(0, 0.1)   # initialization

        self.fc2 = nn.Linear(128, 128)
        self.fc2.weight.data.normal_(0, 0.1)   # initialization

        self.out = nn.Linear(128, N_ACTIONS)
        self.out.weight.data.normal_(0, 0.1)   # initialization

    def forward(self, x):
        x = self.fc1(x)
        x = F.tanh(x)
        x = self.fc2(x)
        x = F.tanh(x)
        actions_value = self.out(x)
        return actions_value


class DQN(object):
    def __init__(self):
        self.eval_net, self.target_net = Net(), Net()

        self.learn_step_counter = 0                                     # for target updating
        self.memory_counter = 0 
        self.f1=  "/home/zhangym/spinningup/rl1/ple/data/dqn_path_evala.pkl"                                      # for storing memory
        self.f2=  "/home/zhangym/spinningup/rl1/ple/data/dqn_path_targeta.pkl"   
        self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2))     # initialize memory
        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)
        self.loss_func = nn.MSELoss()
        self.load_model()

    def __del__(self):
        pass
        # self.save_model()

    def save_model(self):
        print('save model----')
        torch.save(self.eval_net.state_dict(), self.f1)
        torch.save(self.target_net.state_dict(),self.f2)
        print('save model')

    def load_model(self):
        if os.path.exists(self.f1):
            
            self.eval_net.load_state_dict(torch.load(self.f1))
            self.target_net.load_state_dict( torch.load(self.f2))
            print('load model')
    def choose_action(self, x, e= EPSILON):
        x = torch.unsqueeze(torch.FloatTensor(x), 0)
        # input only one sample
        if np.random.uniform() < e:   # greedy
            actions_value = self.eval_net.forward(x)
            action = torch.max(actions_value, 1)[1].data.numpy()
        else:   # random
            action = np.random.randint(0, N_ACTIONS)

        return action

    def store_transition(self, s, a, r, s_,isdone=False):
        # print('h',s, a, r, s_)
        transition = np.hstack((s, [a, r], s_))
        # replace the old memory with new memory
        index = self.memory_counter % MEMORY_CAPACITY
        self.memory[index, :] = transition
        self.memory_counter += 1

    def learn(self):
        # target parameter update
        if self.learn_step_counter % TARGET_REPLACE_ITER == 0:
            self.target_net.load_state_dict(self.eval_net.state_dict())
        self.learn_step_counter += 1

        # sample batch transitions
        sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)
        b_memory = self.memory[sample_index, :]

        b_s = torch.FloatTensor(b_memory[:, :N_STATES])
        b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int))
        b_r = torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2])
        b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:])

        q_eval = self.eval_net(b_s).gather(1, b_a)  # shape (batch, 1)
        # print('b_a',b_a)
        q_next = self.target_net(b_s_).detach()     # detach from graph, don't backpropagate
        q_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1)   # shape (batch, 1)
        loss = self.loss_func(q_eval, q_target)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

dqn = DQN()
env.init()
reward = env.act(None)
env.reset_game()

print('\nCollecting experience...')

t = 0
ep_r = 0
s=None
s_=list(env.getGameState().values())
for i_episode in range(30):
    # env.render()
    ep_r = 0
    
    while 1:
        s=s_
        a = dqn.choose_action(s,0.95)
        t = t+1
        ac=None
        if a:
            ac=K_w

        r=env.act(ac)
        s_=list(env.getGameState().values())
        done = env.game_over()
            
        time.sleep(0.01)

        dqn.store_transition(s, a, r, s_,done)
        ep_r += r
        # print(dqn.memory_counter)
        if dqn.memory_counter > MEMORY_CAPACITY:
            dqn.learn()
        if done:
            env.reset_game()
            print('t=',t,'Ep: ', i_episode,
                '| Ep_r: ', round(ep_r, 2))
            break

dqn.save_model()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值