强化学习之Sarsa (时间差分学习)
上篇文章讲到Q-learning, Sarsa与Q-learning的在决策上是完全相同的,不同之处在于学习的方式上
这次我们用openai gym的Taxi来做演示
Taxi是一个出租车的游戏,把顾客送到目的地+20分,每走一步-1分,如果在路上把乘客赶下车的话扣10分
简要
Sarsa是一种在线学习算法,也就是on-polic,Sarsa在每次更新算法时都是基于确定的action,而Q-learning还没有确定
Sarsa相对比较保守,他的每一步行动都是基于下一个Q(s',a')来完成的
我们来看Sarsa的算法部分
是不是看起来很眼熟,没错和Q-learning的区别很小
Q-learning每次都时action'都选择最大化,而Sarsa每次更新都会选择下一个action,在我们对代码中对应的代码也就是
obervation_, reward, done, info=env.step(action)
action_=choise(obervation_)
游戏开始
首先我们初始化游戏环境
import gym import numpy as np env=gym.make('Taxi-v2') env.seed(1995) MAX_STEP=env.spec.timestep_limit ALPHA=0.01 EPS=1 GAMMA=0.8
TRACE_DACAY=0.9
q_table=np.zeros([env.observation_space.n,env.action_space.n],dtype=np.float32)
eligibility_trace=np.zeros([env.observation_space.n,env.action_space.n],dtype=np.float32)
对没错,Sarsa还是需要Q表来保存经验的,细心的小伙伴们一定发现我们多了一个eligibility_trace的变量,这个是做什么用的呢,这个是用来保存每个回合的每一步的,在新的回合开始后就会清零
Sarsa的决策上还是和Q-learning相同的
def choise(obervation): if np.random.uniform()<EPS: action=env.action_space.sample() else: action=np.argmax(q_table[obervation]) return action
下面是我们的核心部分,就是学习啦^_^
#这里是Q-learning的学习更新部分
def learn(state,action,reward,obervation_): q_table[state][action]+=ALPHA*(reward+GAMMA*(max(q_table[obervation_])-q_table[state,action]))
#这里是Sarsa的学习更新部分
def learn(state,action,reward,obervation_,action_): global q_table,eligibility_trace error=reward + GAMMA * q_table[obervation_,action_] - q_table[state, action] eligibility_trace[state]*=0 eligibility_trace[state][action]=1 q_table+=ALPHA*error*eligibility_trace eligibility_trace*=GAMMA*TRACE_DACAY
哒当,我用红线标示出来了,聪明的你一定发现了不同对吧
青色标示出来的代表的意思是没经历一轮,我们让他+1证明这是获得reward中不可获取的一步
最后一行
eligibility_trace*=GAMMA*TRACE_DACAY
随着时间来衰减eligibility_trace的值,离获取reward越远的步,他的必要性也就越小
GAME OVER
让我们大干一场吧
下面是所有的代码,小伙伴们快来运行把
import gym import numpy as np env=gym.make('Taxi-v2') env.seed(1995) MAX_STEP=env.spec.timestep_limit ALPHA=0.01 EPS=1 GAMMA=0.8 TRACE_DACAY=0.9 q_table=np.zeros([env.observation_space.n,env.action_space.n],dtype=np.float32) eligibility_trace=np.zeros([env.observation_space.n,env.action_space.n],dtype=np.float32) def choise(obervation): if np.random.uniform()<EPS: action=env.action_space.sample() else: action=np.argmax(q_table[obervation]) return action def learn(state,action,reward,obervation_,action_): global q_table,eligibility_trace error=reward + GAMMA * q_table[obervation_,action_] - q_table[state, action] eligibility_trace[state]*=0 eligibility_trace[state][action]=1 q_table+=ALPHA*error*eligibility_trace eligibility_trace*=GAMMA*TRACE_DACAY SCORE=0 for exp in xrange(50000): obervation=env.reset() EPS-= 0.001 action=choise(obervation) eligibility_trace*=0 for i in xrange(MAX_STEP): # env.render() obervation_, reward, done, info=env.step(action) action_=choise(obervation_) learn(obervation,action,reward,obervation_,action_) obervation=obervation_ action=action_ SCORE+=reward if done: break if exp % 1000 == 0: print 'esp,score (%d,%d)' % (exp, SCORE) SCORE = 0 print 'fenshu is %d'%SCORE
欢迎大家一起来学习^_^
最后附上一幅结果图
效率明显提高了^_^
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· 展开说说关于C#中ORM框架的用法!
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?