FlashAttention-2: Faster Attention with Better Parallelism and Work
Partitioning
- URL: http://arxiv.org/abs/2307.08691v1
- Date: Mon, 17 Jul 2023 17:50:36 GMT
- Title: FlashAttention-2: Faster Attention with Better Parallelism and Work
Partitioning
- Authors: Tri Dao
- Abstract summary: We exploit the asymmetric GPU memory hierarchy to bring significant memory saving and runtime speedup.
FlashAttention is still not nearly as fast as optimized matrix-multiply (GEMM) operations, reaching only 25-40% of the theoretical maximum FLOPs/s.
We propose FlashAttention-2, with better work partitioning to address these issues.
- Score: 11.508362885430133
- License: http://arxiv.org/licenses/nonexclusive-distrib/1.0/
- Abstract: Scaling Transformers to longer sequence lengths has been a major problem in
the last several years, promising to improve performance in language modeling
and high-resolution image understanding, as well as to unlock new applications
in code, audio, and video generation. The attention layer is the main
bottleneck in scaling to longer sequences, as its runtime and memory increase
quadratically in the sequence length. FlashAttention exploits the asymmetric
GPU memory hierarchy to bring significant memory saving (linear instead of
quadratic) and runtime speedup (2-4$\times$ compared to optimized baselines),
with no approximation. However, FlashAttention is still not nearly as fast as
optimized matrix-multiply (GEMM) operations, reaching only 25-40\% of the
theoretical maximum FLOPs/s. We observe that the inefficiency is due to
suboptimal work partitioning between different thread blocks and warps on the
GPU, causing either low-occupancy or unnecessary shared memory reads/writes. We
propose FlashAttention-2, with better work partitioning to address these
issues. In particular, we (1) tweak the algorithm to reduce the number of
non-matmul FLOPs (2) parallelize the attention computation, even for a single
head, across different thread blocks to increase occupancy, and (3) within each
thread block, distribute the work between warps to reduce communication through
shared memory. These yield around 2$\times$ speedup compared to FlashAttention,
reaching 50-73\% of the theoretical maximum FLOPs/s on A100 and getting close
to the efficiency of GEMM operations. We empirically validate that when used
end-to-end to train GPT-style models, FlashAttention-2 reaches training speed
of up to 225 TFLOPs/s per A100 GPU (72\% model FLOPs utilization).
Related papers
- vTensor: Flexible Virtual Tensor Management for Efficient LLM Serving [53.972175896814505]
Large Language Models (LLMs) are widely used across various domains, processing millions of daily requests.
Large Language Models (LLMs) are widely used across various domains, processing millions of daily requests.
arXiv Detail & Related papers (2024-07-22T14:37:58Z) - Efficient Video Object Segmentation via Modulated Cross-Attention Memory [123.12273176475863]
We propose a transformer-based approach, named MAVOS, to model temporal smoothness without requiring frequent memory expansion.
Our MAVOS achieves a J&F score of 63.3% while operating at 37 frames per second (FPS) on a single V100 GPU.
arXiv Detail & Related papers (2024-03-26T17:59:58Z) - A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on
NVIDIA Hopper Architecture using the CUTLASS Library [0.7366405857677227]
We provide an optimized implementation of the forward pass of FlashAttention-2 as a custom fused kernel targeting NVIDIA Hopper architecture.
We observe 20-50% higher FLOPs/s over a version of FlashAttention-2 optimized for last-generation NVIDIA Ampere architecture.
arXiv Detail & Related papers (2023-12-19T07:56:25Z) - FlashDecoding++: Faster Large Language Model Inference on GPUs [16.289377349637995]
We present FlashDecoding++, a fast inference engine supporting mainstream Large Language Model (LLM) inference.
To tackle the above challenges, FlashDecoding++ introduces a unified max value technique for different partial softmax computations.
FlashDecoding++ can achieve up to 4.86x and 2.18x speedup on both NVIDIA and AMD GPUs.
arXiv Detail & Related papers (2023-11-02T14:57:03Z) - DISTFLASHATTN: Distributed Memory-efficient Attention for Long-context LLMs Training [82.06732962485754]
FlashAttention effectively reduces the quadratic peak memory usage to linear in training transformer-based large language models (LLMs) on a single GPU.
We introduce DISTFLASHATTN, a memory-efficient attention mechanism optimized for long-context LLMs training.
It achieves 1.67x and 1.26 - 1.88x speedup compared to recent Ring Attention and DeepSpeed-Ulysses.
arXiv Detail & Related papers (2023-10-05T03:47:57Z) - Simple Hardware-Efficient Long Convolutions for Sequence Modeling [18.3719016967593]
State space models (SSMs) have high performance on long sequence modeling.
We study whether a simple alternative can match SSMs in performance and efficiency.
We develop FlashButterfly, an IO-aware algorithm to improve the runtime performance of long convolutions.
arXiv Detail & Related papers (2023-02-13T19:19:23Z) - Adaptable Butterfly Accelerator for Attention-based NNs via Hardware and
Algorithm Co-design [66.39546326221176]
Attention-based neural networks have become pervasive in many AI tasks.
The use of the attention mechanism and feed-forward network (FFN) demands excessive computational and memory resources.
This paper proposes a hardware-friendly variant that adopts a unified butterfly sparsity pattern to approximate both the attention mechanism and the FFNs.
arXiv Detail & Related papers (2022-09-20T09:28:26Z) - FlashAttention: Fast and Memory-Efficient Exact Attention with
IO-Awareness [80.3586155104237]
FlashAttention is an IO-aware exact attention algorithm for Transformers.
It reduces the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip.
FlashAttention and block-sparse FlashAttention enable longer context in Transformers.
arXiv Detail & Related papers (2022-05-27T17:53:09Z) - Efficient Video Semantic Segmentation with Labels Propagation and
Refinement [138.55845680523908]
This paper tackles the problem of real-time semantic segmentation of high definition videos using a hybrid GPU / CPU approach.
We propose an Efficient Video(EVS) pipeline that combines: (i) On the CPU, a very fast optical flow method, that is used to exploit the temporal aspect of the video and propagate semantic information from one frame to the next.
On the popular Cityscapes dataset with high resolution frames (2048 x 1024), the proposed operating points range from 80 to 1000 Hz on a single GPU and CPU.
arXiv Detail & Related papers (2019-12-26T11:45:15Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.