Adjoint sharding for very long context training of state space models
- URL: http://arxiv.org/abs/2501.00692v1
- Date: Wed, 01 Jan 2025 01:10:59 GMT
- Title: Adjoint sharding for very long context training of state space models
- Authors: Xingzi Xu, Amir Tavanaei, Kavosh Asadi, Karim Bouyarmane,
- Abstract summary: Adjoint sharding is a technique that comprises sharding gradient calculation during training to reduce memory requirements by orders of magnitude.
We show the proposed adjoint sharding algorithm reduces memory usage by up to 3X with a 1.27B parameter large language model on 1M context length training.
This allows to increase the maximum context length during training or fine-tuning of a 1.27B parameter model from 35K tokens to above 100K tokens on a training infrastructure composed of five AWS P4 instances.
- Score: 7.723642550918118
- License:
- Abstract: Despite very fast progress, efficiently training large language models (LLMs) in very long contexts remains challenging. Existing methods fall back to training LLMs with short contexts (a maximum of a few thousands tokens in training) and use inference time techniques when evaluating on long contexts (above 1M tokens context window at inference). As opposed to long-context-inference, training on very long context input prompts is quickly limited by GPU memory availability and by the prohibitively long training times it requires on state-of-the-art hardware. Meanwhile, many real-life applications require not only inference but also training/fine-tuning with long context on specific tasks. Such applications include, for example, augmenting the context with various sources of raw reference information for fact extraction, fact summarization, or fact reconciliation tasks. We propose adjoint sharding, a novel technique that comprises sharding gradient calculation during training to reduce memory requirements by orders of magnitude, making training on very long context computationally tractable. Adjoint sharding is based on the adjoint method and computes equivalent gradients to backpropagation. We also propose truncated adjoint sharding to speed up the algorithm while maintaining performance. We provide a distributed version, and a paralleled version of adjoint sharding to further speed up training. Empirical results show the proposed adjoint sharding algorithm reduces memory usage by up to 3X with a 1.27B parameter large language model on 1M context length training. This allows to increase the maximum context length during training or fine-tuning of a 1.27B parameter model from 35K tokens to above 100K tokens on a training infrastructure composed of five AWS P4 instances.
Related papers
- InfiniteHiP: Extending Language Model Context Up to 3 Million Tokens on a Single GPU [48.105361428245736]
We introduce InfiniteHiP, an inference framework for large language models (LLMs)
We dynamically eliminate irrelevant context tokens through a modular hierarchical token pruning algorithm.
Our framework achieves an 18.95x speedup in attention decoding for a 1 million token context without requiring additional training.
arXiv Detail & Related papers (2025-02-13T02:52:01Z) - A Little Goes a Long Way: Efficient Long Context Training and Inference with Partial Contexts [38.867323730365406]
LongGen finetunes a pretrained LLM into an efficient architecture during length extension.
LongGen achieves 1.55x training speedup and reduces wall-clock time by 36%, compared to a full-attention baseline.
During inference, LongGen reduces KV cache memory by 62%, achieving 1.67x prefilling speedup and 1.41x decoding speedup.
arXiv Detail & Related papers (2024-10-02T12:35:53Z) - KV Cache Compression, But What Must We Give in Return? A Comprehensive Benchmark of Long Context Capable Approaches [52.02764371205856]
Long context capability is a crucial competency for large language models (LLMs)
This work provides a taxonomy of current methods and evaluating 10+ state-of-the-art approaches across seven categories of long context tasks.
arXiv Detail & Related papers (2024-07-01T17:59:47Z) - Training-Free Exponential Context Extension via Cascading KV Cache [49.608367376911694]
We introduce a novel mechanism that leverages cascading sub-cache buffers to selectively retain the most relevant tokens.
Our method reduces prefill stage latency by a factor of 6.8 when compared to flash attention on 1M tokens.
arXiv Detail & Related papers (2024-06-24T03:59:17Z) - Hierarchical Context Merging: Better Long Context Understanding for Pre-trained LLMs [61.40047491337793]
We present Hierarchical cOntext MERging (HOMER), a new training-free scheme designed to overcome the limitations of large language models.
HomeR uses a divide-and-conquer algorithm, dividing long inputs into manageable chunks.
A token reduction technique precedes each merging, ensuring memory usage efficiency.
arXiv Detail & Related papers (2024-04-16T06:34:08Z) - E^2-LLM: Efficient and Extreme Length Extension of Large Language Models [74.1254067728251]
We propose an Efficient and Extreme length extension method for Large Language Models, called E 2 -LLM, with only one training procedure and dramatically reduced cost.
Comprehensive experimental results on multiple benchmark datasets demonstrate the effectiveness of our E 2 -LLM on challenging long-context tasks.
arXiv Detail & Related papers (2024-01-13T02:11:20Z) - BTR: Binary Token Representations for Efficient Retrieval Augmented Language Models [77.0501668780182]
Retrieval augmentation addresses many critical problems in large language models.
Running retrieval-augmented language models (LMs) is slow and difficult to scale due to processing large amounts of retrieved text.
We introduce binary token representations (BTR), which use 1-bit vectors to precompute every token in passages.
arXiv Detail & Related papers (2023-10-02T16:48:47Z) - Layered gradient accumulation and modular pipeline parallelism: fast and
efficient training of large language models [0.0]
We analyse the shortest possible training time for different configurations of distributed training.
We introduce two new methods, textitlayered gradient accumulation and textitmodular pipeline parallelism, which together cut the shortest training time by half.
arXiv Detail & Related papers (2021-06-04T19:21:49Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.