训练简单小游戏的强化学习工具箱

代码地址如下:
http://www.demodashi.com/demo/14072.html

详细

先上效果图:

  • 启动界面

启动界面

  • 主界面

主界面

  • 设置界面

设置界面

  • 服务器界面(使用highchart模板画出每一局得分情况)

服务器界面

配置的两款简单小游戏以及训练效果:

  • 贪吃蛇

  • “是男人就下一百层”(修改)

跳跳人
*原图像太大被迫修改大小

使用说明:

【设置窗口】

→在上面的主界面中点击倒三角形状的键,屏幕上会弹出一个黑色的设置窗。在该窗口界面上,用户可以通过拖动滑块条、在框内输入具体数值两种方法设置模型参数。滑块条和编辑框互联。

【在服务器上查看训练结果】

→点击最小化按钮,将会复制浏览器地址到剪切板上,可以将其粘贴到浏览器中实时监测训练情况。窗口中的折线图每隔五秒从temp.db数据库中获取更新的数据并加入到折线图中,实施实时数据可视化。

【关闭按钮】


→当点击关闭按钮时,若训练次数超过1000帧,将会弹出窗口询问是否保存记录。否则会由于训练次数过少,对训练没有意义而直接退出不保存结果,以提高效率。

→点击确认

→成功保存

【新建模式训练】

→选择训练游戏

→开始训练(点击播放按钮)

→鼠标放在进度条上能看到具体数值

【加载模式训练】

→点击切换按钮



→此时再点击播放按钮,会弹出窗口用于选择加载模型

→点击开始按钮开始训练,同时设置窗口按钮、模式转换按钮都会失效,以确保训练顺利进行。

1、相关配置

  • Python 3
  • TensorFlow-gpu
  • pygame
  • OpenCV-Python
  • PyQt5
  • sys
  • threading
  • multiprocessing
  • shelve
  • os
  • sqlite3
  • socket
  • pyperclip
  • flask
  • glob
  • shutil
  • numpy
  • pandas
  • time
  • importlib

2、文件目录

|————MyLibrary.py 用于设置游戏中人物等类
|————run_window.py 启动主程序,包括启动界面
|————mainwindow.py 主界面程序
|————setting.py 参数调节窗口程序
|————message_box.py 消息框窗口程序
|————DQL.py 人工智能主程序,负责选择和启动游戏、启动深度强化学习内核
|————DQLBrain.py 深度强化学习内核
|————game_setting.py 存储已有游戏决策状态数、库名等信息,新游戏加入必须将相关信息也加入在其中
|————flask_tk.py 服务器文件
|————jumpMan.py 跳跳人游戏文件
|————greedySnake.py 贪吃蛇游戏文件
|————resource 窗口图片资源文件夹
|————save_networks 已得出的模型文件
|————templates
   |————index.html 网页前端模板文件
|————static
   |————exporting.js
   |————highcharts-zh_CN.js
   |————highstock.js
   |————jquery.js
|————temp.db 临时数据库,用于服务器和AI端数据交互使用
|————greedy_snake.data-00000-of-00001
|————greedy_snake.index
|————greedy_snake.meta 以上三个为一个训练好的模型
|————greedy_snake.db.bak
|————greedy_snake.db.dat
|————greedy_snake.db.dir 以上三个为一个模型文件
|————setting_resource.py 设定窗口的资源文件
|————resource_message_box.py 消息框窗口的资源文件
|————resource.py 主窗口的资源文件
|————document.py 根据数据库文件自动化生成报告

3、实现过程

整个demo主要分为四大部分:主窗口、算法和游戏内核、服务器以及管理版本数据库文件部分。

