JaxMARL: Multi-Agent RL Environments and Algorithms in JAX
- URL: http://arxiv.org/abs/2311.10090v5
- Date: Sat, 02 Nov 2024 22:26:28 GMT
- Title: JaxMARL: Multi-Agent RL Environments and Algorithms in JAX
- Authors: Alexander Rutherford, Benjamin Ellis, Matteo Gallici, Jonathan Cook, Andrei Lupu, Gardar Ingvarsson, Timon Willi, Ravi Hammond, Akbir Khan, Christian Schroeder de Witt, Alexandra Souly, Saptarashmi Bandyopadhyay, Mikayel Samvelyan, Minqi Jiang, Robert Tjarko Lange, Shimon Whiteson, Bruno Lacerda, Nick Hawes, Tim Rocktaschel, Chris Lu, Jakob Nicolaus Foerster,
- Abstract summary: We present JaxMARL, the first open-source, Python-based library that combines GPU-enabled efficiency with support for a large number of commonly used MARL environments.
Our experiments show that, in terms of wall clock time, our JAX-based training pipeline is around 14 times faster than existing approaches.
We also introduce and benchmark SMAX, a JAX-based approximate reimplementation of the popular StarCraft Multi-Agent Challenge.
- Score: 105.343918678781
- License:
- Abstract: Benchmarks are crucial in the development of machine learning algorithms, with available environments significantly influencing reinforcement learning (RL) research. Traditionally, RL environments run on the CPU, which limits their scalability with typical academic compute. However, recent advancements in JAX have enabled the wider use of hardware acceleration, enabling massively parallel RL training pipelines and environments. While this has been successfully applied to single-agent RL, it has not yet been widely adopted for multi-agent scenarios. In this paper, we present JaxMARL, the first open-source, Python-based library that combines GPU-enabled efficiency with support for a large number of commonly used MARL environments and popular baseline algorithms. Our experiments show that, in terms of wall clock time, our JAX-based training pipeline is around 14 times faster than existing approaches, and up to 12500x when multiple training runs are vectorized. This enables efficient and thorough evaluations, potentially alleviating the evaluation crisis in the field. We also introduce and benchmark SMAX, a JAX-based approximate reimplementation of the popular StarCraft Multi-Agent Challenge, which removes the need to run the StarCraft II game engine. This not only enables GPU acceleration, but also provides a more flexible MARL environment, unlocking the potential for self-play, meta-learning, and other future applications in MARL. The code is available at https://github.com/flairox/jaxmarl.
Related papers
- NAVIX: Scaling MiniGrid Environments with JAX [17.944645332888335]
We introduce NAVIX, a re-implementation of MiniGrid in JAX.
NAVIX achieves over 200 000x speed improvements in batch mode, supporting up to 2048 agents in parallel on a single Nvidia A100 80 GB.
This reduces experiment times from one week to 15 minutes, promoting faster design and more scalable RL model development.
arXiv Detail & Related papers (2024-07-28T04:39:18Z) - XuanCe: A Comprehensive and Unified Deep Reinforcement Learning Library [18.603206638756056]
XuanCe is a comprehensive and unified deep reinforcement learning (DRL) library.
XuanCe offers a wide range of functionalities, including over 40 classical DRL and multi-agent DRL algorithms.
XuanCe is open-source and can be accessed at https://agi-brain.com/agi-brain/xuance.git.
arXiv Detail & Related papers (2023-12-25T14:45:39Z) - JaxPruner: A concise library for sparsity research [46.153423603424]
JaxPruner is an open-source library for sparse neural network research.
It implements popular pruning and sparse training algorithms with minimal memory and latency overhead.
arXiv Detail & Related papers (2023-04-27T10:45:30Z) - SequeL: A Continual Learning Library in PyTorch and JAX [50.33956216274694]
SequeL is a library for Continual Learning that supports both PyTorch and JAX frameworks.
It provides a unified interface for a wide range of Continual Learning algorithms, including regularization-based approaches, replay-based approaches, and hybrid approaches.
We release SequeL as an open-source library, enabling researchers and developers to easily experiment and extend the library for their own purposes.
arXiv Detail & Related papers (2023-04-21T10:00:22Z) - marl-jax: Multi-Agent Reinforcement Leaning Framework [7.064383217512461]
We present marl-jax, a multi-agent reinforcement learning software package for training and evaluating social generalization of the agents.
The package is designed for training a population of agents in multi-agent environments and evaluating their ability to generalize to diverse background agents.
arXiv Detail & Related papers (2023-03-24T05:05:01Z) - Going faster to see further: GPU-accelerated value iteration and
simulation for perishable inventory control using JAX [5.856836693166898]
We use the Python library JAX to implement value iteration and simulators of the underlying Markov decision processes in a high-level API.
Our method can extend use of value iteration to settings that were previously considered infeasible or impractical.
We compare the performance of the optimal replenishment policies to policies, fitted using simulation optimization in JAX which allowed the parallel evaluation of multiple candidate policy parameters.
arXiv Detail & Related papers (2023-03-19T14:20:44Z) - FlexGen: High-Throughput Generative Inference of Large Language Models
with a Single GPU [89.2451963569343]
FlexGen is a generation engine for running large language model (LLM) inference on a single commodity GPU.
When running OPT-175B on a single 16GB GPU, FlexGen achieves significantly higher throughput compared to state-of-the-art offloading systems.
On the HELM benchmark, FlexGen can benchmark a 30B model with a 16GB GPU on 7 representative sub-scenarios in 21 hours.
arXiv Detail & Related papers (2023-03-13T05:19:28Z) - EnvPool: A Highly Parallel Reinforcement Learning Environment Execution
Engine [69.47822647770542]
parallel environment execution is often the slowest part of the whole system but receives little attention.
With a curated design for paralleling RL environments, we have improved the RL environment simulation speed across different hardware setups.
On a high-end machine, EnvPool achieves 1 million frames per second for the environment execution on Atari environments and 3 million frames per second on MuJoCo environments.
arXiv Detail & Related papers (2022-06-21T17:36:15Z) - ElegantRL-Podracer: Scalable and Elastic Library for Cloud-Native Deep
Reinforcement Learning [141.58588761593955]
We present a library ElegantRL-podracer for cloud-native deep reinforcement learning.
It efficiently supports millions of cores to carry out massively parallel training at multiple levels.
At a low-level, each pod simulates agent-environment interactions in parallel by fully utilizing nearly 7,000 GPU cores in a single GPU.
arXiv Detail & Related papers (2021-12-11T06:31:21Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.