FlashMask: Efficient and Rich Mask Extension of FlashAttention
- URL: http://arxiv.org/abs/2410.01359v1
- Date: Wed, 2 Oct 2024 09:17:26 GMT
- Title: FlashMask: Efficient and Rich Mask Extension of FlashAttention
- Authors: Guoxia Wang, Jinle Zeng, Xiyuan Xiao, Siming Wu, Jiabin Yang, Lujing Zheng, Zeyu Chen, Jiang Bian, Dianhai Yu, Haifeng Wang,
- Abstract summary: FlashMask is an extension of FlashAttention that introduces a column-wise sparse representation of attention masks.
By adopting this novel representation, FlashMask achieves linear memory complexity $O(N)$, suitable for modeling long-context sequences.
We evaluate FlashMask's performance in fine-tuning and alignment training of LLMs such as SFT, LoRA, DPO, and RM.
- Score: 22.810595298076866
- License: http://creativecommons.org/licenses/by-nc-sa/4.0/
- Abstract: The computational and memory demands of vanilla attention scale quadratically with the sequence length $N$, posing significant challenges for processing long sequences in Transformer models. FlashAttention alleviates these challenges by eliminating the $O(N^2)$ memory dependency and reducing attention latency through IO-aware memory optimizations. However, its native support for certain attention mask types is limited, and it does not inherently accommodate more complex masking requirements. Previous approaches resort to using dense masks with $O(N^2)$ memory complexity, leading to inefficiencies. In this paper, we propose FlashMask, an extension of FlashAttention that introduces a column-wise sparse representation of attention masks. This approach efficiently represents a wide range of mask types and facilitates the development of optimized kernel implementations. By adopting this novel representation, FlashMask achieves linear memory complexity $O(N)$, suitable for modeling long-context sequences. Moreover, this representation enables kernel optimizations that eliminate unnecessary computations by leveraging sparsity in the attention mask, without sacrificing computational accuracy, resulting in higher computational efficiency. We evaluate FlashMask's performance in fine-tuning and alignment training of LLMs such as SFT, LoRA, DPO, and RM. FlashMask achieves significant throughput improvements, with end-to-end speedups ranging from 1.65x to 3.22x compared to existing FlashAttention dense method. Additionally, our kernel-level comparisons demonstrate that FlashMask surpasses the latest counterpart, FlexAttention, by 12.1% to 60.7% in terms of kernel TFLOPs/s, achieving 37.8% to 62.3% of the theoretical maximum FLOPs/s on the A100 GPU. The code is open-sourced on PaddlePaddle and integrated into PaddleNLP, supporting models with over 100 billion parameters for contexts up to 128K tokens.
Related papers
- Efficiently Dispatching Flash Attention For Partially Filled Attention Masks [29.36452085947087]
Transformers are widely used across various applications, many of which yield sparse or partially filled attention matrices.
We introduce Binary Block Masking, a highly efficient modification that enhances Flash Attention by making it mask-aware.
Our experiments on attention masks derived from real-world scenarios demonstrate up to a 9x runtime improvement.
arXiv Detail & Related papers (2024-09-23T15:11:07Z) - Mask Propagation for Efficient Video Semantic Segmentation [63.09523058489429]
Video Semantic baseline degradation (VSS) involves assigning a semantic label to each pixel in a video sequence.
We propose an efficient mask propagation framework for VSS, called SSSS.
Our framework reduces up to 4x FLOPs compared to the per-frame Mask2Former with only up to 2% mIoU on the Cityscapes validation set.
arXiv Detail & Related papers (2023-10-29T09:55:28Z) - DISTFLASHATTN: Distributed Memory-efficient Attention for Long-context LLMs Training [82.06732962485754]
FlashAttention effectively reduces the quadratic peak memory usage to linear in training transformer-based large language models (LLMs) on a single GPU.
We introduce DISTFLASHATTN, a memory-efficient attention mechanism optimized for long-context LLMs training.
It achieves 1.67x and 1.26 - 1.88x speedup compared to recent Ring Attention and DeepSpeed-Ulysses.
arXiv Detail & Related papers (2023-10-05T03:47:57Z) - FlashAttention-2: Faster Attention with Better Parallelism and Work
Partitioning [11.508362885430133]
We exploit the asymmetric GPU memory hierarchy to bring significant memory saving and runtime speedup.
FlashAttention is still not nearly as fast as optimized matrix-multiply (GEMM) operations, reaching only 25-40% of the theoretical maximum FLOPs/s.
We propose FlashAttention-2, with better work partitioning to address these issues.
arXiv Detail & Related papers (2023-07-17T17:50:36Z) - Faster Causal Attention Over Large Sequences Through Sparse Flash
Attention [45.18552512844457]
We extend FlashAttention to accommodate a large class of attention sparsity patterns.
We increase the training speed of a transformer language model by $2.0times$ and $3.3times$ for sequences of respectively $8k$ and $16k$ tokens.
arXiv Detail & Related papers (2023-06-01T21:33:59Z) - DynaMask: Dynamic Mask Selection for Instance Segmentation [21.50329070835023]
We develop a Mask Switch Module (MSM) with negligible computational cost to select the most suitable mask resolution for each instance.
The proposed method, namely DynaMask, brings consistent and noticeable performance improvements over other state-of-the-arts at a moderate computation overhead.
arXiv Detail & Related papers (2023-03-14T13:01:25Z) - Adaptable Butterfly Accelerator for Attention-based NNs via Hardware and
Algorithm Co-design [66.39546326221176]
Attention-based neural networks have become pervasive in many AI tasks.
The use of the attention mechanism and feed-forward network (FFN) demands excessive computational and memory resources.
This paper proposes a hardware-friendly variant that adopts a unified butterfly sparsity pattern to approximate both the attention mechanism and the FFNs.
arXiv Detail & Related papers (2022-09-20T09:28:26Z) - FlashAttention: Fast and Memory-Efficient Exact Attention with
IO-Awareness [80.3586155104237]
FlashAttention is an IO-aware exact attention algorithm for Transformers.
It reduces the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip.
FlashAttention and block-sparse FlashAttention enable longer context in Transformers.
arXiv Detail & Related papers (2022-05-27T17:53:09Z) - Mask Transfiner for High-Quality Instance Segmentation [95.74244714914052]
We present Mask Transfiner for high-quality and efficient instance segmentation.
Our approach only processes detected error-prone tree nodes and self-corrects their errors in parallel.
Our code and trained models will be available at http://vis.xyz/pub/transfiner.
arXiv Detail & Related papers (2021-11-26T18:58:22Z) - BlendMask: Top-Down Meets Bottom-Up for Instance Segmentation [103.74690082121079]
In this work, we achieve improved mask prediction by effectively combining instance-level information with semantic information with lower-level fine-granularity.
Our main contribution is a blender module which draws inspiration from both top-down and bottom-up instance segmentation approaches.
BlendMask can effectively predict dense per-pixel position-sensitive instance features with very few channels, and learn attention maps for each instance with merely one convolution layer.
arXiv Detail & Related papers (2020-01-02T03:30:17Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.