Flappy Bird is a mobile game that was introduced in 2013 which became super popular because of its simple way to play (flap/no-flap). With the growth of Deep Learning (DL) and Reinforcement Learning (RL), we can now train an AI agent to control the Flappy Bird actions. Today, we will look at the process to create an AI agent using Java. For the game itself, we used a simple open-source Flappy Bird game using Java. For training, we used Deep Java Library (DJL), a deep learning framework based on Java, to build the training network and RL algorithm.

In this article, we will start from the basis of RL and walk through the key components to build the training architecture. If at anytime you cannot follow our code and would like to try the game, you can refer to our RL repo.

The RL Architecture

In this section, we will introduce some major algorithm and networks we used to help you better understand how we trained the model. This project used a similar approach with DeepLearningFlappyBird, a Python Flappy Bird RL implementation. The main RL architecture is Q-Learning, a Convolutional Neural Network (CNN). In each game action stage, we store the current state of the bird, the action the agent took, and the next state of the bird. These are treated as the training data of the CNN.

CNN Training Overview

The input data for training is a continuous four-frame image. We stack these four images to form an “observation” of the bird. The observation here means time-series data represented by a series of images. The image itself is gray-scaled to reduce the training load. The array representation of the image is (batch size, 4 (frames), 80 (width), 80 (height)). Each element of the array represents the pixel value of each frame. These data are fed into the CNN and compute to an output (batch size, 2). The second dimension of the output represents the confidence of the next action (flap, no-flap).

We use the actual action recorded against the output confidence to compute the loss. After that, the model will be updated through back propagation and parameter optimization. The data used for training are continuously updated by the agent to achieve a better result.

Training data

After the action stage, we create preObservation and currentObservation. As mentioned before, these are just images that represent a series of movement. After that, we just put preObservation, currentObservation, action, reward and terminal together as a step stored into the replayBuffer. The replayBuffer is the training dataset with a limited size and dynamically updated with the latest actions.

public void step(NDList action, boolean training) {
    if (action.singletonOrThrow().getInt(1) == 1) {
        bird.birdFlap();
    }
    stepFrame();
    NDList preObservation = currentObservation;
    currentObservation = createObservation(currentImg);
    FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(),
            preObservation, currentObservation, action, currentReward, currentTerminal);
    if (training) {
        replayBuffer.addStep(step);
    }
    if (gameState == GAME_OVER) {
        restartGame();
    }
}

Three stages of RL

There are three different stages of RL used to generate better training data:

  • Observation Stage: Most actions are random with a small portion of actions coming from the AI agent
  • Exploration Stage: Random actions and AI agent actions are combined
  • Training Stage: Actions are primarily produced by the AI agent

During the exploration stage, we will choose between random action and AI agent action for the bird. At the beginning of training random actions are primarily used, since the actions generated by the AI agent are generally poor. After that, we gradually increase the probability of taking the AI agent’s action until it ultimately becomes the only decision maker. The parameter that is used to adjust the ratio between random and AI agent actions is called epsilon. It will change constantly through the training process.

public NDList chooseAction(RlEnv env, boolean training) {
    if (training && RandomUtils.random() < exploreRate.getNewValue(counter++)) {
        return env.getActionSpace().randomAction();
    } else return baseAgent.chooseAction(env, training);
}

#java #reinforcement-learning #machine-learning #ai

Train “undying” Flappy Bird using Reinforcement Learning on Java
1.30 GEEK