Q-learning的算法:
(1)先初始化一个Q table,Q table的行数是state的个数,列数是action的个数。
(2)先随机选择一个作为初始状态S1,根据一些策略选择此状态下的动作,比如贪心策略,假设选择的动作为A1。
(3)判断由A1动作之后的状态S2是不是终止状态,如果是终止状态,返回的reward,相当于找到了宝藏,游戏结束,如果不是最终状态,在S2状态时选择此时使Q值最大的action作为下一步的动作。可以得到一个实际的Q值。Q(S1,A1)=R+λ*maxQ(S2)。更新Q table中的Q(S1,A1)。Q(S1,A1)=Q(S1,A1)+α*[R+λ*maxQ(S2)-Q(S1,A1)], []里面是实际的Q值减去估计的Q值。
简单的代码如下:
1 #coding=utf-8 2 import numpy as np 3 import pandas as pd 4 import time 5 #计算机产生一段伪随机数,每次运行的时候产生的随机数都是一样的 6 np.random.seed(2) 7 #创建几个全局变量 8 N_STATES=6#状态的个数,一共有六个状态0-5状态 9 ACTIONS=["left","right"]#action只有两个左和右 10 EPSILON=0.9#贪心策略 11 ALPHA=0.1#学习率 12 LAMBDA=0.9#discount factor 13 MAX_EPISODEs=10#一共训练10次 14 FRESH_TIME=0.1 15 #初始化一个Q-table,我觉得Q-table里面的值初始化成什么样子应该不影响最终的结果 16 def build_q_table(n_states,actions): 17 table=pd.DataFrame( 18 np.zeros((n_states,len(actions))), 19 columns=actions, 20 ) 21 # print(table) 22 return(table) 23 # build_q_table(N_STATES,ACTIONS) 24 def choose_action(state,q_table): 25 state_action=q_table.iloc[state,:] 26 if (np.random.uniform()>EPSILON) or (state_action.all()==0): 27 action_name=np.random.choice(ACTIONS) 28 else: 29 action_name=state_action.idxmax() 30 return action_name 31 def get_env_feedback(s,A): 32 if A=="right": 33 if s==N_STATES-2: 34 s_="terminal" 35 R=1 36 else: 37 s_=s+1 38 R=0 39 else: 40 R=0 41 if s==0: 42 s_=s 43 else: 44 s_=s-1 45 return s_,R 46 def update_env(S,episode,step_couter): 47 env_list=["-"]*(N_STATES-1)+["T"] 48 if S=="terminal": 49 interaction="Episode %s:total_steps=%s"%(episode+1,step_couter) 50 print("\r{}".format(interaction),end='') 51 time.sleep(2) 52 print('\r ',end='') 53 else: 54 env_list[S]='0' 55 interaction=''.join(env_list) 56 print("\r{}".format(interaction),end='') 57 time.sleep(FRESH_TIME) 58 def rl(): 59 #先初始化一个Q table 60 q_table=build_q_table(N_STATES,ACTIONS) 61 for episode in range(MAX_EPISODEs): 62 step_counter=0 63 #选择一个初始的S 64 S=0 65 is_terminal=False 66 update_env(S,episode,step_counter) 67 #如果S不是终止状态的话,选择动作,得到环境给出的一个反馈S_(新的状态)和R(奖励) 68 while not is_terminal: 69 A=choose_action(S,q_table) 70 S_,R=get_env_feedback(S,A) 71 q_predict=q_table.ix[S,A] 72 if S_!="terminal": 73 #算出来实际的Q值 74 q_target=R+LAMBDA*q_table.iloc[S_,:].max() 75 else: 76 q_target=R 77 is_terminal=True 78 q_table.ix[S,A]+=ALPHA*(q_target-q_predict) 79 S=S_ 80 update_env( 81 S,episode,step_counter+1 82 ) 83 step_counter=step_counter+1 84 return q_table 85 86 if __name__=="__main__": 87 q_table=rl() 88 print("\r\nQ-table:\n") 89 print(q_table)