[RL學習篇][#3] 自動學習grid_mdp最佳的策略

本文修改 policy_iteration.py程式,讓他可以執行[#1]的程式,並找出最佳動作。

 

 1 # /bin/python
 2 import numpy;
 3 import random;
 4 import gym;
 5 #from grid_mdp import Grid_Mdp
 6 
 7 
 8 class Policy_Value:
 9     def __init__(self, grid_mdp):
10         self.v = [0.0 for i in range(len(grid_mdp.env.states) + 1)] # 初始變數v <-- 值函數
11 
12         self.pi = dict()
13         for state in grid_mdp.env.states:
14             if state in grid_mdp.env.terminate_states: continue
15             self.pi[state] = grid_mdp.env.action_s[0] #初始pi <-- 策略pi
16 
17     def policy_improve(self, grid_mdp):
18 
19         for state in grid_mdp.env.states:
20             grid_mdp.env.setAction(state)  # upate state
21             if state in grid_mdp.env.terminate_states: continue
22 
23             a1 = grid_mdp.env.action_s[0]
24             s, r, t, z = grid_mdp.env._step(a1)
25             v1 = r + grid_mdp.env.gamma * self.v[s]
26 
27             for action in grid_mdp.env.action_s:
28                 s, r, t, z = grid_mdp.env._step(action)
29                 if v1 < r + grid_mdp.env.gamma * self.v[s]: # 當action有更好的值,則更新動作
30                     a1 = action
31                     v1 = r + grid_mdp.env.gamma * self.v[s]
32 
33             self.pi[state] = a1   # 紀錄最佳動作
34 
35     def policy_evaluate(self, grid_mdp):
36         for i in range(1000):
37             delta = 0.0
38             for state in grid_mdp.env.states:
39                 grid_mdp.env.setAction(state) # upate state
40                 if state in grid_mdp.env.terminate_states: continue
41                 action = self.pi[state]
42 
43                 s, r, t, z = grid_mdp.env.step(action)
44                 new_v = r + grid_mdp.env.gamma * self.v[s]
45                 delta += abs(self.v[state] - new_v)
46                 self.v[state] = new_v
47 
48             if delta < 1e-6:
49                 break;
50 
51     def policy_iterate(self, grid_mdp):
52         for i in range(100):
53             self.policy_evaluate(grid_mdp);
54             self.policy_improve(grid_mdp);
55 
56 
57 if __name__ == "__main__":
58     #grid_mdp = Grid_Mdp()
59     env = gym.make('GridWorld-v0')
60 
61     policy_value = Policy_Value(env)
62     policy_value.policy_iterate(env)
63     print("value:")
64     for i in range(1, 6):
65         print("%d:%f\t" % (i, policy_value.v[i]), )
66     print("")
67 
68     print("policy:")
69     for i in range(1, 6):
70         print("%d->%s\t" % (i, policy_value.pi[i]), )
71     print("")

執行結果如下:

-----------------------------------------------------------------------------------------------------------------------------------------------------

/home/lsa-dla/anaconda3/envs/tensorflow/bin/python /home/lsa-dla/PycharmProjects/grid_mdp/lsa_test2.py
WARN: Environment '<class 'gym.envs.classic_control.grid_mdp.GridEnv'>' has deprecated methods. Compatibility code invoked.
value:
1:0.640000
2:0.800000
3:1.000000
4:0.800000
5:0.640000

policy:
1->e
2->e
3->s
4->w
5->w


Process finished with exit code 0

 ------------------------------------------------------------------------------------------------------------------------------------------------------

reference:

[1]  Reinforcement_Learning_Blog/2.强化学习系列之二:模型相关的强化学习/

 

posted @ 2018-05-17 19:44  Harris_Li  阅读(339)  评论(0编辑  收藏  举报