【Python】Q-Learning处理CartPole-v1

上一篇配置成功gym环境后,就可以利用该环境做强化学习仿真了。

这里首先用之前学习过的qlearning来处理CartPole-v1模型。

CartPole-v1是一个倒立摆模型,目标是通过左右移动滑块保证倒立杆能够尽可能长时间倒立,最长步骤为500步。

模型控制量是左0、右1两个。

模型状态量为下面四个:

Num Observation Min Max
0 Cart Position -4.8 4.8
1 Cart Velocity -Inf Inf
2 Pole Angle -0.418rad 0.418rad
3 Pole Angular Velocity -Inf Inf

由于要用qtable,但是状态量是连续的,所以我们要先对状态做离散化处理,对应state_dig函数。

然后就是按照qlearning公式迭代即可。

这里在选控制量的时候用了衰减ε-greedy策略,即根据迭代次数,逐步更相信模型的结果而不是随机的结果。

qlearning走迷宫当时的ε-greedy有10%的概率用随机的控制量,衰减ε-greedy策略相对更合理一些。

代码如下:

import gym
import random
import numpy as np

Num = 10
rate = 0.5
factor = 0.9

p_bound = np.linspace(-2.4,2.4,Num-1)
v_bound = np.linspace(-3,3,Num-1)
ang_bound = np.linspace(-0.5,0.5,Num-1)
angv_bound = np.linspace(-2.0,2.0,Num-1)

def state_dig(state):                   #离散化
    p,v,ang,angv = state
    digital_state = (np.digitize(p, p_bound),
            np.digitize(v, v_bound),
            np.digitize(ang, ang_bound), 
            np.digitize(angv, angv_bound))
    return digital_state

if __name__ == '__main__':

    env = gym.make('CartPole-v1')

    action_space_dim = env.action_space.n  
    q_table = np.zeros((Num,Num,Num,Num, action_space_dim))

    for i in range(3000):
        state = env.reset()
        digital_state = state_dig(state)
                
        step = 0
        while True:
            if i%10==0:
                env.render()
            
            step +=1
            epsi = 1.0 / (i + 1)
            if random.random() < epsi:
                action = random.randrange(action_space_dim)
            else:
                action = np.argmax(q_table[digital_state])

            next_state, reward, done, _ = env.step(action)
            next_digital_state = state_dig(next_state)
  
            if done: 
                if step < 400:
                    reward = -1  
                else:   
                    reward = 1
            else:
                reward = 0

            current_q = q_table[digital_state][action]      #根据公式更新qtable
            q_table[digital_state][action] += rate * (reward + factor * max(q_table[next_digital_state])  - current_q) 

            digital_state = next_digital_state

            if done:
                print(step)
                break

最终结果基本都能维持到500步左右,不过即使到500后,随着模型迭代,状态也可能不稳定。

posted @ 2024-04-29 21:08  Dsp Tian  阅读(175)  评论(0编辑  收藏  举报