MPX: Mixed Precision Training for JAX
- URL: http://arxiv.org/abs/2507.03312v2
- Date: Tue, 08 Jul 2025 06:28:22 GMT
- Title: MPX: Mixed Precision Training for JAX
- Authors: Alexander Gräfe, Sebastian Trimpe,
- Abstract summary: Mixed-precision training has emerged as an indispensable tool for enhancing the efficiency of neural network training.<n>We propose MPX, a mixed-precision training toolbox for JAX that simplifies and accelerates the training of large-scale neural networks.<n>MPX seamlessly integrates with popular toolboxes such as Equinox and Flax, allowing users to convert full-precision pipelines to mixed-precision versions.
- Score: 54.62458721568289
- License: http://creativecommons.org/licenses/by-nc-sa/4.0/
- Abstract: Mixed-precision training has emerged as an indispensable tool for enhancing the efficiency of neural network training in recent years. Concurrently, JAX has grown in popularity as a versatile machine learning toolbox. However, it currently lacks robust support for mixed-precision training. We propose MPX, a mixed-precision training toolbox for JAX that simplifies and accelerates the training of large-scale neural networks while preserving model accuracy. MPX seamlessly integrates with popular toolboxes such as Equinox and Flax, allowing users to convert full-precision pipelines to mixed-precision versions with minimal modifications. By casting both inputs and outputs to half precision, and introducing a dynamic loss-scaling mechanism, MPX alleviates issues like gradient underflow and overflow that commonly arise in half precision computations. Its design inherits critical features from JAX's type-promotion behavior, ensuring that operations take place in the correct precision and allowing for selective enforcement of full precision where needed (e.g., sums, means, or softmax). MPX further provides wrappers for automatic creation and management of mixed-precision gradients and optimizers, enabling straightforward integration into existing JAX training pipelines. MPX's source code, documentation, and usage examples are available at github.com/Data-Science-in-Mechanical-Engineering/mixed_precision_for_JAX .
Related papers
- MPQ-DMv2: Flexible Residual Mixed Precision Quantization for Low-Bit Diffusion Models with Temporal Distillation [74.34220141721231]
We present MPQ-DMv2, an improved textbfMixed textbfPrecision textbfQuantization framework for extremely low-bit textbfDiffusion textbfModels.
arXiv Detail & Related papers (2025-07-06T08:16:50Z) - RefineX: Learning to Refine Pre-training Data at Scale from Expert-Guided Programs [76.3459242819381]
RefineX is a novel framework for large-scale, surgical refinement of pre-training data through programmatic editing tasks.<n>The core strength of RefineX lies in distilling high-quality, expert-guided end-to-end refinement results into minimal edit-based deletion programs.<n>We evaluate RefineX across from-scratch pre-training at multiple model scales and find that it consistently outperforms models trained on raw, filtered, or alternatively refined data.
arXiv Detail & Related papers (2025-07-04T02:19:58Z) - ESLM: Risk-Averse Selective Language Modeling for Efficient Pretraining [53.893792844055106]
Large language model pretraining is compute-intensive, yet many tokens contribute marginally to learning, resulting in inefficiency.<n>We introduce Selective Efficient Language Modeling, a risk-aware algorithm that improves training efficiency and distributional robustness by performing online token-level batch selection.<n> Experiments on GPT-2 pretraining show that ESLM significantly reduces training FLOPs while maintaining or improving both perplexity and downstream performance compared to baselines.
arXiv Detail & Related papers (2025-05-26T12:23:26Z) - FlexQuant: A Flexible and Efficient Dynamic Precision Switching Framework for LLM Quantization [18.041828697950812]
We propose FlexQuant, a dynamic precision-switching framework to optimize the trade-off between inference speed and accuracy.<n>Our work provides a comprehensive analysis of quantization strategies, introduces a precision requirement model for optimal switching, and implements efficient fine-grained precision management.<n> Experimental results demonstrate that FlexQuant achieves a 1.3x end-to-end speedup across diverse language tasks with negligible accuracy loss.
arXiv Detail & Related papers (2025-05-21T07:42:53Z) - Quartet: Native FP4 Training Can Be Optimal for Large Language Models [27.800012997794987]
Training large language models (LLMs) models directly in low-precision offers a way to address computational costs.<n> NVIDIA's recent Blackwell architecture facilitates very low-precision operations using FP4 variants.<n>We introduce a new approach for accurate, end-to-end FP4 training with all the major computations in low precision.
arXiv Detail & Related papers (2025-05-20T17:55:50Z) - Optimizing ML Training with Metagradient Descent [69.89631748402377]
We introduce an algorithm for efficiently calculating metagradients -- gradients through model training -- at scale.<n>We then introduce a "smooth model training" framework that enables effective optimization using metagradients.
arXiv Detail & Related papers (2025-03-17T22:18:24Z) - MPAX: Mathematical Programming in JAX [4.320198313490604]
MPAX is a versatile and efficient toolbox for integrating linear programming into machine learning.<n>It provides native support for hardware accelerations along with features like batch solving, auto-differentiation, and device parallelism.
arXiv Detail & Related papers (2024-12-12T21:52:27Z) - Scalify: scale propagation for efficient low-precision LLM training [1.4999444543328293]
Low-precision formats such as float8 have been introduced in machine learning accelerated hardware to improve computational efficiency for large language models training and inference.
We present Scalify, a end-to-end scale propagation paradigm for computational graphs.
arXiv Detail & Related papers (2024-07-24T15:26:01Z) - SliM-LLM: Salience-Driven Mixed-Precision Quantization for Large Language Models [63.118592279833656]
Post-training quantization (PTQ) is an effective technique for compressing large language models (LLMs)<n>We propose SliM-LLM, a salience-driven mixed-precision quantization framework that allocates bit-widths at the group-wise.<n> Experiments show that SliM-LLM achieves superior performance across various LLMs at low bit-widths.
arXiv Detail & Related papers (2024-05-23T16:21:48Z) - On-Chip Hardware-Aware Quantization for Mixed Precision Neural Networks [52.97107229149988]
We propose an On-Chip Hardware-Aware Quantization framework, performing hardware-aware mixed-precision quantization on deployed edge devices.
For efficiency metrics, we built an On-Chip Quantization Aware pipeline, which allows the quantization process to perceive the actual hardware efficiency of the quantization operator.
For accuracy metrics, we propose Mask-Guided Quantization Estimation technology to effectively estimate the accuracy impact of operators in the on-chip scenario.
arXiv Detail & Related papers (2023-09-05T04:39:34Z) - Precision-Recall Divergence Optimization for Generative Modeling with
GANs and Normalizing Flows [54.050498411883495]
We develop a novel training method for generative models, such as Generative Adversarial Networks and Normalizing Flows.
We show that achieving a specified precision-recall trade-off corresponds to minimizing a unique $f$-divergence from a family we call the textitPR-divergences.
Our approach improves the performance of existing state-of-the-art models like BigGAN in terms of either precision or recall when tested on datasets such as ImageNet.
arXiv Detail & Related papers (2023-05-30T10:07:17Z) - Activation Density based Mixed-Precision Quantization for Energy
Efficient Neural Networks [2.666640112616559]
We propose an in-training quantization method for neural network models.
Our method calculates bit-width for each layer during training a mixed precision model with competitive accuracy.
We run experiments on benchmark datasets like CIFAR-10, CIFAR-100, TinyImagenet on VGG19/ResNet18 architectures.
arXiv Detail & Related papers (2021-01-12T09:01:44Z) - Multi-Precision Policy Enforced Training (MuPPET): A precision-switching
strategy for quantised fixed-point training of CNNs [13.83645579871775]
Large-scale convolutional neural networks (CNNs) suffer from very long training times, spanning from hours to weeks.
This work pushes the boundary of quantised training by employing a multilevel approach that utilises multiple precisions.
MuPPET achieves the same accuracy as standard full-precision training with training-time speedup of up to 1.84$times$ and an average speedup of 1.58$times$ across the networks.
arXiv Detail & Related papers (2020-06-16T10:14:36Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.