各模块之间的关系

  • 启动界面

	import sys
	from  mainWindow import MAINWINDOW
	from PyQt5.QtWidgets import QApplication,QSplashScreen
	from PyQt5 import QtCore,QtGui,QtWidgets
	if __name__ == '__main__':
		app = QApplication(sys.argv)

		#初始化启动界面
		splash=QtWidgets.QSplashScreen(QtGui.QPixmap("启动界面.png"))

		#展示启动界面
		splash.show()

		#设置计时器
		timer = QtCore.QElapsedTimer()

		#计时器开始
		timer.start()

		#保证启动界面出现3s
		while timer.elapsed() < 3000:
			app.processEvents()

		#初始化主界面
		MainWindow = MAINWINDOW()

		#展示主界面
		MainWindow.show()

		#主界面完全加载后,启动界面消失
		splash.finish(MainWindow)

		sys.exit(app.exec_())
  • 主界面(均使用Qtdesigner完成)

    import gameSetting
    import resource
    from PyQt5 import QtWidgets,QtCore,QtGui
    from collections import deque
    from threading import Thread
    from multiprocessing import Process
    import shelve
    import sqlite3
    import socket
    import pyperclip
    from DQL import AI
    import setting
    import messageBox
    import webServers
    import glob
    import shutil
    
    game_start=False
    
    class myThread(Thread):
        def __init__(self,game,model,replay_memory,timestep,setting):
            Thread.__init__(self)
            self.game=game
            self.model=model
            self.setting=setting
            self.replay_memory=replay_memory
            self.timestep=timestep
    
        def run(self):
            self.AI = AI(self.game,self.model,self.replay_memory,self.timestep,int(self.setting["Explore"]),float(self.setting["Initial"]),float(self.setting["Final"]),float(self.setting["Gamma"]),int(self.setting["Replay"]),int(self.setting["Batch"]),)
            self.AI.playGame()
    
        def stop(self):
            self.AI.closeGame()
    
    class MAINWINDOW(QtWidgets.QWidget):
        def __init__(self, parent=None):
    
            #父类初始化
            super().__init__()
    
            #主窗体对象初始化
            self.setObjectName("Form")
            self.setEnabled(True)
            self.resize(681, 397)
            self.setStyleSheet("background-color: rgb(255, 255, 255);")
            self.setWindowFlags(QtCore.Qt.FramelessWindowHint)
    
            #进度条初始化
            self.progressBar = QtWidgets.QProgressBar(self)
            self.progressBar.setEnabled(True)
            self.progressBar.setGeometry(QtCore.QRect(140, 348, 291, 23))
            self.progressBar.setProperty("value", 0)
            self.progressBar.setTextVisible(False)
            self.progressBar.setObjectName("progressxzBar")
    
            #启动按钮初始化
            self.control = QtWidgets.QPushButton(self)
            self.control.setGeometry(QtCore.QRect(10, 325, 71, 71))
            self.control.setStyleSheet("border-image: url(:/bottom/resource/开始按钮.png);")
            self.control.setText("")
            self.control.setObjectName("control")
            self.control_state=False
    
            #下拉框初始化
            self.game_selection = QtWidgets.QComboBox(self)
            self.game_selection.setEnabled(True)
            self.game_selection.setGeometry(QtCore.QRect(530, 343, 141, 31))
            self.game_selection.setAutoFillBackground(False)
            self.game_selection.setStyleSheet("QComboBox{border-image: url(:/list/resource/下拉框.png)} \n""QComboBox::drop-down {image: url(:/bottom/resource/下拉框按钮.png)  }")
            self.game_selection.setEditable(False)
            self.game_selection.setInsertPolicy(QtWidgets.QComboBox.NoInsert)
            self.game_selection.setIconSize(QtCore.QSize(0, 0))
            self.game_selection.setFrame(False)
            self.game_selection.setObjectName("game_selection")
    
            #模式选择按钮加载
            self.mode = QtWidgets.QPushButton(self)
            self.mode.setGeometry(QtCore.QRect(440, 340, 71, 41))
            self.mode.setStyleSheet("border-image: url(:/bottom/resource/空白模式.png);\n""")
            self.mode.setText("")
            self.mode.setObjectName("mode")
            self.mode_state = False
    
            #背景图初始化
            self.label = QtWidgets.QLabel(self)
            self.label.setGeometry(QtCore.QRect(0, 0, 681, 331))
            self.label.setStyleSheet("border-image: url(:/image/resource/Background.png);")
            self.label.setText("")
            self.label.setObjectName("label")
    
            #设置按钮初始化
            self.setting = QtWidgets.QPushButton(self)
            self.setting.setGeometry(QtCore.QRect(570, 10, 31, 21))
            self.setting.setStyleSheet("border-image: url(:/bottom/resource/菜单.png);")
            self.setting.setText("")
            self.setting.setObjectName("setting")
    
            #获取ip地址按钮初始化
            self.pushButton_3 = QtWidgets.QPushButton(self)
            self.pushButton_3.setGeometry(QtCore.QRect(610, 10, 31, 23))
            self.pushButton_3.setStyleSheet("border-image: url(:/bottom/resource/最小化.png);")
            self.pushButton_3.setText("")
            self.pushButton_3.setObjectName("pushButton_3")
    
            #关闭按钮初始化
            self.bottom_close = QtWidgets.QPushButton(self)
            self.bottom_close.setGeometry(QtCore.QRect(650, 10, 21, 23))
            self.bottom_close.setStyleSheet("border-image: url(:/bottom/resource/关闭.png);")
            self.bottom_close.setText("")
            self.bottom_close.setObjectName("bottom_close") 
			
			#重设界面
            self.init_window(self)
    
            #按键消息槽设置
            self.connectBottom()
            QtCore.QMetaObject.connectSlotsByName(self)
    
        #初始化窗口
        def init_window(self, Form):
            _translate = QtCore.QCoreApplication.translate
            Form.setWindowTitle(_translate("Form", "深度强化学习工具箱"))
    
            #子窗口对象获取
            self.setting_form =  setting. SETTING()
            self.message_box=messageBox.MESSAGE_BOX()
    
            #游戏列表加载
            game_setting_dict = gameSetting.getSetting()
            for i,game in enumerate(game_setting_dict.keys()):
                self.game_selection.addItem("")
                self.game_selection.setItemText(i, _translate("Form", game))
            self.game_selection.setCurrentText(_translate("Form", list(game_setting_dict.keys())[0]))
            self.game_selection.setCurrentIndex(0)
    
            #启动服务器
            flask_process = Process(target=webServers.start)
            flask_process.daemon = True
            flask_process.start()
    
        #统一实现按键与消息函数连接
        def connectBottom(self):
            self.control.clicked.connect(self.loadGame)
            self.bottom_close.clicked.connect(self.closeWindow)
            self.mode.clicked.connect(self.setMode)
            self.setting.clicked.connect(self.openSetting)
            self.pushButton_3.clicked.connect(self.getIp)
    
        #界面可拖动设置
        def mousePressEvent(self, event):
            if event.button() == QtCore.Qt.LeftButton:
                self.m_drag = True
                self.m_DragPosition = event.globalPos() - self.pos()
                event.accept()
                self.setCursor(QtGui.QCursor(QtCore.Qt.OpenHandCursor))
    
        def mouseMoveEvent(self, QMouseEvent):
            if QtCore.Qt.LeftButton and self.m_drag:
                self.move(QMouseEvent.globalPos() - self.m_DragPosition)
                QMouseEvent.accept()
    
        def mouseReleaseEvent(self, QMouseEvent):
            self.m_drag = False
            self.setCursor(QtGui.QCursor(QtCore.Qt.ArrowCursor))
    
        #加载按键操作
        def loadGame(self):
            self.mode.setEnabled(False)
            self.setting.setEnabled(False)
    
            #开启游戏标志
            global game_start
            game_start=True
    
            #control_state为按键标志,false为还没开始游戏,true为已经开始游戏。按键外形随状态改变
            if self.control_state:
                self.closeWindow()
            else:
                #改变按键状态
                self.control.setStyleSheet("border-image: url(:/bottom/resource/终止按钮.png);")
                self.control_state =True
    
                #初始化AI需要的变量
                self.program_name = ""
                game=self.game_selection.currentText()
                model = ""
                replay_memory = deque()
                self.actual_timestep=0
                setting=self.setting_form.getSetting()
    
                #如果导入已有项目文件,那么更新上述变量
                if self.mode_state:
                    program_path = QtWidgets.QFileDialog.getOpenFileName(self, "请选择你想要加载的项目",
                                                                   "../",
                                                                   "Model File (*.dat)")
                    try:
                        #获取项目名字(无后缀,包含地址)
                        self.program_name=program_path[0][:-7]
    
                        #打开项目文件
                        with shelve.open(self.program_name+'.db') as f:
                            #加载项目信息
                            game=f["game"]
                            model = self.program_name
                            replay_memory = f["replay"]
                            setting=f["setting"]
                            self.actual_timestep = int(f["timestep"])
                            self.setting_form.updateSetting(setting)
                            self.update_dataset(f["result"])
                    except:
                        pass
    
                #启动游戏线程
                self.game_thread = myThread(game,model,replay_memory,self.actual_timestep,setting)
                self.game_thread.start()
    
                #启动状态更新计时器
                self.state_Timer = QtCore.QTimer()
                self.state_Timer.timeout.connect(self.updateState)
                self.state_Timer.start(5000)
    
        #关闭窗口
        def closeWindow(self):
            timestep=0
    
            #如果游戏根本没启动或者启动时间过短,那么按退出键则直接退出
            #这里用try是因为有时候游戏启动太慢,超过五秒
            try:
                timestep=self.state["TIMESTEP"]
            except:
                pass
    
            if timestep>1000:
                #启动对话框
                reply = self.message_box.exec_()
                if reply:
                    # 关闭游戏窗口
                    try:
                        self.game_thread.AI.closeGame()
                    except:
                        pass
                    #新建模式
                    if not self.program_name:
                        save_program_path = QtWidgets.QFileDialog.getSaveFileName(self, "请选择你保存项目的位置",
                                                                             "../",
                                                                             "Program File(*.db)")
    
                        #确保完成了完整保存操作后再进行操作
                        if save_program_path:
    
                            #获取保存的程序地址和名称(无后缀)
                            program_name = save_program_path[0].split(".")[0]
    
                            #打开程序地址
                            self.saveProgram(save_program_path,0)
    
                            #保存模型
                            self.saveModel(program_name)
    
                    #加载模式
                    else:
                        program_name=self.program_name
                        try:
                            self.saveProgram(program_name+'.db',1)
                        except:
                            pass
   
                        #保存模型
                        self.saveModel(program_name)
    
            #清空临时数据库
            with sqlite3.connect('temp.db', check_same_thread=False) as f:
                c = f.cursor()
                c.execute('delete from scores')
                f.commit()
    
            #关闭主界面窗口并终止计时器、服务器线程
            self.close()
    
        #统一处理保存项目文件
        def saveProgram(self,save_program_path,state):
            with shelve.open(save_program_path[0]) as f:
                # AI运行的设定
                f["setting"] = self.setting_form.getSetting()
    
                # AI运行的状态
                state = self.game_thread.AI.getState()
    
                f["game"] = self.game_selection.currentText()
                f["epsilon"] = state["EPSILON"]
                f["result"] = [[i[0] * 1000, i[1]] for i in
                               sqlite3.connect('temp.db', check_same_thread=False).cursor().execute(
                                   'select * from scores').fetchall()]
                f["replay"] = self.game_thread.AI.getReplay()
    
                if state:
                    f["timestep"]=int(state["TIMESTEP"]) + int(f["timestep"])
                else:
                    f["timestep"] = state["TIMESTEP"]
    
        #定时更新主窗口状态
        def updateState(self):
            #尝试获取游戏状态,如果启动时间过慢仍未启动则跳过此次获取
            try:
                self.state = self.game_thread.AI.getState()
            except:
                pass
            else:
                actual_timestep=self.state["TIMESTEP"]
                self.progressBar.setToolTip("Timestep:"+str(actual_timestep)+"    STATE:"+self.state["STATE"]+"     EPSILON:"+str(self.state["EPSILON"]))
                self.progressBar.setProperty("value",min(float(actual_timestep)/float(self.setting_form.getSetting()["Explore"])*100,100))
    
            #每隔5秒才向数据库读取一次,优化速度
            try:
                self.game_thread.AI.data_base.commit()
            except:
                pass
    
    
        # 通过按键更改AI模式
        def setMode(self):
            if not self.mode_state:
                self.mode_state = True
                self.mode.setStyleSheet("border-image: url(:/bottom/resource/加载模式.png);\n""")
            else:
                self.mode_state = False
                self.mode.setStyleSheet("border-image: url(:/bottom/resource/空白模式.png);\n""")
    
    
        # 获取本机ip地址
        def getIp(self):
            try:
                sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
                sock.connect(('8.8.8.8', 80))
                ip = sock.getsockname()[0]
            finally:
                sock.close()
            pyperclip.copy(ip + ':9090')
    
    
        #定时更新数据库
        def updateDataset(self,results):
            with shelve.open('temp.db',writeback=True) as f:
                c=f.cursor()
                for result in results:
                    c.execute("insert into scores values (%s,%s)" % (result[0], result[1]))
                f.commit()
    
    
        # 保存模型
        def saveModel(self, program_name):
            for file in glob.glob("./saved_networks/network-dqn-*"):
                postfix = file.split('.')[-1]
                try:
                    shutil.copy(file, program_name + '.' + postfix)
                except:
                    pass
    
    
        # 设置按键操作
        def openSetting(self):
            self.setting_form.show()
  • 设置窗口

    from PyQt5 import QtCore, QtGui, QtWidgets
    import setting_resource
    
    class SETTING(QtWidgets.QWidget):
        def __init__(self):
    
            #父类初始化
            super().__init__()
    
            #主窗口初始化
            self.setObjectName("Dialog")
            self.resize(547, 402)
            self.setStyleSheet("")
    
            #初始化确定按钮
            self.pushButton = QtWidgets.QPushButton(self)
            self.pushButton.setGeometry(QtCore.QRect(160, 320, 75, 23))
            self.pushButton.setStyleSheet("color: rgb(255, 255, 255);\n""border-image: url(:/image/resource/设定确定按钮.png);")
            self.pushButton.setText("")
            self.pushButton.setObjectName("pushButton")
    
            #初始化取消按钮
            self.pushButton_2 = QtWidgets.QPushButton(self)
            self.pushButton_2.setGeometry(QtCore.QRect(320, 320, 75, 23))
            self.pushButton_2.setStyleSheet("color: rgb(255, 255, 255);\n""border-image: url(:/image/resource/设定取消按钮.png);")
            self.pushButton_2.setText("")
            self.pushButton_2.setObjectName("pushButton_2")
    
            #初始化各个编辑框
            self.line_explore = QtWidgets.QLineEdit(self)
            self.line_explore.setGeometry(QtCore.QRect(450, 60, 61, 20))
            self.line_explore.setStyleSheet("color: rgb(0, 0, 0);")
            self.line_explore.setObjectName("line_explore")
            self.line_initial = QtWidgets.QLineEdit(self)
            self.line_initial.setGeometry(QtCore.QRect(450, 100, 61, 20))
            self.line_initial.setStyleSheet("color: rgb(0, 0, 0);")
            self.line_initial.setObjectName("line_Initial")
            self.line_final = QtWidgets.QLineEdit(self)
            self.line_final.setGeometry(QtCore.QRect(450, 140, 61, 20))
            self.line_final.setStyleSheet("color: rgb(0, 0, 0);")
            self.line_final.setObjectName("line_final")
            self.line_gamma = QtWidgets.QLineEdit(self)
            self.line_gamma.setGeometry(QtCore.QRect(450, 180, 61, 20))
            self.line_gamma.setStyleSheet("color: rgb(0, 0, 0);")
            self.line_gamma.setObjectName("line_gamma")
            self.line_replay = QtWidgets.QLineEdit(self)
            self.line_replay.setGeometry(QtCore.QRect(450, 220, 61, 20))
            self.line_replay.setStyleSheet("color: rgb(0, 0, 0);")
            self.line_replay.setObjectName("line_replay")
            self.line_batch = QtWidgets.QLineEdit(self)
            self.line_batch.setGeometry(QtCore.QRect(450, 260, 61, 20))
            self.line_batch.setStyleSheet("color: rgb(0, 0, 0);")
            self.line_batch.setObjectName("line_batch")
            self.exploreSlider = QtWidgets.QSlider(self)
            self.exploreSlider.setGeometry(QtCore.QRect(120, 60, 300, 19))
            self.exploreSlider.setStyleSheet("QSlider::handle:horizontal {     \n""    image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal {        \n""    image: url(:/image/resource/Base.png);\n""}\n""")
            self.exploreSlider.setMinimum(200000)
            self.exploreSlider.setMaximum(10000000)
            self.exploreSlider.setProperty("value", 200000)
            self.exploreSlider.setOrientation(QtCore.Qt.Horizontal)
            self.exploreSlider.setObjectName("exploreSlider")
            self.label = QtWidgets.QLabel(self)
            self.label.setGeometry(QtCore.QRect(50, 60, 48, 19))
            self.label.setStyleSheet("color: rgb(255, 255, 255);")
            self.label.setObjectName("label")
            self.label_2 = QtWidgets.QLabel(self)
            self.label_2.setGeometry(QtCore.QRect(50, 100, 48, 19))
            self.label_2.setStyleSheet("color: rgb(255, 255, 255);")
            self.label_2.setObjectName("label_2")
            self.initialSlider = QtWidgets.QSlider(self)
            self.initialSlider.setGeometry(QtCore.QRect(120, 100, 300, 19))
            self.initialSlider.setStyleSheet("QSlider::handle:horizontal {     \n""    image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal {        \n""    image: url(:/image/resource/Base.png);\n""}\n""")
            self.initialSlider.setMaximum(1000)
            self.initialSlider.setProperty("value", 0)
            self.initialSlider.setOrientation(QtCore.Qt.Horizontal)
            self.initialSlider.setObjectName("initialSlider")
            self.label_3 = QtWidgets.QLabel(self)
            self.label_3.setGeometry(QtCore.QRect(50, 140, 42, 19))
            self.label_3.setStyleSheet("color: rgb(255, 255, 255);")
            self.label_3.setObjectName("label_3")
            self.finalSlider = QtWidgets.QSlider(self)
            self.finalSlider.setGeometry(QtCore.QRect(120, 140, 300, 19))
            self.finalSlider.setStyleSheet("QSlider::handle:horizontal {     \n""    image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal {        \n""    image: url(:/image/resource/Base.png);\n""}\n""")
            self.finalSlider.setMaximum(1000)
            self.finalSlider.setProperty("value", 0)
            self.finalSlider.setOrientation(QtCore.Qt.Horizontal)
            self.finalSlider.setObjectName("finalSlider")
            self.label_4 = QtWidgets.QLabel(self)
            self.label_4.setGeometry(QtCore.QRect(50, 180, 42, 19))
            self.label_4.setStyleSheet("color: rgb(255, 255, 255);")
            self.label_4.setObjectName("label_4")
            self.gammaSlider = QtWidgets.QSlider(self)
            self.gammaSlider.setGeometry(QtCore.QRect(120, 180, 300, 19))
            self.gammaSlider.setStyleSheet("QSlider::handle:horizontal {     \n""    image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal {        \n""    image: url(:/image/resource/Base.png);\n""}\n""")
            self.gammaSlider.setMaximum(100)
            self.gammaSlider.setProperty("value", 99)
            self.gammaSlider.setOrientation(QtCore.Qt.Horizontal)
            self.gammaSlider.setObjectName("gammaSlider")
            self.label_6 = QtWidgets.QLabel(self)
            self.label_6.setGeometry(QtCore.QRect(50, 220, 42, 19))
            self.label_6.setStyleSheet("color: rgb(255, 255, 255);")
            self.label_6.setObjectName("label_6")
            self.replaySlider = QtWidgets.QSlider(self)
            self.replaySlider.setGeometry(QtCore.QRect(120, 220, 300, 19))
            self.replaySlider.setStyleSheet("QSlider::handle:horizontal {     \n""    image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal {        \n""    image: url(:/image/resource/Base.png);\n""}\n""")
            self.replaySlider.setMaximum(100000)
            self.replaySlider.setProperty("value", 50000)
            self.replaySlider.setOrientation(QtCore.Qt.Horizontal)
            self.replaySlider.setObjectName("replaySlider")
            self.label_7 = QtWidgets.QLabel(self)
            self.label_7.setGeometry(QtCore.QRect(50, 260, 36, 19))
            self.label_7.setStyleSheet("color: rgb(255, 255, 255);")
            self.label_7.setObjectName("label_7")
            self.batchSlider = QtWidgets.QSlider(self)
            self.batchSlider.setGeometry(QtCore.QRect(120, 260, 300, 19))
            self.batchSlider.setStyleSheet("QSlider::handle:horizontal {     \n""    image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal {        \n""    image: url(:/image/resource/Base.png);\n""}\n""")
            self.batchSlider.setMaximum(100)
            self.batchSlider.setProperty("value", 32)
            self.batchSlider.setOrientation(QtCore.Qt.Horizontal)
            self.batchSlider.setObjectName("batchSlider")
            self.label_5 = QtWidgets.QLabel(self)
            self.label_5.setGeometry(QtCore.QRect(0, 0, 551, 411))
            self.label_5.setStyleSheet("background-image: url(:/background/resource/设定背景.png);")
            self.label_5.setText("")
            self.label_5.setObjectName("label_5")
    
            #组件挂起待用
            self.label_5.raise_()
            self.pushButton.raise_()
            self.pushButton_2.raise_()
            self.line_explore.raise_()
            self.line_initial.raise_()
            self.line_final.raise_()
            self.line_gamma.raise_()
            self.line_replay.raise_()
            self.line_batch.raise_()
            self.exploreSlider.raise_()
            self.label.raise_()
            self.label_2.raise_()
            self.initialSlider.raise_()
            self.label_3.raise_()
            self.finalSlider.raise_()
            self.label_4.raise_()
            self.gammaSlider.raise_()
            self.label_6.raise_()
            self.replaySlider.raise_()
            self.label_7.raise_()
            self.batchSlider.raise_()
    
            #重设界面
            self.retranslateUi(self)
    
            #编辑框和滑条互联
            self.connect()
    
            #按钮消息槽激活
            self.pushButton.clicked.connect(self.saveSetting)
            self.pushButton_2.clicked.connect(self.cancel)
            QtCore.QMetaObject.connectSlotsByName(self)
    
        def retranslateUi(self, Dialog):
            _translate = QtCore.QCoreApplication.translate
            Dialog.setWindowTitle(_translate("Dialog", "设置"))
    
            #初始化各编辑框
            self.line_explore.setText(_translate("Dialog", "200000"))
            self.line_initial.setText(_translate("Dialog", "0"))
            self.line_final.setText(_translate("Dialog", "0"))
            self.line_gamma.setText(_translate("Dialog", "0.99"))
            self.line_replay.setText(_translate("Dialog", "50000"))
            self.line_batch.setText(_translate("Dialog", "32"))
            self.label.setText(_translate("Dialog", "Explore:"))
            self.label_2.setText(_translate("Dialog", "Initial:"))
            self.label_3.setText(_translate("Dialog", "Final:"))
            self.label_4.setText(_translate("Dialog", "Gamma:"))
            self.label_6.setText(_translate("Dialog", "Replay:"))
            self.label_7.setText(_translate("Dialog", "Batch:"))
    
            #初始化设定
            self.setting={"Explore":200000,"Initial":0,"Final":0,"Gamma":0.99,"Replay":50000,"Batch":32}
    
        #编辑框和滑动条互联
        def connect(self):
    
            self.exploreSlider.valueChanged.connect(self.changeLineExplore)
            self.line_explore.textChanged.connect(self.changeSliderExplore)
    
            self.initialSlider.valueChanged.connect(self.changeLineInitial)
            self.line_initial.textChanged.connect(self.changeSliderInitial)
    
            self.finalSlider.valueChanged.connect(self.changeLineFinal)
            self.line_final.textChanged.connect(self.changeSliderFinal)
    
            self.gammaSlider.valueChanged.connect(self.changeLineGamma)
            self.line_gamma.textChanged.connect(self.changeSliderGamma)
    
            self.replaySlider.valueChanged.connect(self.changeLineReplay)
            self.line_replay.textChanged.connect(self.changeSliderReplay)
    
            self.batchSlider.valueChanged.connect(self.changeLineBatch)
            self.line_batch.textChanged.connect(self.changeSliderBatch)
    
        def changeLineExplore(self):
            try:
                self.line_explore.setText(str(self.exploreSlider.value()))
            except:
                pass
    
        def changeSliderExplore(self):
            try:
                self.exploreSlider.setValue(int(self.line_explore.text()))
            except:
                pass
    
        def changeLineInitial(self):
            try:
                self.line_initial.setText(str(self.initialSlider.value()/1000))
            except:
                pass
    
        def changeSliderInitial(self):
            try:
                self.initialSlider.setValue(int(float(self.line_initial.text())*1000))
            except:
                pass
    
        def changeLineFinal(self):
            try:
                self.line_final.setText(str(self.finalSlider.value()/1000))
            except:
                pass
    
        def changeSliderFinal(self):
            try:
                self.finalSlider.setValue(int(float(self.line_final.text()*1000)))
            except:
                pass
    
        def changeLineGamma(self):
            try:
                self.line_gamma.setText(str(self.gammaSlider.value()/100))
            except:
                pass
    
        def changeSliderGamma(self):
            try:
                self.gammaSlider.setValue(int(100*float(self.line_gamma.text())))
            except:
                pass
    
        def changeLineReplay(self):
            try:
                self.line_replay.setText(str(self.replaySlider.value()))
            except:
                pass
    
        def changeSliderReplay(self):
            try:
                self.replaySlider.setValue(int(self.line_replay.text()))
            except:
                pass
    
        def changeLineBatch(self):
            try:
                self.line_batch.setText(str(self.batchSlider.value()))
            except:
                pass
    
        def changeSliderBatch(self):
            try:
                self.batchSlider.setValue(int(self.line_batch.text()))
            except:
                pass
    
        #外部获取AI设置
        def getSetting(self):
            return self.setting
    
        #保存设定
        def saveSetting(self):
            self.setting={"Explore":self.line_explore.text(),"Initial":self.line_initial.text(),"Final":self.line_final.text(),"Gamma":self.line_gamma.text(),"Replay":self.line_replay.text(),"Batch":self.line_batch.text()}#还要做一个数字判断
            self.hide()
    
        #取消设定
        def cancel(self):
            self.hide()
            return 0
    
        #通过导入文档更新设定
        def updateSetting(self,setting):
            self.setting={"Explore":setting["Explore"],"Initial":setting["Initial"],"Final":setting["Final"],"Gamma":setting["Gamma"],"Replay":setting["Replay"],"Batch":setting["Batch"]}#还要做一个数字判断
            self.line_explore.setText(str(setting["Explore"]))
            self.line_final.setText(str(setting["Final"]))
            self.line_Initial.setText(str(setting["Initial"]))
            self.line_gamma.setText(str(setting["Gamma"]))
            self.line_replay.setText(str(setting["Replay"]))
            self.line_batch.setText(str(setting["Batch"]))
  • 深度强化学习
    该部分代码参考https://blog.csdn.net/songrotek/article/details/50951537。 深度强化学习原理我这里不再赘述,大家可以查看该blog,有很详细的讲解。
    主要由两部分组成:DQL.py统一管理游戏和算法,DQLBrain.py则是深度强化学习算法核心。下面分别展示:

    • DQL.py
      import cv2
      from DQLBrain import Brain
      import numpy as np
      from collections import deque
      import sqlite3
      import pygame
      import time
      import gameSetting
      import importlib

      #所有游戏的统一设置
      SCREEN_X = 288
      SCREEN_Y = 512
      FPS = 60
      
      class AI:
          def __init__(self, title,model_path,replay_memory,current_timestep,explore,initial_epsilon,final_epsilon,gamma,replay_size,batch_size):
              #初始化常量
              self.scores = deque()
              self.games_info = gameSetting.getSetting()
      
              #连接临时数据库(并确保已经存在对应的表)
              self.data_base = sqlite3.connect('temp.db', check_same_thread=False)
              self.c = self.data_base.cursor()
              try:
                  self.c.execute('create table scores (time integer, score integer) ')
              except:
                  pass
      
              #创建Deep-Reinforcement Learning对象
              self.brain = Brain(self.games_info[title]["action"],model_path,replay_memory,current_timestep,explore,initial_epsilon,final_epsilon,gamma,replay_size,batch_size)
      
              #创建游戏窗口
              self.startGame(title,SCREEN_X,SCREEN_Y)
      
              #加载对应的游戏
              game=importlib.import_module(self.games_info[title]['class'])
              self.game=game.Game(self.screen)
      
          def startGame(self,title,SCREEN_X, SCREEN_Y):
              #窗口的初始化
              pygame.init()
              screen_size = (SCREEN_X, SCREEN_Y)
              pygame.display.set_caption(title)
      
              #屏幕的创建
              self.screen = pygame.display.set_mode(screen_size)
      
              #游戏计时器的创建
              self.clock = pygame.time.Clock()
      
          #为降低画面复杂度,将画面进行预处理
          def preProcess(self, observation):
      
              #将512*288的画面裁剪为80*80并将RGB(三通道)画面转换成灰度图(一通道)
              observation = cv2.cvtColor(cv2.resize(observation, (80, 80)), cv2.COLOR_BGR2GRAY)
      
              #将非黑色的像素都变成白色
              threshold,observation = cv2.threshold(observation, 1, 255, cv2.THRESH_BINARY)
      
              #返回(80,80,1),最后一维是保证图像是一个tensor(张量),用于输入tensorflow
              return np.reshape(observation, (80, 80, 1))
      
          #开始游戏
          def playGame(self):
      
              #先随便给一个决策输入,启动游戏
              observation0, reward0, terminal,score =self.game.frameStep(np.array([1, 0, 0]))
              observation0 = self.preProcess(observation0)
              self.brain.setInitState(observation0[:,:,0])
      
              #开始正式游戏
              i = 1
              while True:
                  i = i + 1
                  action = self.brain.getAction()
                  next_bservation, reward, terminal,score = self.game.frameStep(action)
      
                  #处理游戏界面销毁消息
                  if (terminal == -1):
                      self.closeGame()
                      return
                  else:
      
                  #继续游戏
                      next_bservation = self.preProcess(next_bservation)
                      self.brain.setPerception(next_bservation, action, reward, terminal)
      
                  #提取每一局的成绩
                  if terminal:
                      t = int(time.time())
                      self.c.execute("insert into scores values (%s,%s)" % (t, score))
      
          #关闭游戏
          def closeGame(self):
              pygame.quit()
              self.brain.close()
              time.sleep(0.5)#确保brain中写入数据库的操作已经完成
              self.data_base.close()
      
          #获得当前游戏状态
          def getState(self):
              return self.brain.getState()
      
          #获得当前replay数据,以加入项目文件
          def getReplay(self):
              return self.brain.replay_memory
      
    • DQLBrain.py
      observe=100

      class Brain:
          def __init__(self, actions,model_path,replay_memory=deque(),current_timestep=0,explore=200000.,initial_epsilon=0.0,final_epsilon=0.0,gamma=0.99,replay_size=50000,batch_size=32):
      
              # 设置超参数:
      
              # 学习率
              self.gamma = gamma
      
              # 训练之前观察的次数
              self.observe = observe
      
              # 容错率下降的次数
              self.explore = explore
      
              # 一开始的容错率
              self.initial_epsilon = initial_epsilon
      
              #最终的容错率
              self.final_epsilon = final_epsilon
      
              # replay buffer的大小
              self.replay_size = replay_size
      
              # minibatch的大小
              self.batch_size = batch_size
      
              self.update_time = 100
      
              self.whole_state = dict()
      
              #初始化replay buffer
              self.replay_memory = replay_memory
      
              # 初始化其他参数
              self.timestep = 0
              self.initial_timestep=current_timestep
              self.accual_timestep=self.initial_timestep+self.timestep
      
              #当主界面采用加载模式时,算法核心必须重新加载项目文件中的已经记录的容错率
              self.epsilon = self.initial_epsilon - (self.initial_epsilon - self.final_epsilon) / self.explore * self.accual_timestep
              if self.epsilon<self.final_epsilon:
                  self.epsilon=self.final_epsilon
              self.actions = actions
      
              # 初始化 Q_t+1 网络
              self.state_input, self.QValue, self.conv1_w, self.conv1_b, self.conv2_w, self.conv2_b, self.conv3_w, self.conv3_b, self.fc1_w, self.fc1_b, self.fc2_w, self.fc2_b = self.createQNetwork()
      
              # 初始化 Q_t 网络
              self.state_inputT, self.QValueT, self.conv1_wT, self.conv1_bT, self.conv2_wT, self.conv2_bT, self.conv3_wT, self.conv3_bT, self.fc1_wT, self.fc1_bT, self.fc2_wT, self.fc2_bT = self.createQNetwork()
              self.copyTargetQNetwork = [self.conv1_wT.assign(self.conv1_w), self.conv1_bT.assign(self.conv1_b), self.conv2_wT.assign(self.conv2_w), self.conv2_bT.assign(self.conv2_b), self.conv3_wT.assign(self.conv3_w), self.conv3_bT.assign(self.conv3_b), self.fc1_wT.assign(self.fc1_w), self.fc1_bT.assign(self.fc1_b), self.fc2_wT.assign(self.fc2_w), self.fc2_bT.assign(self.fc2_b)]
      
              #损失函数的设置
              self.action_input = tf.placeholder("float", [None, self.actions])
              self.y_input = tf.placeholder("float", [None])
              Q_Action = tf.reduce_sum(tf.multiply(self.QValue, self.action_input), reduction_indices=1)
              self.cost = tf.reduce_mean(tf.square(self.y_input - Q_Action))
              self.optimizer = tf.train.AdamOptimizer(1e-6).minimize(self.cost)
      
              # 保存和重新加载模型
              self.saver = tf.train.Saver(max_to_keep=1)
              self.session = tf.InteractiveSession()
              self.session.run(tf.initialize_all_variables())
      
          def createQNetwork(self):
      
              # 初始化结构
              # 第一层卷积层 8*8*4*32
              W_conv1 = self.weightVariable([8, 8, 4, 32])
              b_conv1 = self.biasVariable([32])
      
              # 第二层卷积层 4*4*32*64:
              W_conv2 = self.weightVariable([4, 4, 32, 64])
              b_conv2 = self.biasVariable([64])
      
              #第三层卷积层 3*3*64*64
              W_conv3 = self.weightVariable([3, 3, 64, 64])
              b_conv3 = self.biasVariable([64])
      
              #全连接层1600*512
              W_fc1 = self.weightVariable([1600, 512])
              b_fc1 = self.biasVariable([512])
      
              #输出层 512*actions
              W_fc2 = self.weightVariable([512, self.actions])
              b_fc2 = self.biasVariable([self.actions])
      
              # input layer
              stateInput = tf.placeholder("float", [None, 80, 80, 4])
      
              # 开始建立网络
              # 隐藏层
              
              h_conv1 = tf.nn.relu(self.conv2d(stateInput, W_conv1, 4) + b_conv1)
      		
              #20*20*32 to 10*10*32
              h_pool1 = self.maxPool_2x2(h_conv1)
      
              h_conv2 = tf.nn.relu(self.conv2d(h_pool1, W_conv2, 2) + b_conv2)
      
              #stride=1,5*5*64 to 5*5*64
              h_conv3 = tf.nn.relu(self.conv2d(h_conv2, W_conv3, 1) + b_conv3)
      
              #5*5*64 to 1*1600
              h_conv3_flat = tf.reshape(h_conv3, [-1, 1600])
              h_fc1 = tf.nn.relu(tf.matmul(h_conv3_flat, W_fc1) + b_fc1)
      
              #输出层
              QValue = tf.matmul(h_fc1, W_fc2) + b_fc2
      
              return stateInput, QValue, W_conv1, b_conv1, W_conv2, b_conv2, W_conv3, b_conv3, W_fc1, b_fc1, W_fc2, b_fc2
      
          def trainQNetwork(self):
      
              #从replay buffer中抽样
              minibatch = random.sample(self.replay_memory, self.batch_size)
              state_batch = [data[0] for data in minibatch]
              action_batch = [data[1] for data in minibatch]
              reward_batch = [data[2] for data in minibatch]
              nextState_batch = [data[3] for data in minibatch]
      
              #计算损失函数
              y_batch = []
              QValue_batch = self.QValueT.eval(feed_dict={self.state_inputT: nextState_batch})
              for i in range(0, self.batch_size):
                  terminal = minibatch[i][4]
                  if terminal:
                      y_batch.append(reward_batch[i])
                  else:
                      y_batch.append(reward_batch[i] + self.gamma * np.max(QValue_batch[i]))
              self.optimizer.run(feed_dict={self.y_input: y_batch, self.action_input: action_batch, self.state_input: state_batch})
      
              # 每运行100epoch保存一次网络
              if self.timestep % 1000 == 0:
                  self.saver.save(self.session, './saved_networks/network' + '-dqn', global_step=self.timestep+self.initial_timestep)
      
              #更新Q网络
              if self.timestep % self.update_time == 0:
                  self.session.run(self.copyTargetQNetwork)
      
          def setPerception(self, nextObservation, action, reward, terminal):
      
              new_state = np.append(self.current_state[:, :, 1:], nextObservation, axis=2)
              self.replay_memory.append((self.current_state, action, reward, new_state, terminal))
      
              #控制replay buffer的大小
              if len(self.replay_memory) > self.replay_size:
                  self.replay_memory.popleft()
              if self.timestep > self.observe:
                  self.trainQNetwork()
      
              # 将训练信息输出到主界面中
              if self.timestep <= self.observe:
                  state = "observe"
              elif self.timestep  > self.observe and self.timestep  <= self.observe + self.explore:
                  state = "explore"
              else:
                  state = "train"
      
              self.whole_state={"TIMESTEP":self.timestep +self.initial_timestep,"STATE":state, "EPSILON":self.epsilon,"ACTUAL":int(self.timestep+self.initial_timestep)}
      
              self.current_state = new_state
              self.timestep  += 1
      
          def getAction(self):
              QValue = self.QValue.eval(feed_dict={self.state_input: [self.current_state]})[0]
              action = np.zeros(self.actions)
      
              #epsilon策略
              if random.random() <= self.epsilon:
                  action_index = random.randrange(self.actions)
                  action[action_index] = 1
              else:
                  action_index = np.argmax(QValue)
                  action[action_index] = 1
      
              # 改变episilon
              if self.epsilon > self.final_epsilon and self.accual_timestep > self.observe:
                  self.epsilon = self.initial_epsilon - (self.initial_epsilon - self.final_epsilon) / self.explore * self.accual_timestep
      
              return action
      
          def setInitState(self, observation):
              self.current_state = np.stack((observation, observation, observation, observation), axis=2)
      
          def weightVariable(self, shape):
              initial = tf.truncated_normal(shape, stddev=0.01)
              return tf.Variable(initial)
      
          def biasVariable(self, shape):
              initial = tf.constant(0.01, shape=shape)
              return tf.Variable(initial)
      
          def conv2d(self, x, W, stride):
              return tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding="SAME")
      
          def maxPool_2x2(self, x):
              return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
      
          def close(self):
              self.session.close()
      
          def getState(self):
              return self.whole_state
      
  • 服务器

  主要采用highchart的API。在static文件夹中放好上述的四项文件后,在template文件夹中写好服务器界面的代码index.html(为了方便大家学习,界面写得相当简陋hh):

	<head>

	<script src='/static/jquery.js'></script>
	<script src='/static/highstock.js'></script>
	<script src='/static/exporting.js'></script>

	</head>
	<body>

		<div id="container" style="min-width:310px;height:400px"></div>

		<script>
	$(function () {
		// 使用当前时区,否则东八区会差八个小时
		Highcharts.setOptions({
			global: {
				useUTC: false
			}
		});
		$.getJSON('/data', function (data) {
			// Create the chart
			$('#container').highcharts('StockChart', {
			chart:{
			events:{
			
				load:function(){
				
					var series = this.series[0]
					setInterval(function(){
					$.getJSON('/data',function(res){
						$.each(res,function(i,v){
							series.addPoint(v)
						})
					})
					},3000)
				}
			}
			},
				rangeSelector : {
					selected : 1
				},
				title : {
					text : '每局分数'
				},
				series : [{
					name : '训练表现',
					data : data,
					tooltip: {
						valueDecimals: 2
					}
				}]
			});
		});
	});
	</script>
	</body>
	</html>

  同时还需要编写一个实时调用该模板的py文件:Webservice.py:

	from flask import Flask,render_template,request
    import sqlite3
    import json
    
    app=Flask(__name__)
    
    #连接临时数据库
    data_base = sqlite3.connect('temp.db', check_same_thread=False)
    c = data_base.cursor()
    
    #设置前端模板
    @app.route('/')
    def index():
        return render_template("index.html")
    
    
    #设置数据来源
    @app.route('/data')
    def data():
        global tmp_time,c
        sql='select * from scores'
        c.execute(sql)
        arr=[]
        for i in c.fetchall():
            arr.append([i[0]*1000,i[1]])
        return json.dumps(arr)
    
    #启动服务器并设定端口,设置0.0.0.0表示对内网所有主机都进行监听
    def start():
        app.run(host='0.0.0.0',port=9090)

结语

  不过貌似PyQt5和tensorflow会有冲突,因此实际运行的时候会偶尔出现崩溃。另外服务器无法由外网的机器连接。如果大家知道怎么解决这些问题请在下方留言告诉我,谢谢!最后再来一次:github地址为https://github.com/qq303067814/DQLearning-Toolbox, 如果讲解中有部分还想继续了解的话可以直接查看源代码,或者在留言中提出。训练简单小游戏的强化学习工具箱

代码地址如下:
http://www.demodashi.com/demo/14072.html

注:本文著作权归作者,由demo大师代发,拒绝转载,转载需要作者授权

posted on 2018-09-27 11:02  demo例子集  阅读(2016)  评论(0编辑  收藏  举报

导航