BlackJAX: Composable Bayesian inference in JAX
- URL: http://arxiv.org/abs/2402.10797v2
- Date: Thu, 22 Feb 2024 10:58:50 GMT
- Title: BlackJAX: Composable Bayesian inference in JAX
- Authors: Alberto Cabezas, Adrien Corenflos, Junpeng Lao, R\'emi Louf, Antoine
Carnec, Kaustubh Chaudhari, Reuben Cohn-Gordon, Jeremie Coullon, Wei Deng,
Sam Duffield, Gerardo Dur\'an-Mart\'in, Marcin Elantkowski, Dan
Foreman-Mackey, Michele Gregori, Carlos Iguaran, Ravin Kumar, Martin Lysy,
Kevin Murphy, Juan Camilo Orduz, Karm Patel, Xi Wang, Rob Zinkov
- Abstract summary: BlackJAX is a library implementing sampling and variational inference algorithms.
It is written in Python, using JAX to compile and run NumpPy-like samplers and variational methods on CPUs, GPUs, and TPUs.
- Score: 8.834500692867671
- License: http://creativecommons.org/licenses/by/4.0/
- Abstract: BlackJAX is a library implementing sampling and variational inference
algorithms commonly used in Bayesian computation. It is designed for ease of
use, speed, and modularity by taking a functional approach to the algorithms'
implementation. BlackJAX is written in Python, using JAX to compile and run
NumpPy-like samplers and variational methods on CPUs, GPUs, and TPUs. The
library integrates well with probabilistic programming languages by working
directly with the (un-normalized) target log density function. BlackJAX is
intended as a collection of low-level, composable implementations of basic
statistical 'atoms' that can be combined to perform well-defined Bayesian
inference, but also provides high-level routines for ease of use. It is
designed for users who need cutting-edge methods, researchers who want to
create complex sampling methods, and people who want to learn how these work.
Related papers
- Benchmarking Predictive Coding Networks -- Made Simple [48.652114040426625]
We first propose a library called PCX, whose focus lies on performance and simplicity.
We use PCX to implement a large set of benchmarks for the community to use for their experiments.
arXiv Detail & Related papers (2024-07-01T10:33:44Z) - A Python library for efficient computation of molecular fingerprints [0.0]
We create a Python library that computes molecular fingerprints efficiently and delivers an interface that is comprehensive.
The library enables the user to perform computation on large datasets using parallelism.
We show that using molecular fingerprints we can achieve results comparable to state-of-the-art ML solutions.
arXiv Detail & Related papers (2024-03-27T19:02:09Z) - SynJax: Structured Probability Distributions for JAX [3.4447129363520337]
SynJax provides efficient vectorized implementation of inference algorithms for structured distributions.
We can build large-scale differentiable models that explicitly model structure in the data.
arXiv Detail & Related papers (2023-08-07T04:20:38Z) - Provably Efficient Representation Learning with Tractable Planning in
Low-Rank POMDP [81.00800920928621]
We study representation learning in partially observable Markov Decision Processes (POMDPs)
We first present an algorithm for decodable POMDPs that combines maximum likelihood estimation (MLE) and optimism in the face of uncertainty (OFU)
We then show how to adapt this algorithm to also work in the broader class of $gamma$-observable POMDPs.
arXiv Detail & Related papers (2023-06-21T16:04:03Z) - JaxPruner: A concise library for sparsity research [46.153423603424]
JaxPruner is an open-source library for sparse neural network research.
It implements popular pruning and sparse training algorithms with minimal memory and latency overhead.
arXiv Detail & Related papers (2023-04-27T10:45:30Z) - SequeL: A Continual Learning Library in PyTorch and JAX [50.33956216274694]
SequeL is a library for Continual Learning that supports both PyTorch and JAX frameworks.
It provides a unified interface for a wide range of Continual Learning algorithms, including regularization-based approaches, replay-based approaches, and hybrid approaches.
We release SequeL as an open-source library, enabling researchers and developers to easily experiment and extend the library for their own purposes.
arXiv Detail & Related papers (2023-04-21T10:00:22Z) - JAXFit: Trust Region Method for Nonlinear Least-Squares Curve Fitting on
the GPU [0.0]
We implement a trust region method on the GPU for nonlinear least squares curve fitting problems using a new deep learning Python library called JAX.
Our open source package, JAXFit, works for both unconstrained and constrained curve fitting problems.
arXiv Detail & Related papers (2022-08-25T16:13:29Z) - Efficient algorithms for implementing incremental proximal-point methods [0.3263412255491401]
In machine learning, model training algorithms observe a small portion of the training set in each computational step.
Several streams of research attempt to exploit more information about the cost functions than just their gradients via the well-known proximal operators.
We devise a novel algorithmic framework, which exploits convex duality theory to achieve both algorithmic efficiency and software modularity of proximal operator.
arXiv Detail & Related papers (2022-05-03T12:43:26Z) - Exact Paired-Permutation Testing for Structured Test Statistics [67.71280539312536]
We provide an efficient exact algorithm for the paired-permutation test for a family of structured test statistics.
Our exact algorithm was $10$x faster than the Monte Carlo approximation with $20000$ samples on a common dataset.
arXiv Detail & Related papers (2022-05-03T11:00:59Z) - Torch-Struct: Deep Structured Prediction Library [138.5262350501951]
We introduce Torch-Struct, a library for structured prediction.
Torch-Struct includes a broad collection of probabilistic structures accessed through a simple and flexible distribution-based API.
arXiv Detail & Related papers (2020-02-03T16:43:02Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.