强化学习-表格型算法Q学习稳定倒立摆小车

[[Q 学习]] 是表格型算法的一种,主要维护了一个 Q-table,里面是 状态-动作 对的价值,分别由一个状态和一个动作来索引。

这里以一个经典的道理摆小车问题来说明如何使用 [[Q 学习]] 算法。
这里会用到两个类,agentbrainbrain 类中来维护 [[强化学习的基本概念|强化学习]] 算法的具体执行,agent 是一层封装,以后也可以用其他算法来实现 brain 类。整个的逻辑也可以参考[[强化学习基本程序框架]]。
首先是 agent

class Agent():
    def __init__(self, num_states, num_actions):
        self.brain = Brain(num_states, num_actions)

    def update_Q_fun(self, observation, reward, action, next_observation):
        self.brain.update_Q_table( observation, reward, action, next_observation)
        
    def get_action(self, observation,step):
        action = self.brain.decide_action(observation, step)
        return action
        

其中 get_action 就是根据状态选择一个动作,可以不放到 brain 类里面,一般都是 \(\epsilon\) -贪心算法在动作空间里面选动作。update_Q_fun 用来更新 Q-table,如果是其他算法,比如说 [[DQN]],换个名字就行。

然后是 brain

class Brain():
    def __init__(self, num_states, num_actions):
        self.num_actions = num_actions
        self.Q_table = np.random.uniform(low=0, high=1, size=(NUM_DIZITIZED**num_states, num_actions))
    
    def bins(self,clip_min, clip_max, num ):
        return bins(clip_min, clip_max, num)
    
    def digitize_state(self,observation) :
        cart_pos, cart_v, pole_angle, pole_v = observation
        digitized = [
        np.digitize(cart_pos, bins=self.bins(-2.4, 2.4, NUM_DIZITIZED)),
        np.digitize(cart_v, bins=self.bins(-3.0, 3.0, NUM_DIZITIZED)) ,
        np.digitize(pole_angle, bins=self.bins(-0.5, 0.5, NUM_DIZITIZED)) ,
        np.digitize(pole_v, bins=self.bins(-2.0, 2.0, NUM_DIZITIZED) )
    ]
        return sum([x * (NUM_DIZITIZED**i) for i, x in enumerate(digitized)])
    
    def update_Q_table(self, observation, reward, action, next_observation):
        state = self.digitize_state(observation) 
        state_next = self.digitize_state(observation_next)
        Max_Q_next = np.max(self.Q_table[state_next][:])
        self.Q_table[state,action] = self.Q_table[state,action] + ETA * (reward + GAMMA * Max_Q_next - self.Q_table[state,action])
        
    def decide_action(self, observation,episode):  
        state = self.digitize_state(observation)
        epsilon = 0.5 * (1 / (episode + 1))
        
        if epsilon <= np.random.uniform(0, 1):
            action = np.argmax(self.Q_table[state][:])
        else:
            action = np.random.choice(self.num_actions)
        return action

update_Q_table 就是根据时序差分的公式更新 Q-table。

\[Q(s_t,a_t)\leftarrow Q(s_t,a_t)+\alpha[R_t+\gamma\max_aQ(s_{t+1},a)-Q(s_t,a_t)] \]

其中,\(\alpha\) 是学习率,\(\gamma\) 是奖励累积的折扣系数。如果这里的 \(\max_aQ(s_{t+1},a)\) 换成 \(Q(s_{t+1},a_{t+1})\) 的话,就是 [[sarsa 算法]]。
decide_action 就是前面提到的 \(\epsilon\) -贪心算法选取动作,这里的 \(\epsilon\) 是随 episode 的数量衰减的。
digitize_state 是为了处理连续状态的。因为倒立摆小车的位置、速度、杆的角度这些信息是连续变量(尽管是在计算机中仿真,我们也认为是连续的),所以为了能在表格中维护,需要将状态进行离散化处理,比如位置在什么范围内就认为其状态是 1。为了减少内存的占用,示例里 NUM_DIZITIZED 等于 6,意思是只用 6 个数来划分表示单一维度里面的连续区间的状态。实际上,如果状态空间任一维度都很大或者状态空间本身就是连续的,后面会有 [[DQN]] 等算法可以处理。

仿真代码:

frames=[]
#环境初始化
env=gym.make('CartPole-v0')
observation = env.reset()#需要先重置环境

NUM_DIZITIZED = 6
GAMMA=0.99   # 时间折扣率
ETA=0.5       # 学习系数
MAX_STEPS=200
NUM_EPISODES = 200

agent = Agent(6,2)
complete_episodes = 0 
is_episode_final = False 

for episode in range(NUM_EPISODES):
    observation = env.reset()

    for step in range(0,MAX_STEPS):
        if is_episode_final:
            frames.append(env.render(mode='rgb_array')) #将各个时刻的图像添加到帧中
        
        action = agent.get_action (observation, episode)
        observation_next, _, done, _ = env.step(action)

		# 自定义的奖励部分
		# 如果结束的时候,已经稳定了190步,就给1的奖励,否则-1.没结束的时候奖励是0
        if done: 
            if step < 190:
                reward = -1 
                complete_episode = 0 
            else:
                reward = 1
                complete_episodes += 1 
        else:
            reward = 0 
        
        agent.update_Q_fun(observation,reward,action,observation_next)
        
        observation= observation_next
        
        if done:
            print(f'{episode} Episode: Finished after {step + 1} time steps')
            break
            
    if complete_episodes >= 10:
        print('10回合连续成功')
        is_episode_final = True
        
display_frames_as_gif(frames)

More Reading

[[边做边学深度强化学习:PyTorch程序设计实践]]

Reference

posted @ 2024-07-07 21:46  pomolnc  阅读(17)  评论(0编辑  收藏  举报