MPAX: Mathematical Programming in JAX
- URL: http://arxiv.org/abs/2412.09734v2
- Date: Thu, 06 Feb 2025 18:36:50 GMT
- Title: MPAX: Mathematical Programming in JAX
- Authors: Haihao Lu, Zedong Peng, Jinwen Yang,
- Abstract summary: MPAX is a versatile and efficient toolbox for integrating linear programming into machine learning.
It provides native support for hardware accelerations along with features like batch solving, auto-differentiation, and device parallelism.
- Score: 4.320198313490604
- License:
- Abstract: This paper presents MPAX (Mathematical Programming in JAX), a versatile and efficient toolbox for integrating linear programming (LP) into machine learning workflows. MPAX implemented the state-of-the-art first-order methods, restarted average primal-dual hybrid gradient and reflected restarted Halpern primal-dual hybrid gradient, to solve LPs in JAX. This provides native support for hardware accelerations along with features like batch solving, auto-differentiation, and device parallelism. Extensive numerical experiments demonstrate the advantages of MPAX over existing solvers. The solver is available at https://github.com/MIT-Lu-Lab/MPAX.
Related papers
- Gradients of Functions of Large Matrices [18.361820028457718]
We show how to differentiate workhorses of numerical linear algebra efficiently.
We derive previously unknown adjoint systems for Lanczos and Arnoldi iterations, implement them in JAX, and show that the resulting code can compete with Diffrax.
All this is achieved without any problem-specific code optimisation.
arXiv Detail & Related papers (2024-05-27T15:39:45Z) - Enabling High-Sparsity Foundational Llama Models with Efficient Pretraining and Deployment [56.44025052765861]
Large language models (LLMs) have revolutionized Natural Language Processing (NLP), but their size creates computational bottlenecks.
We introduce a novel approach to create accurate, sparse foundational versions of performant LLMs.
We show a total speedup on CPUs for sparse-quantized LLaMA models of up to 8.6x.
arXiv Detail & Related papers (2024-05-06T16:03:32Z) - JAX-SPH: A Differentiable Smoothed Particle Hydrodynamics Framework [8.977530522693444]
Particle-based fluid simulations have emerged as a powerful tool for solving the Navier-Stokes equations.
Recent addition of machine learning methods to the toolbox for solving such problems is pushing the boundary of the quality vs. speed tradeoff.
We lead the way to Lagrangian fluid simulators compatible with deep learning frameworks, and propose JAX-SPH.
arXiv Detail & Related papers (2024-03-07T18:53:53Z) - BlackJAX: Composable Bayesian inference in JAX [8.834500692867671]
BlackJAX is a library implementing sampling and variational inference algorithms.
It is written in Python, using JAX to compile and run NumpPy-like samplers and variational methods on CPUs, GPUs, and TPUs.
arXiv Detail & Related papers (2024-02-16T16:21:02Z) - JaxMARL: Multi-Agent RL Environments and Algorithms in JAX [105.343918678781]
We present JaxMARL, the first open-source, Python-based library that combines GPU-enabled efficiency with support for a large number of commonly used MARL environments.
Our experiments show that, in terms of wall clock time, our JAX-based training pipeline is around 14 times faster than existing approaches.
We also introduce and benchmark SMAX, a JAX-based approximate reimplementation of the popular StarCraft Multi-Agent Challenge.
arXiv Detail & Related papers (2023-11-16T18:58:43Z) - In Situ Framework for Coupling Simulation and Machine Learning with
Application to CFD [51.04126395480625]
Recent years have seen many successful applications of machine learning (ML) to facilitate fluid dynamic computations.
As simulations grow, generating new training datasets for traditional offline learning creates I/O and storage bottlenecks.
This work offers a solution by simplifying this coupling and enabling in situ training and inference on heterogeneous clusters.
arXiv Detail & Related papers (2023-06-22T14:07:54Z) - JaxPruner: A concise library for sparsity research [46.153423603424]
JaxPruner is an open-source library for sparse neural network research.
It implements popular pruning and sparse training algorithms with minimal memory and latency overhead.
arXiv Detail & Related papers (2023-04-27T10:45:30Z) - Kernel methods through the roof: handling billions of points efficiently [94.31450736250918]
Kernel methods provide an elegant and principled approach to nonparametric learning, but so far could hardly be used in large scale problems.
Recent advances have shown the benefits of a number of algorithmic ideas, for example combining optimization, numerical linear algebra and random projections.
Here, we push these efforts further to develop and test a solver that takes full advantage of GPU hardware.
arXiv Detail & Related papers (2020-06-18T08:16:25Z) - MPLP++: Fast, Parallel Dual Block-Coordinate Ascent for Dense Graphical
Models [96.1052289276254]
This work introduces a new MAP-solver, based on the popular Dual Block-Coordinate Ascent principle.
Surprisingly, by making a small change to the low-performing solver, we derive the new solver MPLP++ that significantly outperforms all existing solvers by a large margin.
arXiv Detail & Related papers (2020-04-16T16:20:53Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.