神经网络构建Google Chrome Dinosaur游戏AI
大家是否知道谷歌浏览器在断网状态下,点击空格即可开始游玩的一款小游戏,就叫它谷歌小恐龙(Google Chrome Dinosaur)吧。该款游戏很简单,玩家可以通过点击‘Space’键跳跃,点击‘下’箭头趴下,来躲避障碍物,生存的时间越长,得分相应的越高。前一段时间由于要做一个实训项目,因此学习了神经网络等知识。突然想到,能否利用神经网络来实现该款游戏的AI呢?
正好我们当时使用某一卷积神经网络训练另一款游戏,该算法是在AlexNet基础上进行改进的,以游戏截图和玩家的操作作为模型的输入数据集,来训练模型。于是,想把该算法移植到该款游戏中,最终在博主的电脑上成功运行。下图是游戏AI取得的成绩,很不错吧。
OK,言归正传,让我介绍一下基本实现步骤吧。
前言
该算法的实现需要用到多个Python模块,有些模块的安装可能会因为python版本等原因出现一些问题,请大家自行百度解决。详细python模块请见源代码。有趣的一点是,该款游戏只需要在合适的时机点击空格,即可顺利通过障碍物。为了简化训练过程,模型训练的数据集里的操作只有两种,即:点击空格或者什么也不做。
一、游戏场景实现
前面已经解释了,该算法是以截图和操作为输入的,但是谷歌小恐龙是变速且白天黑夜交替的游戏。为了简化实现,博主在网上找到了该款游戏的完美再现。原代码见:https://github.com/alphabeats7/Dinosaur.git 中的Dinosaur-map文件夹,Dinosaur-last为该AI项目的源代码(训练得到的模型由于太大没有上传)。下面的6个python文件即可实现该项目。
二、截图函数实现
截图函数的作用即是截取特定区域的屏幕,需要用到cv2、numpy、win32gui、win32ui、win32con、win32api等模块。代码如下(ord_grab.py):
import cv2
import numpy as np
import win32gui, win32ui, win32con, win32api
def grab_screen(region=None):
hwin = win32gui.GetDesktopWindow()
if region:
left, top, x2, y2 = region
width = x2 - left + 1
height = y2 - top + 1
else:
width = win32api.GetSystemMetrics(win32con.SM_CXVIRTUALSCREEN)
height = win32api.GetSystemMetrics(win32con.SM_CYVIRTUALSCREEN)
left = win32api.GetSystemMetrics(win32con.SM_XVIRTUALSCREEN)
top = win32api.GetSystemMetrics(win32con.SM_YVIRTUALSCREEN)
hwindc = win32gui.GetWindowDC(hwin)
srcdc = win32ui.CreateDCFromHandle(hwindc)
memdc = srcdc.CreateCompatibleDC()
bmp = win32ui.CreateBitmap()
bmp.CreateCompatibleBitmap(srcdc, width, height)
memdc.SelectObject(bmp)
memdc.BitBlt((0, 0), (width, height), srcdc, (left, top), win32con.SRCCOPY)
signedIntsArray = bmp.GetBitmapBits(True)
img = np.fromstring(signedIntsArray, dtype='uint8')
img.shape = (height, width, 4)
srcdc.DeleteDC()
memdc.DeleteDC()
win32gui.ReleaseDC(hwin, hwindc)
win32gui.DeleteObject(bmp.GetHandle())
return cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
三、键盘操作
该Python文件作用是获取键盘的输入。代码如下(get_key.py):
import win32api as wapi
keyList = ["\b"]
for char in "ABCDEFGHIJKLMNOPQRSTUVWXYZ 123456789,.'£$/\\":
keyList.append(char)
def key_check():
keys = []
for key in keyList:
# 转换为对应的ASCII码
if wapi.GetAsyncKeyState(ord(key)):
keys.append(key)
# 0x20为space按键的16进制表示
el