强化学习原理源码解读004:A3C (Asynchronous Advantage Actor-Critic)

目录

  A3C原理

  源码实现

  参考资料


 

 针对A2C的训练慢的问题,DeepMind团队于2016年提出了多进程版本的A2C,即A3C。

A3C原理

 

同时开多个worker,最后会把所有的经验集合在一起

一开始有一个全局的网络,假设参数是θ1

每一个worker使用一个cpu去跑,工作之前就把全局的参数拷贝过来

每一个actor和环境做互动,为了收集到各种各样的数据,制定策略收集比较多样性的数据

计算梯度

更新全局的参数为θ2

所有的actor都是并行的

可以再开一个进程用于测试全局模型的表现

 返回目录

 

源码实现


import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import torch.multiprocessing as mp
import time
import matplotlib.pyplot as plt
# Hyperparameters
n_train_processes = 3
learning_rate = 0.0002
update_interval = 5
gamma = 0.98
max_train_ep = 300
max_test_ep = 400


class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(4, 256)
        self.fc_pi = nn.Linear(256, 2)
        self.fc_v = nn.Linear(256, 1)

    def pi(self, x, softmax_dim=0):
        x = F.relu(self.fc1(x))
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=softmax_dim)
        return prob

    def v(self, x):
        x = F.relu(self.fc1(x))
        v = self.fc_v(x)
        return v


def train(global_model, rank):
    local_model = ActorCritic()
    local_model.load_state_dict(global_model.state_dict())

    optimizer = optim.Adam(global_model.parameters(), lr=learning_rate)

    env = gym.make('CartPole-v1')

    for n_epi in range(max_train_ep):
        done = False
        s = env.reset()
        while not done:
            s_lst, a_lst, r_lst = [], [], []
            for t in range(update_interval):
                prob = local_model.pi(torch.from_numpy(s).float())
                m = Categorical(prob)
                a = m.sample().item()
                s_prime, r, done, info = env.step(a)

                s_lst.append(s)
                a_lst.append([a])
                r_lst.append(r/100.0)

                s = s_prime
                if done:
                    break

            s_final = torch.tensor(s_prime, dtype=torch.float)
            R = 0.0 if done else local_model.v(s_final).item()
            td_target_lst = []
            for reward in r_lst[::-1]:
                R = gamma * R + reward
                td_target_lst.append([R])
            td_target_lst.reverse()

            s_batch, a_batch, td_target = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
                torch.tensor(td_target_lst)
            advantage = td_target - local_model.v(s_batch)

            pi = local_model.pi(s_batch, softmax_dim=1)
            pi_a = pi.gather(1, a_batch)
            loss = -torch.log(pi_a) * advantage.detach() + \
                F.smooth_l1_loss(local_model.v(s_batch), td_target.detach())

            optimizer.zero_grad()
            loss.mean().backward()
            for global_param, local_param in zip(global_model.parameters(), local_model.parameters()):
                global_param._grad = local_param.grad
            optimizer.step()
            local_model.load_state_dict(global_model.state_dict())

    env.close()
    print("Training process {} reached maximum episode.".format(rank))


def test(global_model):
    env = gym.make('CartPole-v1')
    score = 0.0
    print_interval = 20
    x = []
    y = []

    for n_epi in range(max_test_ep):
        done = False
        s = env.reset()
        while not done:
            prob = global_model.pi(torch.from_numpy(s).float())
            a = Categorical(prob).sample().item()
            s_prime, r, done, info = env.step(a)
            s = s_prime
            score += r

        if n_epi % print_interval == 0 and n_epi != 0:
            print("# of episode :{}, avg score : {:.1f}".format(
                n_epi, score/print_interval))
            x.append(n_epi)
            y.append(score / print_interval)
            score = 0.0
            time.sleep(1)
    env.close()
    plt.plot(x, y)
    plt.savefig('pic_saved/res_A3C.jpg')
    plt.show()


if __name__ == '__main__':
    global_model = ActorCritic()
    global_model.share_memory()

    processes = []
    for rank in range(n_train_processes + 1):  # + 1 for test process
        if rank == 0:
            p = mp.Process(target=test, args=(global_model,))
        else:
            p = mp.Process(target=train, args=(global_model, rank,))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()
View Code

 

横坐标表示训练轮数,纵坐标表示智能体得分的能力(满分500分),可以看到A3C在较短的时间内就能达到满分的水平,效果确实不错。

 返回目录

 

参考资料

https://github.com/seungeunrho/minimalRL

https://www.bilibili.com/video/BV1UE411G78S?from=search&seid=10996250814942853843

paper:Actor-Critic Algorithms

paper:Asynchronous Methods for Deep Reinforcement Learning

 返回目录

 

posted @ 2020-10-01 22:23  黎明程序员  阅读(925)  评论(0编辑  收藏  举报