Fork me on GitHub

用 Java 训练出一只“不死鸟”

作者:Kingyu & Lanking

FlappyBird 是 2013 年推出的一款手机游戏,因其简单的玩法但极度困难的设定迅速走红全网。随着深度学习(DL)与增强学习(RL)等前沿算法的发展,我们可以使用 Java 非常方便地训练出一个智能体来控制 Flappy Bird。

故事开始于《GitHub 上的大佬们打完招呼,会聊些什么?》,那么,今天我们就来一起看一下如何用 Java 训练出一个不死鸟。游戏项目我们使用了一个仅用 Java 基本类库编写的 FlappyBird 游戏。在训练方面,我们使用 DeepJavaLibrary 一个基于 Java 的深度学习框架来构建增强学习训练网络并进行训练。经过了差不多 300 万步(四小时)的训练后,小鸟已经可以获得最高 8000 多分的成绩,灵活穿梭于水管之间。

在本文中,我们将从原理开始一步一步实现增强学习算法并用它对游戏进行训练。如果任何一个时刻不清楚如何继续进行下去,可以参阅项目的源码。

项目地址:https://github.com/kingyuluk/RL-FlappyBird

增强学习(RL)的架构

在这一节会介绍主要用到的算法以及神经网络,帮助你更好的了解如何进行训练。本项目与 DeepLearningFlappyBird 使用了类似的方法进行训练。算法整体的架构是 Q-Learning + 卷积神经网络(CNN),把游戏每一帧的状态存储起来,即小鸟采用的动作和采用动作之后的效果,这些将作为卷积神经网络的训练数据。

CNN 训练简述

CNN 的输入数据为连续的 4 帧图像,我们将这图像 stack 起来作为小鸟当前的“observation”,图像会转换成灰度图以减少所需的训练资源。图像存储的矩阵形式是 (batch size, 4 (frames), 80 (width), 80 (height)) 数组里的元素就是当前帧的像素值,这些数据将输入到 CNN 后将输出 (batch size, 2) 的矩阵,矩阵的第二个维度就是小鸟 (振翅不采取动作) 对应的收益。

训练数据

在小鸟采取动作后,我们会得到 preObservation and currentObservation 即是两组 4 帧的连续的图像表示小鸟动作前和动作后的状态。然后我们将 preObservation, currentObservation, action, reward, terminal 组成的五元组作为一个 step 存进 replayBuffer 中。它是一个有限大小的训练数据集,他会随着最新的操作动态更新内容。

public void step(NDList action, boolean training) {
    if (action.singletonOrThrow().getInt(1) == 1) {
        bird.birdFlap();
    }
    stepFrame();
    NDList preObservation = currentObservation;
    currentObservation = createObservation(currentImg);
    FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(),
            preObservation, currentObservation, action, currentReward, currentTerminal);
    if (training) {
        replayBuffer.addStep(step);
    }
    if (gameState == GAME_OVER) {
        restartGame();
    }
}

训练的三个周期

训练分为 3 个不同的周期以更好地生成训练数据:

  • Observe(观察) 周期:随机产生训练数据
  • Explore (探索) 周期:随机与推理动作结合更新训练数据
  • Training (训练) 周期:推理动作主导产生新数据

通过这种训练模式,我们可以更好的达到预期效果。

处于 Explore 周期时,我们会根据权重选取随机的动作或使用模型推理出的动作来作为小鸟的动作。训练前期,随机动作的权重会非常大,因为模型的决策十分不准确 (甚至不如随机)。在训练后期时,随着模型学习的动作逐步增加,我们会不断增加模型推理动作的权重并最终使它成为主导动作。调节随机动作的参数叫做 epsilon 它会随着训练的过程不断变化。

public NDList chooseAction(RlEnv env, boolean training) {
    if (training && RandomUtils.random() < exploreRate.getNewValue(counter++)) {
        return env.getActionSpace().randomAction();
    } else return baseAgent.chooseAction(env, training);
}

训练逻辑

首先,我们会从 replayBuffer 中随机抽取一批数据作为作为训练集。然后将 preObservation 输入到神经网络得到所有行为的 reward(Q)作为预测值:

NDList QReward = trainer.forward(preInput);
NDList Q = new NDList(QReward.singletonOrThrow()
        .mul(actionInput.singletonOrThrow())
        .sum(new int[]{1}));

postObservation 同样会输入到神经网络,根据马尔科夫决策过程以及贝尔曼价值函数计算出所有行为的 reward(targetQ)作为真实值:

