强化学习代码实战-07 Actor-Critic 算法
Actor(策略网络)和 Critic(价值网络)
- Actor 要做的是与环境交互,并在 Critic 价值函数的指导下用策略梯度学习一个更好的策略。
- Critic 要做的是通过 Actor 与环境交互收集的数据学习一个价值函数,这个价值函数会用于判断在当前状态什么动作是好的,什么动作不是好的,进而帮助 Actor 进行策略更新。
import random import gym import torch import numpy as np from matplotlib import pyplot as plt from IPython import display env = gym.make("CartPole-v0") # 智能体状态 state = env.reset() # 动作空间 actions = env.action_space.n print(state, actions) # Actor使用策略梯度更新(接收状态,输出策略),Critic使用价值函数更新(接收状态,输出价值) actor_model = torch.nn.Sequential(torch.nn.Linear(4, 128), torch.nn.ReLU(), torch.nn.Linear(128, 2), torch.nn.Softmax(dim=1)) critic_model = torch.nn.Sequential(torch.nn.Linear(4, 128), torch.nn.ReLU(), torch.nn.Linear(128, 1)) def get_action(state): state = torch.FloatTensor(state).reshape(1,4) prob = actor_model(state) action = random.choices(range(2), weights=prob[0].tolist(), k=1)[0] return action def get_data(): state = env.reset() states = [] actions = [] rewards = [] next_states = [] dones = [] done = False while not done: action = get_action(state) next_state, reward, done, _ = env.step(action) states.append(state) actions.append(action) rewards.append(reward) next_states.append(next_state) dones.append(done) state = next_state states = torch.FloatTensor(states).reshape(-1, 4) rewards = torch.FloatTensor(rewards).reshape(-1, 1) actions = torch.LongTensor(actions).reshape(-1, 1) next_states = torch.FloatTensor(next_states).reshape(-1, 4) dones = torch.LongTensor(dones).reshape(-1, 1) return states, rewards, actions, next_states, dones def test(): state = env.reset() rewards_sum = 0 done = False while not done: action = get_action(state) state, reward, done, _ = env.step(action) rewards_sum += reward return rewards_sum def train(): optimizer = torch.optim.Adam(actor_model.parameters(), lr=1e-3) optimizer_td = torch.optim.Adam(critic_model.parameters(), lr=1e-2) # 玩N局游戏,每局游戏训练一次 for epoch in range(1000): states, rewards, actions, next_states, dones = get_data() # 分batch优化 current_values = critic_model(states) next_state_values = critic_model(next_states) * 0.98 next_state_values *= (1 - dones) next_values = rewards + next_state_values # 时序差分误差.单纯使用值,不反向传播梯度. detach:阻断反向梯度传播 delta = (next_values - current_values).detach() # actor重新评估动作计算得分 probs = actor_model(states) probs = probs.gather(dim=1, index=actions) actor_loss = (-probs.log() * delta).mean() # 时序差分loss。均方误差 critic_loss = torch.nn.MSELoss()(current_values, next_values.detach()) optimizer.zero_grad() actor_loss.backward() optimizer.step() optimizer_td.zero_grad() critic_loss.backward() optimizer_td.step() if epoch % 100 == 0: result = sum([test() for _ in range(50)]) / 50 print(epoch, result)
时刻记着自己要成为什么样的人!