JaxPruner: A concise library for sparsity research
- URL: http://arxiv.org/abs/2304.14082v3
- Date: Tue, 19 Dec 2023 02:58:39 GMT
- Title: JaxPruner: A concise library for sparsity research
- Authors: Joo Hyung Lee, Wonpyo Park, Nicole Mitchell, Jonathan Pilault, Johan
Obando-Ceron, Han-Byul Kim, Namhoon Lee, Elias Frantar, Yun Long, Amir
Yazdanbakhsh, Shivani Agrawal, Suvinay Subramanian, Xin Wang, Sheng-Chun Kao,
Xingyao Zhang, Trevor Gale, Aart Bik, Woohyun Han, Milen Ferev, Zhonglin Han,
Hong-Seok Kim, Yann Dauphin, Gintare Karolina Dziugaite, Pablo Samuel Castro,
Utku Evci
- Abstract summary: JaxPruner is an open-source library for sparse neural network research.
It implements popular pruning and sparse training algorithms with minimal memory and latency overhead.
- Score: 46.153423603424
- License: http://creativecommons.org/licenses/by/4.0/
- Abstract: This paper introduces JaxPruner, an open-source JAX-based pruning and sparse
training library for machine learning research. JaxPruner aims to accelerate
research on sparse neural networks by providing concise implementations of
popular pruning and sparse training algorithms with minimal memory and latency
overhead. Algorithms implemented in JaxPruner use a common API and work
seamlessly with the popular optimization library Optax, which, in turn, enables
easy integration with existing JAX based libraries. We demonstrate this ease of
integration by providing examples in four different codebases: Scenic, t5x,
Dopamine and FedJAX and provide baseline experiments on popular benchmarks.
Related papers
- Benchmarking Predictive Coding Networks -- Made Simple [48.652114040426625]
We first propose a library called PCX, whose focus lies on performance and simplicity.
We use PCX to implement a large set of benchmarks for the community to use for their experiments.
arXiv Detail & Related papers (2024-07-01T10:33:44Z) - JaxUED: A simple and useable UED library in Jax [1.5821811088000381]
We present JaxUED, an open-source library providing minimal dependency implementations of modern Unsupervised Environment Design (UED) algorithms in Jax.
Inspired by CleanRL, we provide fast, clear, understandable, and easily modifiable implementations, with the aim of accelerating research into UED.
arXiv Detail & Related papers (2024-03-19T18:40:50Z) - DrJAX: Scalable and Differentiable MapReduce Primitives in JAX [9.676195490442367]
DrJAX is a library designed to support large-scale distributed and parallel machine learning algorithms.
DrJAX embeds building blocks for MapReduce computations as primitives in JAX.
DrJAX computations can be translated directly to XLA HLO, enabling flexible integration with a wide array of ML training platforms.
arXiv Detail & Related papers (2024-03-11T19:51:01Z) - BlackJAX: Composable Bayesian inference in JAX [8.834500692867671]
BlackJAX is a library implementing sampling and variational inference algorithms.
It is written in Python, using JAX to compile and run NumpPy-like samplers and variational methods on CPUs, GPUs, and TPUs.
arXiv Detail & Related papers (2024-02-16T16:21:02Z) - JaxMARL: Multi-Agent RL Environments and Algorithms in JAX [105.343918678781]
We present JaxMARL, the first open-source, Python-based library that combines GPU-enabled efficiency with support for a large number of commonly used MARL environments.
Our experiments show that, in terms of wall clock time, our JAX-based training pipeline is around 14 times faster than existing approaches.
We also introduce and benchmark SMAX, a JAX-based approximate reimplementation of the popular StarCraft Multi-Agent Challenge.
arXiv Detail & Related papers (2023-11-16T18:58:43Z) - QDax: A Library for Quality-Diversity and Population-based Algorithms
with Hardware Acceleration [3.8494302715990845]
QDax is an open-source library with a streamlined and modular API for Quality-Diversity (QD) optimization algorithms in Jax.
The library serves as a versatile tool for optimization purposes, ranging from black-box optimization to continuous control.
arXiv Detail & Related papers (2023-08-07T15:29:44Z) - SequeL: A Continual Learning Library in PyTorch and JAX [50.33956216274694]
SequeL is a library for Continual Learning that supports both PyTorch and JAX frameworks.
It provides a unified interface for a wide range of Continual Learning algorithms, including regularization-based approaches, replay-based approaches, and hybrid approaches.
We release SequeL as an open-source library, enabling researchers and developers to easily experiment and extend the library for their own purposes.
arXiv Detail & Related papers (2023-04-21T10:00:22Z) - GRecX: An Efficient and Unified Benchmark for GNN-based Recommendation [55.55523188090938]
We present GRecX, an open-source framework for benchmarking GNN-based recommendation models.
GRecX consists of core libraries for building GNN-based recommendation benchmarks, as well as the implementations of popular GNN-based recommendation models.
We conduct experiments with GRecX, and the experimental results show that GRecX allows us to train and benchmark GNN-based recommendation baselines in an efficient and unified way.
arXiv Detail & Related papers (2021-11-19T17:45:46Z) - Torch-Struct: Deep Structured Prediction Library [138.5262350501951]
We introduce Torch-Struct, a library for structured prediction.
Torch-Struct includes a broad collection of probabilistic structures accessed through a simple and flexible distribution-based API.
arXiv Detail & Related papers (2020-02-03T16:43:02Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.