// 将 postInput 输入到神经网络中得到 targetQReward 是 (batchsize,2) 的矩阵。根据 Q-learning 的算法,每一次的 targetQ 需要根据当前环境是否结束算出不同的值,因此需要将每一个 step 的 targetQ 单独算出后再将 targetQ 堆积成 NDList。
NDList targetQReward = trainer.forward(postInput);
NDArray[] targetQValue = new NDArray[batchSteps.length]; 
for (int i = 0; i < batchSteps.length; i++) {
    if (batchSteps[i].isTerminal()) {
        targetQValue[i] = batchSteps[i].getReward();
    } else {
        targetQValue[i] = targetQReward.singletonOrThrow().get(i)
                .max()
                .mul(rewardDiscount)
                .add(rewardInput.singletonOrThrow().get(i));
    }
}
NDList targetQBatch = new NDList();
Arrays.stream(targetQValue).forEach(value -> targetQBatch.addAll(new NDList(value)));
NDList targetQ = new NDList(NDArrays.stack(targetQBatch, 0));

在训练结束时,计算 Q 和 targetQ 的损失值,并在 CNN 中更新权重。

卷积神经网络模型(CNN)

我们采用了采用了 3 个卷积层,4 个 relu 激活函数以及 2 个全连接层的神经网络架构。

layer input shape output shape
conv2d (batchSize, 4, 80, 80) (batchSize,4,20,20)
conv2d (batchSize, 4, 20 ,20) (batchSize, 32, 9, 9)
conv2d (batchSize, 32, 9, 9) (batchSize, 64, 7, 7)
linear (batchSize, 3136) (batchSize, 512)
linear (batchSize, 512) (batchSize, 2)

训练过程

DJL 的 RL 库中提供了非常方便的用于实现强化学习的接口:(RlEnv, RlAgent, ReplayBuffer)。

  • 实现 RlAgent 接口即可构建一个可以进行训练的智能体。
  • 在现有的游戏环境中实现 RlEnv 接口即可生成训练所需的数据。
  • 创建 ReplayBuffer 可以存储并动态更新训练数据。

在实现这些接口后,只需要调用 step 方法:

RlEnv.step(action, training);

这个方法会将 RlAgent 决策出的动作输入到游戏环境中获得反馈。我们可以在 RlEnv 中提供的 runEnviroment 方法中调用 step 方法,然后只需要重复执行 runEnvironment 方法,即可不断地生成用于训练的数据。

public Step[] runEnvironment(RlAgent agent, boolean training) {
    // run the game
    NDList action = agent.chooseAction(this, training);
    step(action, training);
    if (training) {
        batchSteps = this.getBatch();
    }
    return batchSteps;
}

我们将 ReplayBuffer 可存储的 step 数量设置为 50000,在 observe 周期我们会先向 replayBuffer 中存储 1000 个使用随机动作生成的 step,这样可以使智能体更快地从随机动作中学习。

在 explore 和 training 周期,神经网络会随机从 replayBuffer 中生成训练集并将它们输入到模型中训练。我们使用 Adam 优化器和 MSE 损失函数迭代神经网络。

神经网络输入预处理

首先将图像大小 resize 成 80x80 并转为灰度图,这有助于在不丢失信息的情况下提高训练速度。

public static NDArray imgPreprocess(BufferedImage observation) {
    return NDImageUtils.toTensor(
            NDImageUtils.resize(
                    ImageFactory.getInstance().fromImage(observation)
                    .toNDArray(NDManager.newBaseManager(),
                     Image.Flag.GRAYSCALE) ,80,80));
}

然后我们把连续的四帧图像作为一个输入,为了获得连续四帧的连续图像,我们维护了一个全局的图像队列保存游戏线程中的图像,每一次动作后替换掉最旧的一帧,然后把队列里的图像 stack 成一个单独的 NDArray。

public NDList createObservation(BufferedImage currentImg) {
    NDArray observation = GameUtil.imgPreprocess(currentImg);
    if (imgQueue.isEmpty()) {
        for (int i = 0; i < 4; i++) {
            imgQueue.offer(observation);
        }
        return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1));
    } else {
        imgQueue.remove();
        imgQueue.offer(observation);
        NDArray[] buf = new NDArray[4];
        int i = 0;
        for (NDArray nd : imgQueue) {
            buf[i++] = nd;
        }
        return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1));
    }
}

一旦以上部分完成,我们就可以开始训练了。训练优化为了获得最佳的训练性能,我们关闭了 GUI 以加快样本生成速度。并使用 Java 多线程将训练循环和样本生成循环分别在不同的线程中运行。

List<Callable<Object>> callables = new ArrayList<>(numOfThreads);
callables.add(new GeneratorCallable(game, agent, training));
if(training) {
    callables.add(new TrainerCallable(model, agent));
}

总结

这个模型在 NVIDIA T4 GPU 训练了大概 4 个小时,更新了 300 万步。训练后的小鸟已经可以完全自主控制动作灵活穿梭与管道之间。训练后的模型也同样上传到了仓库中供您测试。在此项目中 DJL 提供了强大的训练 API 以及模型库支持,使得在 Java 开发过程中得心应手。

本项目完整代码:https://github.com/kingyuluk/RL-FlappyBird

posted @ 2020-12-23 09:08  削微寒  阅读(1566)  评论(1编辑  收藏  举报