Learning-to-Cache: Accelerating Diffusion Transformer via Layer Caching
- URL: http://arxiv.org/abs/2406.01733v2
- Date: Sat, 16 Nov 2024 07:43:28 GMT
- Title: Learning-to-Cache: Accelerating Diffusion Transformer via Layer Caching
- Authors: Xinyin Ma, Gongfan Fang, Michael Bi Mi, Xinchao Wang,
- Abstract summary: We make an interesting and somehow surprising observation: the computation of a large proportion of layers in the diffusion transformer, through a caching mechanism, can be readily removed even without updating the model parameters.
We introduce a novel scheme, named Learningto-Cache (L2C), that learns to conduct caching in a dynamic manner for diffusion transformers.
Experimental results show that L2C largely outperforms samplers such as DDIM and DPM-r, alongside prior cache-based methods at the same inference speed.
- Score: 56.286064975443026
- License:
- Abstract: Diffusion Transformers have recently demonstrated unprecedented generative capabilities for various tasks. The encouraging results, however, come with the cost of slow inference, since each denoising step requires inference on a transformer model with a large scale of parameters. In this study, we make an interesting and somehow surprising observation: the computation of a large proportion of layers in the diffusion transformer, through introducing a caching mechanism, can be readily removed even without updating the model parameters. In the case of U-ViT-H/2, for example, we may remove up to 93.68% of the computation in the cache steps (46.84% for all steps), with less than 0.01 drop in FID. To achieve this, we introduce a novel scheme, named Learning-to-Cache (L2C), that learns to conduct caching in a dynamic manner for diffusion transformers. Specifically, by leveraging the identical structure of layers in transformers and the sequential nature of diffusion, we explore redundant computations between timesteps by treating each layer as the fundamental unit for caching. To address the challenge of the exponential search space in deep models for identifying layers to cache and remove, we propose a novel differentiable optimization objective. An input-invariant yet timestep-variant router is then optimized, which can finally produce a static computation graph. Experimental results show that L2C largely outperforms samplers such as DDIM and DPM-Solver, alongside prior cache-based methods at the same inference speed. Code is available at https://github.com/horseee/learning-to-cache
Related papers
- SmoothCache: A Universal Inference Acceleration Technique for Diffusion Transformers [4.7170474122879575]
Diffusion Transformers (DiT) have emerged as powerful generative models for various tasks, including image, video, and speech synthesis.
We introduce SmoothCache, a model-agnostic inference acceleration technique for DiT architectures.
Our experiments demonstrate that SmoothCache achieves 71% 8% to speed up while maintaining or even improving generation quality across diverse modalities.
arXiv Detail & Related papers (2024-11-15T16:24:02Z) - Dynamic Diffusion Transformer [67.13876021157887]
Diffusion Transformer (DiT) has demonstrated superior performance but suffers from substantial computational costs.
We propose Dynamic Diffusion Transformer (DyDiT), an architecture that dynamically adjusts its computation along both timestep and spatial dimensions during generation.
With 3% additional fine-tuning, our method reduces the FLOPs of DiT-XL by 51%, accelerates generation by 1.73, and achieves a competitive FID score of 2.07 on ImageNet.
arXiv Detail & Related papers (2024-10-04T14:14:28Z) - Token Caching for Diffusion Transformer Acceleration [30.437462937127773]
TokenCache is a novel post-training acceleration method for diffusion transformers.
It reduces redundant computations among tokens across inference steps.
We show that TokenCache achieves an effective trade-off between generation quality and inference speed for diffusion transformers.
arXiv Detail & Related papers (2024-09-27T08:05:34Z) - FORA: Fast-Forward Caching in Diffusion Transformer Acceleration [39.51519525071639]
Diffusion transformers (DiT) have become the de facto choice for generating high-quality images and videos.
Fast-FORward CAching (FORA) is designed to accelerate DiT by exploiting the repetitive nature of the diffusion process.
arXiv Detail & Related papers (2024-07-01T16:14:37Z) - Cache Me if You Can: Accelerating Diffusion Models through Block Caching [67.54820800003375]
A large image-to-image network has to be applied many times to iteratively refine an image from random noise.
We investigate the behavior of the layers within the network and find that 1) the layers' output changes smoothly over time, 2) the layers show distinct patterns of change, and 3) the change from step to step is often very small.
We propose a technique to automatically determine caching schedules based on each block's changes over timesteps.
arXiv Detail & Related papers (2023-12-06T00:51:38Z) - DeepCache: Accelerating Diffusion Models for Free [65.02607075556742]
DeepCache is a training-free paradigm that accelerates diffusion models from the perspective of model architecture.
DeepCache capitalizes on the inherent temporal redundancy observed in the sequential denoising steps of diffusion models.
Under the same throughput, DeepCache effectively achieves comparable or even marginally improved results with DDIM or PLMS.
arXiv Detail & Related papers (2023-12-01T17:01:06Z) - Winner-Take-All Column Row Sampling for Memory Efficient Adaptation of Language Model [89.8764435351222]
We propose a new family of unbiased estimators called WTA-CRS, for matrix production with reduced variance.
Our work provides both theoretical and experimental evidence that, in the context of tuning transformers, our proposed estimators exhibit lower variance compared to existing ones.
arXiv Detail & Related papers (2023-05-24T15:52:08Z) - Layer Pruning on Demand with Intermediate CTC [50.509073206630994]
We present a training and pruning method for ASR based on the connectionist temporal classification (CTC)
We show that a Transformer-CTC model can be pruned in various depth on demand, improving real-time factor from 0.005 to 0.002 on GPU.
arXiv Detail & Related papers (2021-06-17T02:40:18Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.