强化学习代码实战-04时序差分算法(N步Sarsa)

import numpy as np
import random

# 获取一个格子的状态
def get_state(row, col):
    if row != 3:
        return 'ground'

    if row == 3 and col == 0:
        return 'ground'

    if row == 3 and col == 11:
        return 'terminal'

    return 'trap'

# 根据当前所处的格子,选取一个动作
def get_action(row, col):
    # 以一定的概率探索
    if random.random() < 0.1:
        return random.choice(range(4))
    
    # 返回当前Q表格中分数最高的动作
    return Q[row, col].argmax()

# 在某一状态下执行动作,获得对应奖励
def move(row, col, action):
    #如果当前已经在陷阱或者终点,则不能执行任何动作
    if get_state(row, col) in ['trap', 'terminal']:
        return row, col, 0

    #
    if action == 0:
        row -= 1

    #
    if action == 1:
        row += 1

    #
    if action == 2:
        col -= 1

    #
    if action == 3:
        col += 1

    #不允许走到地图外面去
    row = max(0, row)
    row = min(3, row)
    col = max(0, col)
    col = min(11, col)

    #是陷阱的话,奖励是-100,否则都是-1
    reward = -1
    if get_state(row, col) == 'trap':
        reward = -100

    return row, col, reward

# 初始化Q表格,每个格子采取每个动作的分数,刚开始都是未知的故为零
Q = np.zeros([4, 12, 4])

# 存储历史的状态、动作和奖励,后期要回溯这些历史数据
state_list = []
action_list = []
reward_list = []

# 获取5个时间步的更新分数
def get_update_list(next_row, next_col, next_action):
    #初始化的target是最后一个state和最后一个action的分数
    target = Q[next_row, next_col, next_action]

    #计算每一步的target
    #每一步的tagret等于下一步的tagret*0.9,再加上本步的reward
    #时间从后往前回溯,越以前的tagret会累加的信息越多
    #[4, 3, 2, 1, 0]
    target_list = []
    for i in reversed(range(5)):
        target = 0.9 * target + reward_list[i]
        target_list.append(target)

    #把时间顺序正过来
    target_list = list(reversed(target_list))

    #计算每一步的value
    value_list = []
    for i in range(5):
        row, col = state_list[i]
        action = action_list[i]
        value_list.append(Q[row, col, action])


    #计算每一步的更新量
    update_list = []
    for i in range(5):
        #根据时序差分算法,当前state,action的分数 = 下一个state,action的分数*gamma + reward
        #此处是求两者的差,越接近0越好
        update = target_list[i] - value_list[i]

        #这个0.1相当于lr
        update *= 0.1

        update_list.append(update)

    return update_list

def train():
    for epoch in range(1500):
        # 初始化起点及动作
        row = random.choice(range(4))
        col = 0
        action = get_action(row, col)
        
        # 初始化历史存储
        state_list.clear()
        action_list.clear()
        reward_list.clear()
        
        rewards_sum = 0
        
        # 不停迭代更新直至进入陷阱或结束
        while get_state(row, col) not in ["terminal", "trap"]:
            # 一步移动,获取下一状态信息
            next_row, next_col, reward = move(row, col, action)
            next_action = get_action(next_row, next_col)
            rewards_sum += reward
            # 记录历史数据
            state_list.append([row, col])
            action_list.append(action)
            reward_list.append(reward)
            # 积累到5步之后再更新参数
            if len(state_list) == 5:
                update_list = get_update_list(next_row, next_col, next_action)
                # 每次只更新第一步分数,修正Q
                row, col = state_list[0]
                action = action_list[0]
                update = update_list[0]
                Q[row, col, action] += update
                # 更新过了,移除第一步数据,为之后的数据保持新空间
                state_list.pop(0)
                action_list.pop(0)
                reward_list.pop(0)   # 奖励pop,不是update_list
            # 状态更新
            row = next_row
            col = next_col
            action = next_action
        # 结束后历史还有数据,将其更新完
        for i in range(len(state_list)):
            row, col = state_list[i]
            action = action_list[i]
            update = update_list[i]
            Q[row, col, action] += update
        if epoch % 100 == 0:
            print(f"epoch:{epoch}, reward:{rewards_sum}")
        

#打印游戏,方便测试
def show(row, col, action):
    graph = [
        '', '', '', '', '', '', '', '', '', '', '', '', '', '',
        '', '', '', '', '', '', '', '', '', '', '', '', '', '',
        '', '', '', '', '', '', '', '', '', '', '', '', '', '',
        '', '', '', '', '', ''
    ]

    action = {0: '', 1: '', 2: '', 3: ''}[action]

    graph[row * 12 + col] = action

    graph = ''.join(graph)

    for i in range(0, 4 * 12, 12):
        print(graph[i:i + 12])
show(1,1,2)

from IPython import display
import time


def play():
    #起点
    row = random.choice(range(4))
    col = 0

    #最多玩N步
    for _ in range(200):

        #获取当前状态,如果状态是终点或者掉陷阱则终止
        if get_state(row, col) in ['trap', 'terminal']:
            break

        #选择最优动作
        action = Q[row, col].argmax()

        #打印这个动作
        display.clear_output(wait=True)
        time.sleep(0.1)
        show(row, col, action)

        #执行动作
        row, col, reward = move(row, col, action)


play()

 重点:结合N步的reward和最后一个状态的价值估计,玩N步会后更新第一步的参数

posted @ 2022-11-11 15:18  今夜无风  阅读(102)  评论(0编辑  收藏  举报