[RL學習篇] [#1]解決grid_mdp.py 不能執行的問題?

在 深入淺出強化學習:原理入門 一書中第一個例子的grid_mdp.py不能執行,會有錯誤。

於是我研究了一下之後,改成下列方程式就可以執行了,給有需要的人下載回去看看。

 

  1 import logging
  2 import numpy
  3 import random
  4 from gym import spaces
  5 import gym
  6 from gym.utils import seeding
  7 logger = logging.getLogger(__name__)
  8 class GridEnv(gym.Env):
  9     metadata = {
 10         'render.modes': ['human', 'rgb_array'],
 11         'video.frames_per_second': 2
 12     }
 13 
 14     def __init__(self):
 15         self.states = [1,2,3,4,5,6,7,8] #状态空间
 16         self.x=[140,220,300,380,460,140,300,460]
 17         self.y=[250,250,250,250,250,150,150,150]
 18         self.terminate_states = dict()  #终止状态为字典格式
 19         self.terminate_states[6] = 1
 20         self.terminate_states[7] = 1
 21         self.terminate_states[8] = 1
 22         self.action_s = ['n','e','s','w']
 23 
 24 
 25         self.rewards = dict();        #回报的数据结构为字典
 26         self.rewards['1_s'] = -1.0
 27         self.rewards['3_s'] = 1.0
 28         self.rewards['5_s'] = -1.0
 29         self.t = dict();             #状态转移的数据格式为字典
 30         self.t['1_s'] = 6
 31         self.t['1_e'] = 2
 32         self.t['2_w'] = 1
 33         self.t['2_e'] = 3
 34         self.t['3_s'] = 7
 35         self.t['3_w'] = 2
 36         self.t['3_e'] = 4
 37         self.t['4_w'] = 3
 38         self.t['4_e'] = 5
 39         self.t['5_s'] = 8
 40         self.t['5_w'] = 4
 41         self.gamma = 0.8         #折扣因子
 42         self.seed = None
 43         self.viewer = None
 44         self.state = None
 45 
 46     def _seed(self, seed=None):
 47         self.np_random, seed = seeding.np_random(seed)
 48         return [seed]
 49 
 50     def getTerminal(self):
 51         return self.terminate_states
 52 
 53     def getGamma(self):
 54         return self.gamma
 55 
 56     def getStates(self):
 57         return self.states
 58 
 59     def getAction(self):
 60         return self.action_s
 61 
 62     def getTerminate_states(self):
 63         return self.terminate_states
 64 
 65     def setAction(self,s):
 66         self.state=s
 67 
 68     def _step(self, action):
 69         #系统当前状态
 70         state = self.state
 71         if state in self.terminate_states:
 72             return state, 0, True, {}
 73         key = "%d_%s"%(state, action)   #将状态和动作组成字典的键值
 74 
 75         #状态转移
 76         if key in self.t:
 77             next_state = self.t[key]
 78         else:
 79             next_state = state
 80 
 81         is_terminal = False
 82         if next_state in self.terminate_states:
 83             is_terminal = True
 84 
 85         if key not in self.rewards:
 86             r = 0.0
 87         else:
 88             r = self.rewards[key]
 89         return next_state, r, is_terminal, {}
 90 
 91     def _reset(self):
 92         self.state = self.states[int(random.random() * len(self.states))]
 93         return self.state
 94 
 95     def _render(self, mode='human', close=False):
 96         if close:
 97             if self.viewer is not None:
 98                 self.viewer.close()
 99                 self.viewer = None
100             return
101 
102         screen_width = 600
103         screen_height = 400
104 
105         if self.viewer is None:
106             from gym.envs.classic_control import rendering
107             self.viewer = rendering.Viewer(screen_width, screen_height)
108             #创建网格世界
109             self.line1 = rendering.Line((100,300),(500,300))
110             self.line2 = rendering.Line((100, 200), (500, 200))
111             self.line3 = rendering.Line((100, 300), (100, 100))
112             self.line4 = rendering.Line((180, 300), (180, 100))
113             self.line5 = rendering.Line((260, 300), (260, 100))
114             self.line6 = rendering.Line((340, 300), (340, 100))
115             self.line7 = rendering.Line((420, 300), (420, 100))
116             self.line8 = rendering.Line((500, 300), (500, 100))
117             self.line9 = rendering.Line((100, 100), (180, 100))
118             self.line10 = rendering.Line((260, 100), (340, 100))
119             self.line11 = rendering.Line((420, 100), (500, 100))
120 
121             #创建第一个骷髅
122             self.kulo1 = rendering.make_circle(40)
123             self.circletrans = rendering.Transform(translation=(140,150))
124             self.kulo1.add_attr(self.circletrans)
125             self.kulo1.set_color(0,0,0)
126 
127             #创建第二个骷髅
128             self.kulo2 = rendering.make_circle(40)
129             self.circletrans = rendering.Transform(translation=(460, 150))
130             self.kulo2.add_attr(self.circletrans)
131             self.kulo2.set_color(0, 0, 0)
132 
133             #创建金条
134             self.gold = rendering.make_circle(40)
135             self.circletrans = rendering.Transform(translation=(300, 150))
136             self.gold.add_attr(self.circletrans)
137             self.gold.set_color(1, 0.9, 0)
138 
139             #创建机器人
140 
141             self.robot= rendering.make_circle(30)
142             self.robotrans = rendering.Transform()
143             self.robot.add_attr(self.robotrans)
144             self.robot.set_color(0.8, 0.6, 0.4)
145 
146             self.line1.set_color(0, 0, 0)
147             self.line2.set_color(0, 0, 0)
148             self.line3.set_color(0, 0, 0)
149             self.line4.set_color(0, 0, 0)
150             self.line5.set_color(0, 0, 0)
151             self.line6.set_color(0, 0, 0)
152             self.line7.set_color(0, 0, 0)
153             self.line8.set_color(0, 0, 0)
154             self.line9.set_color(0, 0, 0)
155             self.line10.set_color(0, 0, 0)
156             self.line11.set_color(0, 0, 0)
157 
158             self.viewer.add_geom(self.line1)
159             self.viewer.add_geom(self.line2)
160             self.viewer.add_geom(self.line3)
161             self.viewer.add_geom(self.line4)
162             self.viewer.add_geom(self.line5)
163             self.viewer.add_geom(self.line6)
164             self.viewer.add_geom(self.line7)
165             self.viewer.add_geom(self.line8)
166             self.viewer.add_geom(self.line9)
167             self.viewer.add_geom(self.line10)
168             self.viewer.add_geom(self.line11)
169             self.viewer.add_geom(self.kulo1)
170             self.viewer.add_geom(self.kulo2)
171             self.viewer.add_geom(self.gold)
172             self.viewer.add_geom(self.robot)
173 
174         if self.state is None: return None
175         #self.robotrans.set_translation(self.x[self.state-1],self.y[self.state-1])
176         self.robotrans.set_translation(self.x[self.state-1], self.y[self.state- 1])
177         return self.viewer.render(return_rgb_array=mode == 'rgb_array')

#### Note : ####

2018.05.17 修整小bug

 


 

posted @ 2018-05-16 17:31  Harris_Li  阅读(262)  评论(0编辑  收藏  举报