之前讲到Sarsa和Q Learning都不太适合解决大规模问题,为什么呢?
因为传统的强化学习都有一张Q表,这张Q表记录了每个状态下,每个动作的q值,但是现实问题往往极其复杂,其状态非常多,甚至是连续的,
比如足球场上足球的位置,此时,内存将无力承受这张Q表。
价值函数近似
既然Q表太大,那么怎么办呢?
假设我们可以找到一种方法来预测q值,那么在某个状态下,就可以估计其每个动作的q值,这样就不需要Q表了,这就是价值函数近似。
假设这个函数由参数w描述,那么
状态价值函数就表示为 v(s)≈f(s, w)
动作状态价值函数就表示为 q(s, a)≈f(s, a, w)
当然这个函数可以是线性的wx+b,也可以是决策树,神经网络等等,以神经网络最为常用。
价值函数近似有如下三种形式
1. 输入s,输出v
2. 输入s a,输出q
3.输入s,输出每个动作的q
Deep Q Learning 算法简介
先大概回顾Q Learning算法,初始状态S下ε-贪婪法选择动作,执行,进入状态S’,贪婪选择动作,更新q值,切换状态到S',
Deep Q Learning 大致思路也是如此,只是在某些细节有变化,具体如下
这是正版描述,结合实例翻译大致如下
输入:{S A R γ ε },迭代轮数M,每次迭代次数T,初始的神经网络模型,初始的记忆库(空),batch_size训练样本数,
输出:预测q值的神经网络
for episode = 1:M
初始化状态s
for step = 1:T
### 制造记忆库
用神经网络计算s下所有动作的q值,(相当于Q-Learing中状态s对应的q值)这就是q估计
基于ε-贪婪法选择动作A,执行动作A,获得奖励R,进入状态S’
把 {S A R S'}存入记忆库
s=S'
### 训练神经网络
从记忆库随机取batch_size样本
用神经网络计算S下所有动作的q值,(相当于Q-Learing中状态S对应的q值)这就是q估计
用神经网络计算S’下所有动作的q值,(相当于Q-Learning中状态S’对应的q值)
找到样本中状态S下执行的动作A和奖励R,
计算 q 现实,R+γmaxaf(S',w), (相当于Q-Learning中的 R+γmaxq(S',a)) , 当然如果S’为终止状态,那只有R
利用 (q现实-q估计)2作为损失函数更新神经网络参数
当然,这只是大致思路,具体算法时可以根据经验适当调整。
传统的Q Learning 是边实验边学习,而神经网络需要历史数据,
Deep Q Learning采用记忆库的方式解决这个问题,所有在实际算法中,往往需要先实验几次,以建立记忆库,
Deep Q Learning把这个方法叫 experience replay,经验回放,这里不一定要“亲自”去实验,也可以用“别人”实验的结果作为记忆。
实例
openAI中的例子很多,由于gym环境不能很好地支持windows,故选择了这个例子。
import gym import tensorflow as tf import numpy as np import random from collections import deque # Hyper Parameters for DQN GAMMA = 0.9 # discount factor for target Q INITIAL_EPSILON = 0.5 # starting value of epsilon FINAL_EPSILON = 0.01 # final value of epsilon REPLAY_SIZE = 10000 # experience replay buffer size BATCH_SIZE = 32 # size of minibatch class DQN(): # DQN Agent def __init__(self, env): # init experience replay self.replay_buffer = deque() # 记忆库 # init some parameters self.time_step = 0 # self.epsilon = INITIAL_EPSILON self.state_dim = env.observation_space.shape[0] # 状态 self.action_dim = env.action_space.n # 动作 self.create_Q_network() self.create_training_method() # Init session self.session = tf.InteractiveSession() self.session.run(tf.global_variables_initializer()) def create_Q_network(self): # 神经网络,输入s,输出q value # network weights W1 = self.weight_variable([self.state_dim,20]) b1 = self.bias_variable([20]) W2 = self.weight_variable([20,self.action_dim]) b2 = self.bias_variable([self.action_dim]) # input layer self.state_input = tf.placeholder("float",[None,self.state_dim]) # hidden layers h_layer = tf.nn.relu(tf.matmul(self.state_input,W1) + b1) # Q Value layer self.Q_value = tf.matmul(h_layer,W2) + b2 def create_training_method(self): # 训练方法 self.action_input = tf.placeholder("float",[None,self.action_dim]) # one hot presentation self.y_input = tf.placeholder("float",[None]) Q_action = tf.reduce_sum(tf.multiply(self.Q_value,self.action_input), reduction_indices = 1) # q 估计 self.cost = tf.reduce_mean(tf.square(self.y_input - Q_action)) self.optimizer = tf.train.AdamOptimizer(0.0001).minimize(self.cost) def perceive(self,state,action,reward,next_state,done): # 存储记忆 并 训练网络 one_hot_action = np.zeros(self.action_dim) one_hot_action[action] = 1 self.replay_buffer.append((state,one_hot_action,reward,next_state,done)) if len(self.replay_buffer) > REPLAY_SIZE: # 记忆大于经验回放大小,就删掉之前的 self.replay_buffer.popleft() if len(self.replay_buffer) > BATCH_SIZE: # 记忆大于batch,就开始训练网络 self.train_Q_network() def train_Q_network(self): # 训练网络 self.time_step += 1 # Step 1: obtain random minibatch from replay memory # 随机取样本 minibatch = random.sample(self.replay_buffer,BATCH_SIZE) state_batch = [data[0] for data in minibatch] action_batch = [data[1] for data in minibatch] reward_batch = [data[2] for data in minibatch] next_state_batch = [data[3] for data in minibatch] # Step 2: calculate y y_batch = [] # q 现实 Q_value_batch = self.Q_value.eval(feed_dict={self.state_input:next_state_batch}) # 神经网络预测q值,注意是下个状态的 for i in range(0,BATCH_SIZE): done = minibatch[i][4] if done: # 回合结束 y_batch.append(reward_batch[i]) else : y_batch.append(reward_batch[i] + GAMMA * np.max(Q_value_batch[i])) # 更新q值 self.optimizer.run(feed_dict={ self.y_input:y_batch, self.action_input:action_batch, self.state_input:state_batch }) def egreedy_action(self,state): Q_value = self.Q_value.eval(feed_dict = {self.state_input:[state]})[0] self.epsilon -= (INITIAL_EPSILON - FINAL_EPSILON)/10000 if random.random() <= self.epsilon: return random.randint(0,self.action_dim - 1) else: return np.argmax(Q_value) def action(self,state): return np.argmax(self.Q_value.eval(feed_dict = {self.state_input:[state]})[0]) def weight_variable(self,shape): initial = tf.truncated_normal(shape) return tf.Variable(initial) def bias_variable(self,shape): initial = tf.constant(0.01, shape = shape) return tf.Variable(initial) # --------------------------------------------------------- # Hyper Parameters ENV_NAME = 'CartPole-v0' EPISODE = 10000 # Episode limitation STEP = 300 # Step limitation in an episode TEST = 10 # The number of experiment test every 100 episode def main(): # initialize OpenAI Gym env and dqn agent env = gym.make(ENV_NAME) agent = DQN(env) for episode in range(EPISODE): # initialize task state = env.reset() # Train # 生成记忆 for step in range(STEP): action = agent.egreedy_action(state) # e-greedy action for train next_state,reward,done,_ = env.step(action) # Define reward for agent # reward_agent = -1 if done else 0.1 agent.perceive(state,action,reward,next_state,done) state = next_state if done: break # Test every 100 episodes if episode % 100 == 0: total_reward = 0 for i in range(TEST): state = env.reset() for j in range(STEP): env.render() action = agent.action(state) # direct action for test state,reward,done,_ = env.step(action) total_reward += reward if done: break ave_reward = total_reward/TEST print('episode: ',episode,'Evaluation Average Reward:',ave_reward) if ave_reward >= 200: break if __name__ == '__main__': main()
这里做了一些简单优化:
1. 探索率逐渐减小
2. 记忆库不能太大,超过限值就删除最早的记忆,最早的记忆太过久远,参考意义不大,删掉就不会被随机选中。
3. 在记忆库小于batch时,只实验,不训练。
总结
这里虽然解决了大规模问题,但是Deep Q Learning可能出现不收敛的情况,所以产生了很多Deep Q Learning的变种来优化该算法。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)