day2021_10_12
-
昨天对着一个q_learning的小案例敲了一下,他是类似走迷宫,从红色地方走到黄色圆圈,走到后面,基本上就是沿着一条路线走.
主要参考:
https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow
- 这是一个迷宫的环境
# -*- coding = utf-8 -*- # @Time: 2021/10/11 19:39 # @Author: 闲卜 # @File: maze_env.py # @Software: PyCharm import numpy as np import tkinter as tk import time MAZE_H = 4 MAZE_W = 4 UNIT = 40 class Maze(tk.Tk, object): def __init__(self): super(Maze, self).__init__() # 行为:上、下、左、右 self.action_space = ['u', 'd', 'l', 'r'] self.n_actions = len(self.action_space) self.title("maze") self.geometry('{0}x{1}'.format(MAZE_H*UNIT, MAZE_W*UNIT)) self._build_maze() # 构建迷宫 # 构建迷宫 def _build_maze(self): self.canvas = tk.Canvas( self, bg="white", height=MAZE_H*UNIT, width=MAZE_W*UNIT ) # 创建网格 for c in range(0, MAZE_W*UNIT, UNIT): x0, x1, y0, y1 = c, 0, c, MAZE_W*UNIT self.canvas.create_line(x0, x1, y0, y1) for c in range(0, MAZE_H*UNIT, UNIT): x0, x1, y0, y1 = 0, c, MAZE_H*UNIT, c self.canvas.create_line(x0, x1, y0, y1) # 确定出发点位置 origin = np.array([20, 20]) hell1_center = origin+np.array([UNIT*2, UNIT]) hell2_center = origin+np.array([UNIT, UNIT*2]) self.hell1 = self.canvas.create_rectangle( hell1_center[0]-15, hell1_center[1]-15, hell1_center[0]+15, hell1_center[1]+15, fill='black' ) self.hell2 = self.canvas.create_rectangle( hell2_center[0]-15, hell2_center[1]-15, hell2_center[0]+15, hell2_center[1]+15, fill='black' ) self.rect = self.canvas.create_rectangle( origin[0]-15, origin[1]-15, origin[0]+15, origin[1]+15, fill='red' ) oval_center = origin+np.array([UNIT*2, UNIT*2]) self.oval = self.canvas.create_oval( oval_center[0]-15, oval_center[1]-15, oval_center[0]+15, oval_center[1]+15, fill='yellow' ) self.canvas.pack() # 将画布重置初始状态 def reset(self): self.update() time.sleep(0.1) self.canvas.delete(self.rect) origin = np.array([20, 20]) self.rect = self.canvas.create_rectangle( origin[0] - 15, origin[1] - 15, origin[0] + 15, origin[1] + 15, fill='red' ) return self.canvas.coords(self.rect) def step(self, action): s = self.canvas.coords(self.rect) base_action = np.array([0, 0]) if action == 0: if s[1] > UNIT: base_action[1] -= UNIT elif action == 1: if s[1] < UNIT*(MAZE_H - 1): base_action[1] += UNIT elif action == 2: if s[0] < UNIT*(MAZE_H - 1): base_action[0] += UNIT elif action == 3: # left if s[0] > UNIT: base_action[0] -= UNIT self.canvas.move(self.rect, base_action[0], base_action[1]) s_ = self.canvas.coords(self.rect) # 奖励函数 if s_ == self.canvas.coords(self.oval): reward = 1 done = True s_ = 'terminal' elif s_ in [self.canvas.coords(self.hell1), self.canvas.coords(self.hell2)]: reward = -1 done = True s_ = 'terminal' else: reward = 0 done = False return s_, reward, done def render(self): time.sleep(0.1) self.update() def update(): for t in range(100): s = env.reset() while True: env.render() a = 1 s, r, done = env.step(a) if done: break env.destroy() if __name__ == "__main__": env = Maze() env.after(100, update) env.mainloop()
- 这是q_learning算法的学习过程
# -*- coding = utf-8 -*- # @Time: 2021/10/11 21:15 # @Author: 闲卜 # @File: RL_brain.py # @Software: PyCharm import numpy as np import pandas as pd class QLearningTable: def __init__(self, actions, learning=0.1, reward_decay=0.9, e_greed=0.9): self.actions = actions self.lr = learning self.gamma = reward_decay self.epsilon = e_greed self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64) def choose_action(self, observation): self.check_state_exist(observation) if np.random.uniform() < self.epsilon: state_action = self.q_table.loc[observation, :] action = np.random.choice(state_action[state_action == np.max(state_action)].index) else: action = np.random.choice(self.actions) return action def learn(self, s, a, r, s_): self.check_state_exist(s_) q_predict = self.q_table.loc[s, a] if s_ != 'terminal': q_target = r + self.gamma*self.q_table.loc[s_, :].max() else: q_target = r self.q_table.loc[s, a] += self.lr*(q_target-q_predict) def check_state_exist(self, state): if state not in self.q_table.index: self.q_table = self.q_table.append( pd.Series( [0]*len(self.actions), index=self.q_table.columns, name=state, ) )
- 这是启动类
from maze_env import Maze from RL_brain import QLearningTable def update(): for episode in range(10): observation = env.reset() # 重置到起始点的坐标 while True: env.render() # 刷新界面 action = RL.choose_action(str(observation)) observation_, reward, done = env.step(action) RL.learn(str(observation), action, reward, str(observation_)) observation = observation_ if done: break print(RL.q_table) if __name__ == '__main__': env = Maze() RL = QLearningTable(actions=list(range(env.n_actions))) env.after(100, update) env.mainloop()
-
今晚有自己在此基础上把sarsa算法予以实现,主要表现在下面两个地方的不同
多一个参数a_,也就是接下来的行为,sarsa更新表会把执行的行为获得的奖励或惩罚直接用于更新,
而q_learning更新表则是选择下一个状态下,行为得到q值最大的值用于更新q表
- sarsa的学习
def learn(self, a, s, r, s_, a_): self.if_exist_action(s_) q_predict = self.q_table.loc[s, a] if s_ != 'terminal': q_target = r + self.gamma*self.q_table.loc[s_, a_] else: q_target = r self.q_table.loc[s, a] += self.lr*(q_target-q_predict)
- q_learning的学习
def learn(self, s, a, r, s_): self.check_state_exist(s_) q_predict = self.q_table.loc[s, a] if s_ != 'terminal': q_target = r + self.gamma*self.q_table.loc[s_, :].max() else: q_target = r self.q_table.loc[s, a] += self.lr*(q_target-q_predict)
今天遇到的麻烦
- 一个是传入状态observation的时候,因为这个状态是位置信息,是一个list类型,不能用来作为表的索引,需要转成string类型
- 二是,定位q表里的值,老是忘把loc忘加了,导致找不到key报错
后面计划
- 工程训练课数据库的设计
- 算法课后题以及数据库的完成
- 有时间就看下工程训练需要的内容,或者是tensorflow框架