JAXbind: Bind any function to JAX
- URL: http://arxiv.org/abs/2403.08847v2
- Date: Thu, 27 Jun 2024 07:45:28 GMT
- Title: JAXbind: Bind any function to JAX
- Authors: Jakob Roth, Martin Reinecke, Gordian Edenhofer,
- Abstract summary: JAXbind provides an easy-to-use Python interface for defining custom, so-called JAX primitives.
JAXbind allows a user to interface the JAX function engine with custom derivatives and rules, enabling all JAX transformations for the custom primitive.
- Score: 0.0
- License: http://arxiv.org/licenses/nonexclusive-distrib/1.0/
- Abstract: JAX is widely used in machine learning and scientific computing, the latter of which often relies on existing high-performance code that we would ideally like to incorporate into JAX. Reimplementing the existing code in JAX is often impractical and the existing interface in JAX for binding custom code either limits the user to a single Jacobian product or requires deep knowledge of JAX and its C++ backend for general Jacobian products. With JAXbind we drastically reduce the effort required to bind custom functions implemented in other programming languages with full support for Jacobian-vector products and vector-Jacobian products to JAX. Specifically, JAXbind provides an easy-to-use Python interface for defining custom, so-called JAX primitives. Via JAXbind, any function callable from Python can be exposed as a JAX primitive. JAXbind allows a user to interface the JAX function transformation engine with custom derivatives and batching rules, enabling all JAX transformations for the custom primitive.
Related papers
- JailbreakBench: An Open Robustness Benchmark for Jailbreaking Large Language Models [123.66104233291065]
Jailbreak attacks cause large language models (LLMs) to generate harmful, unethical, or otherwise objectionable content.
evaluating these attacks presents a number of challenges, which the current collection of benchmarks and evaluation techniques do not adequately address.
JailbreakBench is an open-sourced benchmark with the following components.
arXiv Detail & Related papers (2024-03-28T02:44:02Z) - JaxDecompiler: Redefining Gradient-Informed Software Design [0.0]
JaxDecompiler is a tool that transforms any JAX function into an editable Python code.
This article introduces JaxDecompiler, a tool that transforms any JAX function into an editable Python code.
arXiv Detail & Related papers (2024-03-14T20:32:31Z) - DrJAX: Scalable and Differentiable MapReduce Primitives in JAX [9.676195490442367]
DrJAX is a library designed to support large-scale distributed and parallel machine learning algorithms.
DrJAX embeds building blocks for MapReduce computations as primitives in JAX.
DrJAX computations can be translated directly to XLA HLO, enabling flexible integration with a wide array of ML training platforms.
arXiv Detail & Related papers (2024-03-11T19:51:01Z) - Executable Code Actions Elicit Better LLM Agents [76.95566120678787]
This work proposes to use Python code to consolidate Large Language Model (LLM) agents' actions into a unified action space (CodeAct)
integrated with a Python interpreter, CodeAct can execute code actions and dynamically revise prior actions or emit new actions upon new observations through multi-turn interactions.
The encouraging performance of CodeAct motivates us to build an open-source LLM agent that interacts with environments by executing interpretable code and collaborates with users using natural language.
arXiv Detail & Related papers (2024-02-01T21:38:58Z) - Automatic Functional Differentiation in JAX [8.536145202129827]
We extend JAX with the capability to automatically differentiate higher-order functions (functionals and operators)
By representing functions as a generalization of arrays, we seamlessly use JAX's existing primitive system to implement higher-order functions.
arXiv Detail & Related papers (2023-11-30T17:23:40Z) - JaxMARL: Multi-Agent RL Environments and Algorithms in JAX [105.343918678781]
We present JaxMARL, the first open-source, Python-based library that combines GPU-enabled efficiency with support for a large number of commonly used MARL environments.
Our experiments show that, in terms of wall clock time, our JAX-based training pipeline is around 14 times faster than existing approaches.
We also introduce and benchmark SMAX, a JAX-based approximate reimplementation of the popular StarCraft Multi-Agent Challenge.
arXiv Detail & Related papers (2023-11-16T18:58:43Z) - FORFIS: A forest fire firefighting simulation tool for education and
research [90.40304110009733]
We present a forest fire firefighting simulation tool named FORFIS that is implemented in Python.
Our tool is published underv3 license and comes with a GUI as well as additional output functionality.
arXiv Detail & Related papers (2023-05-29T09:14:38Z) - JaxPruner: A concise library for sparsity research [46.153423603424]
JaxPruner is an open-source library for sparse neural network research.
It implements popular pruning and sparse training algorithms with minimal memory and latency overhead.
arXiv Detail & Related papers (2023-04-27T10:45:30Z) - SequeL: A Continual Learning Library in PyTorch and JAX [50.33956216274694]
SequeL is a library for Continual Learning that supports both PyTorch and JAX frameworks.
It provides a unified interface for a wide range of Continual Learning algorithms, including regularization-based approaches, replay-based approaches, and hybrid approaches.
We release SequeL as an open-source library, enabling researchers and developers to easily experiment and extend the library for their own purposes.
arXiv Detail & Related papers (2023-04-21T10:00:22Z) - Equinox: neural networks in JAX via callable PyTrees and filtered
transformations [4.264192013842096]
JAX and PyTorch are two popular Python autodifferentiation frameworks.
JAX is based around pure functions and functional programming.
PyTorch has popularised the use of an object-oriented (OO) class-based syntax for defining parameterised functions.
arXiv Detail & Related papers (2021-10-30T14:08:56Z) - WAX-ML: A Python library for machine learning and feedback loops on
streaming data [0.0]
WAX-ML is a research-oriented Python library.
It provides tools to design powerful machine learning algorithms.
It strives to complement JAX with tools dedicated to time series.
arXiv Detail & Related papers (2021-06-11T17:42:02Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.