DeFT: Decoding with Flash Tree-attention for Efficient Tree-structured LLM Inference
- URL: http://arxiv.org/abs/2404.00242v2
- Date: Wed, 29 May 2024 18:46:41 GMT
- Title: DeFT: Decoding with Flash Tree-attention for Efficient Tree-structured LLM Inference
- Authors: Jinwei Yao, Kaiqi Chen, Kexun Zhang, Jiaxuan You, Binhang Yuan, Zeke Wang, Tao Lin,
- Abstract summary: We introduce an IO-aware tree attention algorithm tailored for tree-structured inference.
DeFT achieves up to 2.52/3.82x speedup in the end-to-end/attention latency across three practical tree-based workloads.
- Score: 22.684773338989007
- License: http://arxiv.org/licenses/nonexclusive-distrib/1.0/
- Abstract: Given the increasing demand for tree-structured interactions with LLMs, we introduce DeFT (Decoding with Flash Tree-Attention), an IO-aware tree attention algorithm tailored for tree-structured inference. Unlike traditional sequence-based decoding, tree-structured decoding better accommodates modern task requirements, including self-consistency, few-shot prompting, multi-step reasoning, and multi-model/head coordination. However, existing sequence-based inference systems are ill-suited for tree-structured decoding, resulting in redundancy in computation, memory footprints, and memory access, thereby undermining inference efficiency. To address this challenge, DeFT maintains memory-efficient attention calculation with low memory footprints through two key stages: (1) QKV Preparation: We propose a KV-Guided Grouping Strategy with Tree Split to intelligently group QKV, optimizing GPU resource utilization while minimizing memory reads/writes for KV cache between GPU global memory and on-chip shared memory; (2)Attention Calculation: We compute partial attention of each QKV group in a fused kernel and employ a Tree-topology-aware Global Reduction strategy to obtain final attention. By reducing 73-99% KV cache IO and nearly 100% IO for partial results during attention calculation (e.g., Softmax), DeFT achieves up to 2.52/3.82x speedup in the end-to-end/attention latency across three practical tree-based workloads: namely, few-shot prompting, multi-step reasoning, and speculative decoding, over state-of-the-art attention algorithms.
Related papers
- ChunkAttention: Efficient Self-Attention with Prefix-Aware KV Cache and Two-Phase Partition [3.659659889927316]
ChunkAttention is a prefix-aware self-attention module for large language models.
It can detect matching prompt prefixes across multiple requests and share their key/value tensors in memory at runtime.
Experiments show that ChunkAttention can speed up the self-attention kernel by 3.2-4.8$times$ compared to the start-of-the-art implementation.
arXiv Detail & Related papers (2024-02-23T09:29:19Z) - Towards Model-Size Agnostic, Compute-Free, Memorization-based Inference
of Deep Learning [5.41530201129053]
This paper proposes a novel memorization-based inference (MBI) that is compute free and only requires lookups.
Specifically, our work capitalizes on the inference mechanism of the recurrent attention model (RAM)
By leveraging the low-dimensionality of glimpse, our inference procedure stores key value pairs comprising of glimpse location, patch vector, etc. in a table.
The computations are obviated during inference by utilizing the table to read out key-value pairs and performing compute-free inference by memorization.
arXiv Detail & Related papers (2023-07-14T21:01:59Z) - UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation [93.88170217725805]
We propose a 3D medical image segmentation approach, named UNETR++, that offers both high-quality segmentation masks as well as efficiency in terms of parameters, compute cost, and inference speed.
The core of our design is the introduction of a novel efficient paired attention (EPA) block that efficiently learns spatial and channel-wise discriminative features.
Our evaluations on five benchmarks, Synapse, BTCV, ACDC, BRaTs, and Decathlon-Lung, reveal the effectiveness of our contributions in terms of both efficiency and accuracy.
arXiv Detail & Related papers (2022-12-08T18:59:57Z) - Single MCMC Chain Parallelisation on Decision Trees [0.9137554315375919]
We propose a method to parallelise a single MCMC decision tree chain on an average laptop or personal computer.
Experiments showed that we could achieve 18 times faster running time provided that the serial and the parallel implementation are statistically identical.
arXiv Detail & Related papers (2022-07-26T07:07:51Z) - Group Fisher Pruning for Practical Network Compression [58.25776612812883]
We present a general channel pruning approach that can be applied to various complicated structures.
We derive a unified metric based on Fisher information to evaluate the importance of a single channel and coupled channels.
Our method can be used to prune any structures including those with coupled channels.
arXiv Detail & Related papers (2021-08-02T08:21:44Z) - Dynamic Probabilistic Pruning: A general framework for
hardware-constrained pruning at different granularities [80.06422693778141]
We propose a flexible new pruning mechanism that facilitates pruning at different granularities (weights, kernels, filters/feature maps)
We refer to this algorithm as Dynamic Probabilistic Pruning (DPP)
We show that DPP achieves competitive compression rates and classification accuracy when pruning common deep learning models trained on different benchmark datasets for image classification.
arXiv Detail & Related papers (2021-05-26T17:01:52Z) - Improved Branch and Bound for Neural Network Verification via Lagrangian
Decomposition [161.09660864941603]
We improve the scalability of Branch and Bound (BaB) algorithms for formally proving input-output properties of neural networks.
We present a novel activation-based branching strategy and a BaB framework, named Branch and Dual Network Bound (BaDNB)
BaDNB outperforms previous complete verification systems by a large margin, cutting average verification times by factors up to 50 on adversarial properties.
arXiv Detail & Related papers (2021-04-14T09:22:42Z) - PACSET (Packed Serialized Trees): Reducing Inference Latency for Tree
Ensemble Deployment [4.314299343332365]
We present methods to serialize and deserialize tree ensembles that optimize inference latency when models are not already loaded into memory.
Our packed serialized trees (PACSET) encode reference locality in the layout of a tree ensemble using principles from external memory algorithms.
The result is that each I/O yields a higher fraction of useful data, leading to a 2-6 times reduction in classification latency for interactive workloads.
arXiv Detail & Related papers (2020-11-10T20:32:11Z) - A Vertex Cut based Framework for Load Balancing and Parallelism
Optimization in Multi-core Systems [15.913119724815733]
High-level applications, such as machine learning, are evolving from simple models based on multilayer perceptrons for simple image recognition to much deeper and more complex neural networks for self-driving vehicle control systems.
Parallel programs running on high-performance computers often suffer from data communication bottlenecks, limited memory bandwidth, and synchronization overhead due to irregular critical sections.
We propose a framework to reduce the data communication and improve the scalability and performance of these applications in multi-core systems.
arXiv Detail & Related papers (2020-10-09T07:54:28Z) - Efficient Computation of Expectations under Spanning Tree Distributions [67.71280539312536]
We propose unified algorithms for the important cases of first-order expectations and second-order expectations in edge-factored, non-projective spanning-tree models.
Our algorithms exploit a fundamental connection between gradients and expectations, which allows us to derive efficient algorithms.
arXiv Detail & Related papers (2020-08-29T14:58:26Z) - OctSqueeze: Octree-Structured Entropy Model for LiDAR Compression [77.8842824702423]
We present a novel deep compression algorithm to reduce the memory footprint of LiDAR point clouds.
Our method exploits the sparsity and structural redundancy between points to reduce the memory footprint.
Our algorithm can be used to reduce the onboard and offboard storage of LiDAR points for applications such as self-driving cars.
arXiv Detail & Related papers (2020-05-14T17:48:49Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.