FlashAttention-2: Faster Attention with Better Parallelism and Work
Partitioning
- URL: http://arxiv.org/abs/2307.08691v1
- Date: Mon, 17 Jul 2023 17:50:36 GMT
- Title: FlashAttention-2: Faster Attention with Better Parallelism and Work
Partitioning
- Authors: Tri Dao
- Abstract summary: We exploit the asymmetric GPU memory hierarchy to bring significant memory saving and runtime speedup.
FlashAttention is still not nearly as fast as optimized matrix-multiply (GEMM) operations, reaching only 25-40% of the theoretical maximum FLOPs/s.
We propose FlashAttention-2, with better work partitioning to address these issues.
- Score: 11.508362885430133
- License: http://arxiv.org/licenses/nonexclusive-distrib/1.0/
- Abstract: Scaling Transformers to longer sequence lengths has been a major problem in
the last several years, promising to improve performance in language modeling
and high-resolution image understanding, as well as to unlock new applications
in code, audio, and video generation. The attention layer is the main
bottleneck in scaling to longer sequences, as its runtime and memory increase
quadratically in the sequence length. FlashAttention exploits the asymmetric
GPU memory hierarchy to bring significant memory saving (linear instead of
quadratic) and runtime speedup (2-4$\times$ compared to optimized baselines),
with no approximation. However, FlashAttention is still not nearly as fast as
optimized matrix-multiply (GEMM) operations, reaching only 25-40\% of the
theoretical maximum FLOPs/s. We observe that the inefficiency is due to
suboptimal work partitioning between different thread blocks and warps on the
GPU, causing either low-occupancy or unnecessary shared memory reads/writes. We
propose FlashAttention-2, with better work partitioning to address these
issues. In particular, we (1) tweak the algorithm to reduce the number of
non-matmul FLOPs (2) parallelize the attention computation, even for a single
head, across different thread blocks to increase occupancy, and (3) within each
thread block, distribute the work between warps to reduce communication through
shared memory. These yield around 2$\times$ speedup compared to FlashAttention,
reaching 50-73\% of the theoretical maximum FLOPs/s on A100 and getting close
to the efficiency of GEMM operations. We empirically validate that when used
end-to-end to train GPT-style models, FlashAttention-2 reaches training speed
of up to 225 TFLOPs/s per A100 GPU (72\% model FLOPs utilization).
Related papers
- Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss [59.835032408496545]
We propose a tile-based strategy that partitions the contrastive loss calculation into arbitrary small blocks.
We also introduce a multi-level tiling strategy to leverage the hierarchical structure of distributed systems.
Compared to SOTA memory-efficient solutions, it achieves a two-order-of-magnitude reduction in memory while maintaining comparable speed.
arXiv Detail & Related papers (2024-10-22T17:59:30Z) - FlashMask: Efficient and Rich Mask Extension of FlashAttention [22.810595298076866]
FlashMask is an extension of FlashAttention that introduces a column-wise sparse representation of attention masks.
By adopting this novel representation, FlashMask achieves linear memory complexity $O(N)$, suitable for modeling long-context sequences.
We evaluate FlashMask's performance in fine-tuning and alignment training of LLMs such as SFT, LoRA, DPO, and RM.
arXiv Detail & Related papers (2024-10-02T09:17:26Z) - vTensor: Flexible Virtual Tensor Management for Efficient LLM Serving [53.972175896814505]
Large Language Models (LLMs) are widely used across various domains, processing millions of daily requests.
Large Language Models (LLMs) are widely used across various domains, processing millions of daily requests.
arXiv Detail & Related papers (2024-07-22T14:37:58Z) - Efficient Video Object Segmentation via Modulated Cross-Attention Memory [123.12273176475863]
We propose a transformer-based approach, named MAVOS, to model temporal smoothness without requiring frequent memory expansion.
Our MAVOS achieves a J&F score of 63.3% while operating at 37 frames per second (FPS) on a single V100 GPU.
arXiv Detail & Related papers (2024-03-26T17:59:58Z) - A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on
NVIDIA Hopper Architecture using the CUTLASS Library [0.7366405857677227]
We provide an optimized implementation of the forward pass of FlashAttention-2 as a custom fused kernel targeting NVIDIA Hopper architecture.
We observe 20-50% higher FLOPs/s over a version of FlashAttention-2 optimized for last-generation NVIDIA Ampere architecture.
arXiv Detail & Related papers (2023-12-19T07:56:25Z) - DISTFLASHATTN: Distributed Memory-efficient Attention for Long-context LLMs Training [82.06732962485754]
FlashAttention effectively reduces the quadratic peak memory usage to linear in training transformer-based large language models (LLMs) on a single GPU.
We introduce DISTFLASHATTN, a memory-efficient attention mechanism optimized for long-context LLMs training.
It achieves 1.67x and 1.26 - 1.88x speedup compared to recent Ring Attention and DeepSpeed-Ulysses.
arXiv Detail & Related papers (2023-10-05T03:47:57Z) - Simple Hardware-Efficient Long Convolutions for Sequence Modeling [18.3719016967593]
State space models (SSMs) have high performance on long sequence modeling.
We study whether a simple alternative can match SSMs in performance and efficiency.
We develop FlashButterfly, an IO-aware algorithm to improve the runtime performance of long convolutions.
arXiv Detail & Related papers (2023-02-13T19:19:23Z) - FlashAttention: Fast and Memory-Efficient Exact Attention with
IO-Awareness [80.3586155104237]
FlashAttention is an IO-aware exact attention algorithm for Transformers.
It reduces the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip.
FlashAttention and block-sparse FlashAttention enable longer context in Transformers.
arXiv Detail & Related papers (2022-05-27T17:53:09Z) - Efficient Video Semantic Segmentation with Labels Propagation and
Refinement [138.55845680523908]
This paper tackles the problem of real-time semantic segmentation of high definition videos using a hybrid GPU / CPU approach.
We propose an Efficient Video(EVS) pipeline that combines: (i) On the CPU, a very fast optical flow method, that is used to exploit the temporal aspect of the video and propagate semantic information from one frame to the next.
On the popular Cityscapes dataset with high resolution frames (2048 x 1024), the proposed operating points range from 80 to 1000 Hz on a single GPU and CPU.
arXiv Detail & Related papers (2019-12-26T11:45:15Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.