强化学习代码实战-03动态规划算法(策略迭代)
# 获取一个格子的状态 def get_state(row, col): if row!=3: return 'ground' if row == 3 and col == 11: return 'terminal' if row == 3 and col == 0: return 'ground' return 'trap' # 在某一状态下执行动作,获得对应奖励 def move(row, col, action): # 状态检查-进入陷阱或结束,则不能执行任何动作,获得0奖励 if get_state(row, col) in ["trap", "terminal"]: return row, col, 0 # 执行上下左右动作后,对应的位置变化 if action == 0: row -= 1 if action == 1: row += 1 if action == 2: col -= 1 if action == 3: col += 1 # 最小不能小于零,最大不能大于3 row = max(0, row) row = min(3, row) col = max(0, col) col = min(11, col) # 掉进trap奖励-100,其余每走一步奖励-1,让agent尽快完成任务 reward = -1 if get_state(row, col) == 'trap': reward = -100 return row, col, reward move(0, 0, 0) # 初始化每个格子的价值,最开始我们不知道价值情况,故都设置为零 values = np.zeros([4, 12]) # 初始化每个格子采用动作的概率(策略),未知,故平均概率 pi = np.ones([4, 14, 4]) * 0.25 values, pi[0] # 计算在一个状态下执行相应动作的得分,走一步(奖励+折扣回报) def get_qsa(row, col, action): """动作价值""" # 在当前状态下执行动作,获取下一个状态和reward next_row, next_col, reward = move(row, col, action) # 查找Q表格(values)计算下一状态的分数,并打一个折扣 value = values[next_row, next_col] * 0.9 # 如果下一个状态是陷阱或者重点,则下一个状态的分数是0,游戏结束没有得分 if get_state(next_row, next_col) in ["trap", "terminal"]: value = 0 return reward + value # 策略评估(用格子的得分多少评估策略的好坏) def get_values(): # 初始化新的values,重新评估所有格子的分数 new_values = np.zeros([4, 12]) # 遍历所有的格子 for row in range(4): for col in range(12): # 计算当前格子每个动作的分数 action_value = np.zeros(4) for action in range(4): action_value[action] = get_qsa(row, col, action) # 动作的得分乘以对应策略概率 action_value *= pi[row, col] # 期望 new_values[row, col] = action_value.sum() return new_values # 策略提升 def get_pi(): # 初始新的每个格子下采用动作的概率,重新评估 new_pi = np.zeros([4, 12, 4]) # 遍历 for row in range(4): for col in range(12): # 格子每个动作的分数 action_value = np.zeros(4) for action in range(4): action_value[action] = get_qsa(row, col, action) # 贪婪方法找到最高得分的策略,并统计个数 count = (action_value == action_value.max()).sum() # 让这些动作均分概率,其余赋值为零.理论证明会收敛 for action in range(4): if action_value[action] == action_value.max(): new_pi[row, col, action] = 1 / count else: new_pi[row, col, action] = 0 return new_pi # 循环迭代策略评估和策略提升,寻找最优解 for _ in range(10): for _ in range(100): values = get_values() pi = get_pi() values, pi
文献参考:https://hrl.boyuai.com/chapter/1/%E5%8A%A8%E6%80%81%E8%A7%84%E5%88%92%E7%AE%97%E6%B3%95
时刻记着自己要成为什么样的人!