七、策略梯度算法
1 简介
之前介绍的 Q-learning、DQN 和 DQN 改进算法都属于基于价值(value-based)的方法,其中 Q-learning 是处理有限状态的算法,而 DQN 是处理连续状态的算法。处理基于价值的方法,还有一种基于策略(policy-based)的方法。对比两者,基于值函数的方法主要是学习值函数,然后根据值函数导出一个策略,学习过程中并不存在一个显式的策略;而基于策略的方法则是直接显式地学习一个目标策略。策略梯度是基于策略的方法的基础,本章从策略梯度算法说起。
2 策略梯度
基于策略的方法首先需要将策略参数化。假设目标策略 \(π_θ\) 是一个随机性策略,并且处处可微,其中 \(θ\) 是对应的参数。我们可以用一个线性模型或者神经网络模型来为这样一个策略函数建模,输入某个状态,然后输出一个动作的概率分布。我们的目标是要寻找一个最优策略并最大化这个策略在环境中的期望回报。我们将策略学习的目标函数定义为
其中,\(s_0\) 表示初始状态。现在有了目标函数,我们将目标函数对策略 \(θ\) 求导,得到导数后,就可以用梯度上升方法来最大化这个目标函数,从而得到最优策略。省略详细的推导过程,求导结果如下


其中,\(η\) 是动作 \(a\) 的状态分布,\(∇_θπ\) 是策略 \(π\) 相对于参数 \(θ\) 的梯度。
这个梯度可以用来更新策略。需要注意的是,因为上式中期望 \(E\) 的下标是 \(π(S,θ)\),所以策略梯度算法为在线策略(on-policy)算法,即必须使用当前策略 \(π(S,θ)\) 采样得到的数据来计算梯度。直观理解一下策略梯度这个公式,可以发现在每一个状态下,梯度的修改是让策略更多地去采样到带来较高 Q 值的动作,更少地去采样到带来较低 Q 值的动作,如图所示。

在计算策略梯度的公式中,我们需要用到 \(q_π(s,a)\),可以用多种方式对它进行估计。接下来要介绍的 REINFORCE 算法便是采用了蒙特卡洛方法来估计 \(q_π(s,a)\),对于一个有限步数的环境来说,REINFORCE 算法中的策略梯度为:
3 REINFORCE
3.1 算法
我们使用的是CartPole-v1环境,需要导入第五章中的rl_utils.py文件
import gym
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import rl_utils
class PolicyNet(torch.nn.Module):
def __init__(self, state_dim, hidden_dim, action_dim):
super(PolicyNet, self).__init__()
self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, action_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
# softmax()将输出转换为概率分布的形式,针对每个可能的动作输出一个概率值
# dim=1指应用softmax的维度,dim=1为动作维度,确保动作输出的概率和为1
return F.softmax(self.fc2(x), dim=1)
class REINFORCE:
def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, device):
self.policy_net = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=learning_rate)
self.gamma = gamma
self.device = device
def take_action(self, state): # 根据动作概率分布随机采样
state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)
probs = self.policy_net(state)
# 根据动作概率创建一个Categorical(多项分布)实例
action_dist = torch.distributions.Categorical(probs)
action = action_dist.sample() # 调用sample()随机采样
return action.item() # 将采样得到的tensor转换回标准数据类型Number
def update(self, transition_dict):
reward_list = transition_dict['rewards']
state_list = transition_dict['states']
action_list = transition_dict['actions']
G = 0
self.optimizer.zero_grad()
for i in reversed(range(len(reward_list))): # 从最后一步算起
reward = reward_list[i]
state = torch.tensor(np.array([state_list[i]]), dtype=torch.float).to(self.device)
action = torch.tensor([action_list[i]]).view(-1, 1).to(self.device)
log_prob = torch.log(self.policy_net(state).gather(1, action))
G = self.gamma * G + reward
loss = -log_prob * G # 每一步的损失函数
loss.backward() # 反向传播计算梯度
self.optimizer.step() # 梯度下降
learning_rate = 1e-3
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
env_name = "CartPole-v1"
env = gym.make(env_name, new_step_api=True)
env.reset(seed=0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = REINFORCE(state_dim, hidden_dim, action_dim, learning_rate, gamma, device)
return_list = []
for i in range(10):
with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
for i_episode in range(int(num_episodes / 10)):
episode_return = 0
transition_dict = {
'states': [],
'actions': [],
'next_states': [],
'rewards': [],
'dones': []
}
state = env.reset()
done = False
while not done:
action = agent.take_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
if terminated or truncated:
done = True
transition_dict['states'].append(state)
transition_dict['actions'].append(action)
transition_dict['next_states'].append(next_state)
transition_dict['rewards'].append(reward)
transition_dict['dones'].append(done)
state = next_state
episode_return += reward
return_list.append(episode_return)
agent.update(transition_dict)
if (i_episode + 1) % 10 == 0:
pbar.set_postfix({
'episode':
'%d' % (num_episodes / 10 * i + i_episode + 1),
'return':
'%.3f' % np.mean(return_list[-10:])
})
pbar.update(1)
Iteration 0: 100%|██████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 50.00it/s, episode=100, return=65.700]
Iteration 1: 100%|█████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 21.07it/s, episode=200, return=142.200]
Iteration 2: 100%|█████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:12<00:00, 7.94it/s, episode=300, return=190.200]
Iteration 3: 100%|█████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:18<00:00, 5.46it/s, episode=400, return=185.500]
Iteration 4: 100%|█████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:17<00:00, 5.77it/s, episode=500, return=252.200]
Iteration 5: 100%|█████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:20<00:00, 4.99it/s, episode=600, return=245.500]
Iteration 6: 100%|█████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:24<00:00, 4.15it/s, episode=700, return=319.200]
Iteration 7: 100%|█████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:26<00:00, 3.80it/s, episode=800, return=319.700]
Iteration 8: 100%|█████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:33<00:00, 2.96it/s, episode=900, return=409.500]
Iteration 9: 100%|████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:19<00:00, 5.23it/s, episode=1000, return=222.900]
在 CartPole-v1 环境中,满分就是 500 分,我们发现 REINFORCE 算法效果比较好,可以达到 405 分。接下来我们绘制训练过程中每一条轨迹的回报变化图。由于回报抖动比较大,往往会进行平滑处理。
接下来绘制训练过程中每一条轨迹的回报变化图。由于回报抖动比较大,往往会进行平滑处理:
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('REINFORCE on {}'.format(env_name))
plt.show()
mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('REINFORCE on {}'.format(env_name))
plt.show()


可以看到,随着收集到的轨迹越来越多,REINFORCE 算法有效地学习到了最优策略。不过,相比于前面的 DQN 算法,REINFORCE 算法使用了更多的序列,这是因为 REINFORCE 算法是一个在线策略算法,之前收集到的轨迹数据不会被再次利用。此外,REINFORCE 算法的性能也有一定程度的波动,这主要是因为每条采样轨迹的回报值波动比较大,这也是 REINFORCE 算法主要的不足。
参考资料
https://hrl.boyuai.com/chapter/2/策略梯度算法/
https://www.bilibili.com/video/BV1sd4y167NS/?p=47&spm_id_from=pageDriver&vd_source=f7563459deb4ecb3add61713c7d5d111
浙公网安备 33010602011771号