Adjoint sharding for very long context training of state space models
- URL: http://arxiv.org/abs/2501.00692v1
- Date: Wed, 01 Jan 2025 01:10:59 GMT
- Title: Adjoint sharding for very long context training of state space models
- Authors: Xingzi Xu, Amir Tavanaei, Kavosh Asadi, Karim Bouyarmane,
- Abstract summary: Adjoint sharding is a technique that comprises sharding gradient calculation during training to reduce memory requirements by orders of magnitude.<n>We show the proposed adjoint sharding algorithm reduces memory usage by up to 3X with a 1.27B parameter large language model on 1M context length training.<n>This allows to increase the maximum context length during training or fine-tuning of a 1.27B parameter model from 35K tokens to above 100K tokens on a training infrastructure composed of five AWS P4 instances.
- Score: 7.723642550918118
- License: http://arxiv.org/licenses/nonexclusive-distrib/1.0/
- Abstract: Despite very fast progress, efficiently training large language models (LLMs) in very long contexts remains challenging. Existing methods fall back to training LLMs with short contexts (a maximum of a few thousands tokens in training) and use inference time techniques when evaluating on long contexts (above 1M tokens context window at inference). As opposed to long-context-inference, training on very long context input prompts is quickly limited by GPU memory availability and by the prohibitively long training times it requires on state-of-the-art hardware. Meanwhile, many real-life applications require not only inference but also training/fine-tuning with long context on specific tasks. Such applications include, for example, augmenting the context with various sources of raw reference information for fact extraction, fact summarization, or fact reconciliation tasks. We propose adjoint sharding, a novel technique that comprises sharding gradient calculation during training to reduce memory requirements by orders of magnitude, making training on very long context computationally tractable. Adjoint sharding is based on the adjoint method and computes equivalent gradients to backpropagation. We also propose truncated adjoint sharding to speed up the algorithm while maintaining performance. We provide a distributed version, and a paralleled version of adjoint sharding to further speed up training. Empirical results show the proposed adjoint sharding algorithm reduces memory usage by up to 3X with a 1.27B parameter large language model on 1M context length training. This allows to increase the maximum context length during training or fine-tuning of a 1.27B parameter model from 35K tokens to above 100K tokens on a training infrastructure composed of five AWS P4 instances.
Related papers
- From 128K to 4M: Efficient Training of Ultra-Long Context Large Language Models [54.44375226381814]
Long-context capabilities are essential for a wide range of applications, including document and video understanding, in-context learning, and inference-time scaling.
We introduce a efficient training recipe for building ultra-long context LLMs from aligned instruct model, pushing the boundaries of context lengths from 128K to 1M, 2M, and 4M tokens.
Our approach achieves state-of-the-art performance across a diverse set of long-context benchmarks.
arXiv Detail & Related papers (2025-04-08T16:58:58Z) - InfiniteHiP: Extending Language Model Context Up to 3 Million Tokens on a Single GPU [48.105361428245736]
We introduce InfiniteHiP, an inference framework for large language models (LLMs)
We dynamically eliminate irrelevant context tokens through a modular hierarchical token pruning algorithm.
Our framework achieves an 18.95x speedup in attention decoding for a 1 million token context without requiring additional training.
arXiv Detail & Related papers (2025-02-13T02:52:01Z) - How to Train Long-Context Language Models (Effectively) [75.5418485597276]
We study continued training and supervised fine-tuning (SFT) of a language model (LM) to make effective use of long-context information.
We find that code repositories and books are excellent sources of long data, but it is crucial to combine them with high-quality short-context data.
Our final model, ProLong-8B, demonstrates state-of-the-art long-context performance among similarly sized models at a length of 128K.
arXiv Detail & Related papers (2024-10-03T16:46:52Z) - A Little Goes a Long Way: Efficient Long Context Training and Inference with Partial Contexts [38.867323730365406]
LongGen finetunes a pretrained LLM into an efficient architecture during length extension.
LongGen achieves 1.55x training speedup and reduces wall-clock time by 36%, compared to a full-attention baseline.
During inference, LongGen reduces KV cache memory by 62%, achieving 1.67x prefilling speedup and 1.41x decoding speedup.
arXiv Detail & Related papers (2024-10-02T12:35:53Z) - Hierarchical Context Merging: Better Long Context Understanding for Pre-trained LLMs [61.40047491337793]
We present Hierarchical cOntext MERging (HOMER), a new training-free scheme designed to overcome the limitations of large language models.
HomeR uses a divide-and-conquer algorithm, dividing long inputs into manageable chunks.
A token reduction technique precedes each merging, ensuring memory usage efficiency.
arXiv Detail & Related papers (2024-04-16T06:34:08Z) - E^2-LLM: Efficient and Extreme Length Extension of Large Language Models [74.1254067728251]
We propose an Efficient and Extreme length extension method for Large Language Models, called E 2 -LLM, with only one training procedure and dramatically reduced cost.
Comprehensive experimental results on multiple benchmark datasets demonstrate the effectiveness of our E 2 -LLM on challenging long-context tasks.
arXiv Detail & Related papers (2024-01-13T02:11:20Z) - BTR: Binary Token Representations for Efficient Retrieval Augmented Language Models [77.0501668780182]
Retrieval augmentation addresses many critical problems in large language models.
Running retrieval-augmented language models (LMs) is slow and difficult to scale due to processing large amounts of retrieved text.
We introduce binary token representations (BTR), which use 1-bit vectors to precompute every token in passages.
arXiv Detail & Related papers (2023-10-02T16:48:47Z) - Layered gradient accumulation and modular pipeline parallelism: fast and
efficient training of large language models [0.0]
We analyse the shortest possible training time for different configurations of distributed training.
We introduce two new methods, textitlayered gradient accumulation and textitmodular pipeline parallelism, which together cut the shortest training time by half.
arXiv Detail & Related papers (2021-06-04T19:21:49Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.