This repository contains JAX code for the paper

**Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation**

by Evgenii Nikishin, Romina Abachi, Rishabh Agarwal, and Pierre-Luc Bacon.

Model based reinforcement learning typically trains the dynamics and reward functions by minimizing the error of predictions. The error is only a proxy to maximizing the sum of rewards, the ultimate goal of the agent, leading to the objective mismatch. We propose an end-to-end algorithm called *Optimal Model Design* (OMD) that optimizes the returns directly for model learning. OMD leverages the implicit function theorem to optimize the model parameters and forms the following computational graph:

Please cite our work if you find it useful in your research:

```
@article{nikishin2021control,
title={Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation},
author={Nikishin, Evgenii and Abachi, Romina and Agarwal, Rishabh and Bacon, Pierre-Luc},
journal={arXiv preprint arXiv:2106.03273},
year={2021}
}
```

We assume that you use Python 3. To install the necessary dependencies, run the following commands:

```
1\. virtualenv ~/env_omd
2\. source ~/env_omd/bin/activate
3\. pip install -r requirements.txt
```

To use JAX with GPU, follow the official instructions. To install MuJoCo, check the instructions.

For historical reasons, the code is divided into 3 parts.

All results for the tabular experiments could be reproduced by running the `tabular.ipynb`

notebook.

To open the notebook in Google Colab, use this link.

To train the OMD agent on CartPole, use the following commands:

```
cd cartpole
python train.py --agent_type omd
```

We also provide the implementation of the corresponding MLE and VEP baselines. To train the agents, change the `--agent_type`

flag to `mle`

or `vep`

.

To train the OMD agent on MuJoCo HalfCheetah-v2, use the following commands:

```
cd mujoco
python train.py --config.algo=omd
```

To train the MLE baseline, change the `--config.algo`

flag to `mle`

.

- Tabular experiments are based on the code from the library for fixed points in JAX
- Code for MuJoCo is based on the implementation of SAC in JAX
- Code for CartPole reuses parts of the SAC implementation in PyTorch
- For experimentation, we used a moditication of the slurm runner

**Author:** evgenii-nikishin

**Download Link:** Download The Source Code

**Official Website:** https://github.com/evgenii-nikishin/omd

**License:** MIT

#machine-learning #data-science

11.30 GEEK