Efficient Long-Decoding Inference with Reasoning-Aware Attention Sparsity
- URL: http://arxiv.org/abs/2502.11147v1
- Date: Sun, 16 Feb 2025 14:28:52 GMT
- Title: Efficient Long-Decoding Inference with Reasoning-Aware Attention Sparsity
- Authors: Junhao Hu, Wenrui Huang, Weidong Wang, Zhenwen Li, Tiancheng Hu, Zhixia Liu, Xusheng Chen, Tao Xie, Yizhou Shan,
- Abstract summary: Solving reasoning tasks often requires long decoding chains (of thoughts) which incur $O(N)$ time and memory consumption.
We propose a new algorithm named RaaS that identifies and retains milestone tokens only until they are no longer needed.
Based on this pattern, we propose a new algorithm named RaaS that achieves high accuracy with $O(L)$ time and $O(L)$ memory complexity.
- Score: 14.409253716114213
- License:
- Abstract: Large Language Models (LLMs) have demonstrated strong capabilities across various domains, with recent advancements in challenging reasoning tasks such as mathematics and programming. However, solving reasoning tasks often requires long decoding chains (of thoughts), which incur $O(N)$ time and memory consumption, where $N$ is the chain length. To mitigate $O(N)$ time and memory consumption, existing sparsity-based algorithms propose retaining only the most critical token's intermediate data (i.e., key-value cache) and discarding the rest. However, these existing algorithms struggle with the ``impossible trinity'' of accuracy, time, and memory. For example, the state-of-the-art algorithm, Quest, achieves high accuracy with $O(L)$ time but $O(N)$ memory ($L$ is the cache budget, $L \ll N$). To address this issue, in this paper, we identify a new attention pattern during the decode stage of reasoning tasks, where milestone tokens (analogous to lemmas in mathematical proofs) emerge, are utilized, and then become unimportant afterward. Based on this pattern, we propose a new algorithm named RaaS that identifies and retains milestone tokens only until they are no longer needed, achieving high accuracy with $O(L)$ time and $O(L)$ memory complexity.
Related papers
- Algorithm Design for Continual Learning in IoT Networks [16.35495567193046]
Continual learning (CL) is a new online learning technique over sequentially generated streaming data from different tasks.
In practical IoT networks, an autonomous vehicle to sample data and learn different tasks can route and alter the order of task pattern.
arXiv Detail & Related papers (2024-12-22T02:36:09Z) - HashAttention: Semantic Sparsity for Faster Inference [91.54218318798603]
HashAttention is a principled approach casting pivotal token identification as a recommendation problem.
It efficiently identifies pivotal tokens for a given query in this Hamming space using bitwise operations.
It can reduce the number of tokens used by a factor of $1/32times$ for the Llama-3.1-8B model with LongBench.
arXiv Detail & Related papers (2024-12-19T02:34:15Z) - A Training-free Sub-quadratic Cost Transformer Model Serving Framework With Hierarchically Pruned Attention [43.211427581302715]
We propose Hierarchically Pruned Attention (HiP) to increase context length in large language models.
HiP reduces the time complexity of the attention mechanism to $O(T log T)$ and the space complexity to $O(T)$, where $T$ is the sequence length.
We show that HiP significantly reduces both prefill and decoding latencies, as well as memory usage, while maintaining high-quality generation with minimal degradation.
arXiv Detail & Related papers (2024-06-14T08:32:45Z) - One Pass Streaming Algorithm for Super Long Token Attention
Approximation in Sublinear Space [11.735802740426294]
Attention computation takes both the time complexity of $O(n2)$ and the space complexity of $O(n2)$ simultaneously.
We introduce a new algorithm that only reads one pass of data in a streaming fashion.
Notably, our algorithm exhibits exceptional memory-efficient performance with super-long tokens.
arXiv Detail & Related papers (2023-11-24T18:35:00Z) - Efficiently Learning One-Hidden-Layer ReLU Networks via Schur
Polynomials [50.90125395570797]
We study the problem of PAC learning a linear combination of $k$ ReLU activations under the standard Gaussian distribution on $mathbbRd$ with respect to the square loss.
Our main result is an efficient algorithm for this learning task with sample and computational complexity $(dk/epsilon)O(k)$, whereepsilon>0$ is the target accuracy.
arXiv Detail & Related papers (2023-07-24T14:37:22Z) - Improved Algorithms for Allen's Interval Algebra by Dynamic Programming
with Sublinear Partitioning [9.594432031144716]
Allen's interval algebra is one of the most well-known calculi in qualitative temporal reasoning.
We propose a novel framework for solving NP-hard qualitative reasoning problems.
We obtain a major improvement of $O*((fraccnlogn)n)$ for Allen's interval algebra.
arXiv Detail & Related papers (2023-05-25T11:45:12Z) - Simplifying and Understanding State Space Models with Diagonal Linear
RNNs [56.33053691749856]
This work disposes of the discretization step, and proposes a model based on vanilla Diagonal Linear RNNs.
We empirically show that, despite being conceptually much simpler, $mathrmDLR$ is as performant as previously-proposed SSMs.
We also characterize the expressivity of SSMs and attention-based models via a suite of $13$ synthetic sequence-to-sequence tasks.
arXiv Detail & Related papers (2022-12-01T18:53:06Z) - Near-Optimal Regret Bounds for Multi-batch Reinforcement Learning [54.806166861456035]
We study the episodic reinforcement learning (RL) problem modeled by finite-horizon Markov Decision Processes (MDPs) with constraint on the number of batches.
We design a computational efficient algorithm to achieve near-optimal regret of $tildeO(sqrtSAH3Kln (1/delta))$tildeO(cdot) hides logarithmic terms of $(S,A,H,K)$ in $K$ episodes.
Our technical contribution are two-fold: 1) a near-optimal design scheme to explore
arXiv Detail & Related papers (2022-10-15T09:22:22Z) - Learning a Latent Simplex in Input-Sparsity Time [58.30321592603066]
We consider the problem of learning a latent $k$-vertex simplex $KsubsetmathbbRdtimes n$, given access to $AinmathbbRdtimes n$.
We show that the dependence on $k$ in the running time is unnecessary given a natural assumption about the mass of the top $k$ singular values of $A$.
arXiv Detail & Related papers (2021-05-17T16:40:48Z) - RNNs can generate bounded hierarchical languages with optimal memory [113.73133308478612]
We show that RNNs can efficiently generate bounded hierarchical languages that reflect the scaffolding of natural language syntax.
We introduce Dyck-($k$,$m$), the language of well-nested brackets (of $k$ types) and $m$-bounded nesting depth.
We prove that an RNN with $O(m log k)$ hidden units suffices, an exponential reduction in memory, by an explicit construction.
arXiv Detail & Related papers (2020-10-15T04:42:29Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.