DrJAX: Scalable and Differentiable MapReduce Primitives in JAX
- URL: http://arxiv.org/abs/2403.07128v2
- Date: Wed, 17 Jul 2024 21:41:39 GMT
- Title: DrJAX: Scalable and Differentiable MapReduce Primitives in JAX
- Authors: Keith Rush, Zachary Charles, Zachary Garrett, Sean Augenstein, Nicole Mitchell,
- Abstract summary: DrJAX is a library designed to support large-scale distributed and parallel machine learning algorithms.
DrJAX embeds building blocks for MapReduce computations as primitives in JAX.
DrJAX computations can be translated directly to XLA HLO, enabling flexible integration with a wide array of ML training platforms.
- Score: 9.676195490442367
- License: http://creativecommons.org/licenses/by/4.0/
- Abstract: We present DrJAX, a JAX-based library designed to support large-scale distributed and parallel machine learning algorithms that use MapReduce-style operations. DrJAX leverages JAX's sharding mechanisms to enable native targeting of TPUs and state-of-the-art JAX runtimes, including Pathways. DrJAX embeds building blocks for MapReduce computations as primitives in JAX. This enables three key benefits. First, DrJAX computations can be translated directly to XLA HLO, enabling flexible integration with a wide array of ML training platforms. Second, DrJAX computations are fully differentiable. Last, DrJAX computations can be interpreted out to existing batch-processing compute systems, including traditional MapReduce systems like Apache Beam and cross-device compute systems like those powering federated learning applications. We show that DrJAX provides an easily programmable, performant, and scalable framework for parallelized algorithm development. DrJAX is available at \url{https://github.com/google-research/google-research/tree/master/drjax}.
Related papers
- Benchmarking Predictive Coding Networks -- Made Simple [48.652114040426625]
We first propose a library called PCX, whose focus lies on performance and simplicity.
We use PCX to implement a large set of benchmarks for the community to use for their experiments.
arXiv Detail & Related papers (2024-07-01T10:33:44Z) - JaxDecompiler: Redefining Gradient-Informed Software Design [0.0]
JaxDecompiler is a tool that transforms any JAX function into an editable Python code.
This article introduces JaxDecompiler, a tool that transforms any JAX function into an editable Python code.
arXiv Detail & Related papers (2024-03-14T20:32:31Z) - BlackJAX: Composable Bayesian inference in JAX [8.834500692867671]
BlackJAX is a library implementing sampling and variational inference algorithms.
It is written in Python, using JAX to compile and run NumpPy-like samplers and variational methods on CPUs, GPUs, and TPUs.
arXiv Detail & Related papers (2024-02-16T16:21:02Z) - JaxMARL: Multi-Agent RL Environments and Algorithms in JAX [105.343918678781]
We present JaxMARL, the first open-source, Python-based library that combines GPU-enabled efficiency with support for a large number of commonly used MARL environments.
Our experiments show that, in terms of wall clock time, our JAX-based training pipeline is around 14 times faster than existing approaches.
We also introduce and benchmark SMAX, a JAX-based approximate reimplementation of the popular StarCraft Multi-Agent Challenge.
arXiv Detail & Related papers (2023-11-16T18:58:43Z) - JaxPruner: A concise library for sparsity research [46.153423603424]
JaxPruner is an open-source library for sparse neural network research.
It implements popular pruning and sparse training algorithms with minimal memory and latency overhead.
arXiv Detail & Related papers (2023-04-27T10:45:30Z) - SequeL: A Continual Learning Library in PyTorch and JAX [50.33956216274694]
SequeL is a library for Continual Learning that supports both PyTorch and JAX frameworks.
It provides a unified interface for a wide range of Continual Learning algorithms, including regularization-based approaches, replay-based approaches, and hybrid approaches.
We release SequeL as an open-source library, enabling researchers and developers to easily experiment and extend the library for their own purposes.
arXiv Detail & Related papers (2023-04-21T10:00:22Z) - Extending Compositional Attention Networks for Social Reasoning in
Videos [84.12658971655253]
We propose a novel deep architecture for the task of reasoning about social interactions in videos.
We leverage the multi-step reasoning capabilities of Compositional Attention Networks (MAC), and propose a multimodal extension (MAC-X)
arXiv Detail & Related papers (2022-10-03T19:03:01Z) - SMORE: Knowledge Graph Completion and Multi-hop Reasoning in Massive
Knowledge Graphs [147.73127662757335]
We present scalable Multi-hOp REasoning (SMORE), the first general framework for both single-hop and multi-hop reasoning in Knowledge Graphs (KGs)
Using a single machine SMORE can perform multi-hop reasoning in Freebase KG (86M entities, 338M edges), which is 1,500x larger than previously considered KGs.
SMORE increases throughput (i.e., training speed) over prior multi-hop KG frameworks by 2.2x with minimal GPU memory requirements.
arXiv Detail & Related papers (2021-10-28T05:02:33Z) - SymJAX: symbolic CPU/GPU/TPU programming [9.868558660605995]
SymJAX is a symbolic programming version of JAX simplifying graph input/output/updates and providing additional functionalities for general machine learning and deep learning applications.
From an user perspective SymJAX provides a la Theano experience with fast graph optimization/compilation and broad hardware support, along with Lasagne-like deep learning functionalities.
arXiv Detail & Related papers (2020-05-21T13:37:25Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.