[RL學習篇] [#1]解決grid_mdp.py 不能執行的問題?
在 深入淺出強化學習:原理入門 一書中第一個例子的grid_mdp.py不能執行,會有錯誤。
於是我研究了一下之後,改成下列方程式就可以執行了,給有需要的人下載回去看看。
1 import logging 2 import numpy 3 import random 4 from gym import spaces 5 import gym 6 from gym.utils import seeding 7 logger = logging.getLogger(__name__) 8 class GridEnv(gym.Env): 9 metadata = { 10 'render.modes': ['human', 'rgb_array'], 11 'video.frames_per_second': 2 12 } 13 14 def __init__(self): 15 self.states = [1,2,3,4,5,6,7,8] #状态空间 16 self.x=[140,220,300,380,460,140,300,460] 17 self.y=[250,250,250,250,250,150,150,150] 18 self.terminate_states = dict() #终止状态为字典格式 19 self.terminate_states[6] = 1 20 self.terminate_states[7] = 1 21 self.terminate_states[8] = 1 22 self.action_s = ['n','e','s','w'] 23 24 25 self.rewards = dict(); #回报的数据结构为字典 26 self.rewards['1_s'] = -1.0 27 self.rewards['3_s'] = 1.0 28 self.rewards['5_s'] = -1.0 29 self.t = dict(); #状态转移的数据格式为字典 30 self.t['1_s'] = 6 31 self.t['1_e'] = 2 32 self.t['2_w'] = 1 33 self.t['2_e'] = 3 34 self.t['3_s'] = 7 35 self.t['3_w'] = 2 36 self.t['3_e'] = 4 37 self.t['4_w'] = 3 38 self.t['4_e'] = 5 39 self.t['5_s'] = 8 40 self.t['5_w'] = 4 41 self.gamma = 0.8 #折扣因子 42 self.seed = None 43 self.viewer = None 44 self.state = None 45 46 def _seed(self, seed=None): 47 self.np_random, seed = seeding.np_random(seed) 48 return [seed] 49 50 def getTerminal(self): 51 return self.terminate_states 52 53 def getGamma(self): 54 return self.gamma 55 56 def getStates(self): 57 return self.states 58 59 def getAction(self): 60 return self.action_s 61 62 def getTerminate_states(self): 63 return self.terminate_states 64 65 def setAction(self,s): 66 self.state=s 67 68 def _step(self, action): 69 #系统当前状态 70 state = self.state 71 if state in self.terminate_states: 72 return state, 0, True, {} 73 key = "%d_%s"%(state, action) #将状态和动作组成字典的键值 74 75 #状态转移 76 if key in self.t: 77 next_state = self.t[key] 78 else: 79 next_state = state 80 81 is_terminal = False 82 if next_state in self.terminate_states: 83 is_terminal = True 84 85 if key not in self.rewards: 86 r = 0.0 87 else: 88 r = self.rewards[key] 89 return next_state, r, is_terminal, {} 90 91 def _reset(self): 92 self.state = self.states[int(random.random() * len(self.states))] 93 return self.state 94 95 def _render(self, mode='human', close=False): 96 if close: 97 if self.viewer is not None: 98 self.viewer.close() 99 self.viewer = None 100 return 101 102 screen_width = 600 103 screen_height = 400 104 105 if self.viewer is None: 106 from gym.envs.classic_control import rendering 107 self.viewer = rendering.Viewer(screen_width, screen_height) 108 #创建网格世界 109 self.line1 = rendering.Line((100,300),(500,300)) 110 self.line2 = rendering.Line((100, 200), (500, 200)) 111 self.line3 = rendering.Line((100, 300), (100, 100)) 112 self.line4 = rendering.Line((180, 300), (180, 100)) 113 self.line5 = rendering.Line((260, 300), (260, 100)) 114 self.line6 = rendering.Line((340, 300), (340, 100)) 115 self.line7 = rendering.Line((420, 300), (420, 100)) 116 self.line8 = rendering.Line((500, 300), (500, 100)) 117 self.line9 = rendering.Line((100, 100), (180, 100)) 118 self.line10 = rendering.Line((260, 100), (340, 100)) 119 self.line11 = rendering.Line((420, 100), (500, 100)) 120 121 #创建第一个骷髅 122 self.kulo1 = rendering.make_circle(40) 123 self.circletrans = rendering.Transform(translation=(140,150)) 124 self.kulo1.add_attr(self.circletrans) 125 self.kulo1.set_color(0,0,0) 126 127 #创建第二个骷髅 128 self.kulo2 = rendering.make_circle(40) 129 self.circletrans = rendering.Transform(translation=(460, 150)) 130 self.kulo2.add_attr(self.circletrans) 131 self.kulo2.set_color(0, 0, 0) 132 133 #创建金条 134 self.gold = rendering.make_circle(40) 135 self.circletrans = rendering.Transform(translation=(300, 150)) 136 self.gold.add_attr(self.circletrans) 137 self.gold.set_color(1, 0.9, 0) 138 139 #创建机器人 140 141 self.robot= rendering.make_circle(30) 142 self.robotrans = rendering.Transform() 143 self.robot.add_attr(self.robotrans) 144 self.robot.set_color(0.8, 0.6, 0.4) 145 146 self.line1.set_color(0, 0, 0) 147 self.line2.set_color(0, 0, 0) 148 self.line3.set_color(0, 0, 0) 149 self.line4.set_color(0, 0, 0) 150 self.line5.set_color(0, 0, 0) 151 self.line6.set_color(0, 0, 0) 152 self.line7.set_color(0, 0, 0) 153 self.line8.set_color(0, 0, 0) 154 self.line9.set_color(0, 0, 0) 155 self.line10.set_color(0, 0, 0) 156 self.line11.set_color(0, 0, 0) 157 158 self.viewer.add_geom(self.line1) 159 self.viewer.add_geom(self.line2) 160 self.viewer.add_geom(self.line3) 161 self.viewer.add_geom(self.line4) 162 self.viewer.add_geom(self.line5) 163 self.viewer.add_geom(self.line6) 164 self.viewer.add_geom(self.line7) 165 self.viewer.add_geom(self.line8) 166 self.viewer.add_geom(self.line9) 167 self.viewer.add_geom(self.line10) 168 self.viewer.add_geom(self.line11) 169 self.viewer.add_geom(self.kulo1) 170 self.viewer.add_geom(self.kulo2) 171 self.viewer.add_geom(self.gold) 172 self.viewer.add_geom(self.robot) 173 174 if self.state is None: return None 175 #self.robotrans.set_translation(self.x[self.state-1],self.y[self.state-1]) 176 self.robotrans.set_translation(self.x[self.state-1], self.y[self.state- 1]) 177 return self.viewer.render(return_rgb_array=mode == 'rgb_array')
#### Note : ####
2018.05.17 修整小bug