强化学习-学习笔记7 | Sarsa算法原理与推导
Sarsa算法 是 TD算法的一种,之前没有严谨推导过 TD 算法,这一篇就来从数学的角度推导一下 Sarsa 算法。注意,这部分属于 TD算法的延申。
7. Sarsa算法
7.1 推导 TD target
推导:Derive。
这一部分就是Sarsa 最重要的内核。
折扣回报:
即 将之后 都提出一个 项,后面括号中的式子意义正为
通常认为奖励 依赖于 t 时刻的状态 与 动作 以及 t+1 时刻的状态 。
当时对于为什么依赖于 有疑问,我回去翻看了 学习笔记1:https://www.cnblogs.com/Roboduster/p/16442003.html ,发现并强调了以下这一点:
“值得注意的是,这个 r1 是什么时候给的?是在状态 state s2 的时候给的。”
状态价值函数 是回报 的期望;
- 用折扣回报的变换式,把替换掉:
- 有两项期望,分解开:
下面研究上式的第二项:
其等于
Q 是 U 的期望:所以 ,期望的期望还是原来的期望;这里是逆用这个性质。这么做是为了让等式两边都有 函数,如下:
于是便得到:
右侧有一个期望,但直接求期望很困难,所以通常是对期望求蒙特卡洛近似。
- 近似为观测到奖励
- 用观测到的 来近似
- 得到蒙特卡洛近似值
- 将这个值表示为 TD target
TD learning 目标:让 来接近部分真实的奖励 。
完全是估计,而 包含了一部分真实奖励,所以 更可靠。
7.2 Sarsa算法过程
这是一种TD 算法。
a. 表格形式
如果我们想要学习动作价值 ,假设状态和动作都是有限的,可以画一个表来表示:
- 表每个元素代表一个动作价值;
- 用 Sarsa 算法更新表格,每次更新一个元素;
-
在表格形式中,每次观测到一个四元组,称为一个 transition
-
根据策略函数 随机采样计算下一个动作,记作;
-
计算TD target: ,
前一部分是观测到的奖励,后面一部分是对未来动作的打分, 可以通过查表得知。
表最开始是通过一定方式初始化的(比如随机),然后通过不断计算来更新表格。
通过查表,还知道的值,可以计算:
-
TD error:;
-
最后用 来更新:,并写入表格相应的位置
是学习率。通过TD error 更新,可以让 Q 更好的接近 。
每一步中,Sarsa 算法用 来更新 ,sarsa,这就是算法名字的由来。
b. 神经网络形式
值得留意的是表格形式的假设:假设状态和动作都是有限的,而当状态和动作很多,表格就会很大,很难学习。
-
用神经网络-价值网络 来近似,Sarsa算法可以训练这个价值网络。
- actor-critic 那篇用过 Sarsa 算法,想不起来往下看:
- q 和 Q 都与 策略函数 有关。
- 网络参数 初始时随机初始化,后续不断更新。
输入状态是 s ,输出就是所有动作的价值
- actor-critic 方法中,q 作为 critic 用来评估 actor;用 sarsa 这一 TD 学习算法更新的价值网络。
- TD target:
- TD error:
- Loss: ,我们的目的是通过更新网络参数 w 来降低 Loss;
- 梯度:
- 梯度下降更新 w:
7.3 一些解惑 / 有什么不同
这一篇跟第二篇价值学习内容看似很接近,甚至在第四篇 actor-critic 中也有提及,可能会困惑 这个第七篇有什么特别的,我也困惑了一会儿,然后我发现是自己的学习不够仔细:
第二篇和第四篇的 价值网络 学习方法并不同。虽然都用到了 以TD target 为代表的TD 算法。但是两者的学习函数并不相同!
Sarsa算法 学习动作价值函数
Actor-Critic 中的价值网络j就是用 Sarsa 训练的
而第二篇 DQN 中的 TD 学习 是训练最优动作价值函数:
而这种方法在下一篇中很快会提及,这就是 Q-learning 方法。
参考:
本文作者:climerecho
本文链接:https://www.cnblogs.com/Roboduster/p/16454245.html
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步