Provably learning a multi-head attention layer
- URL: http://arxiv.org/abs/2402.04084v1
- Date: Tue, 6 Feb 2024 15:39:09 GMT
- Title: Provably learning a multi-head attention layer
- Authors: Sitan Chen, Yuanzhi Li
- Abstract summary: Multi-head attention layer is one of the key components of the transformer architecture that sets it apart from traditional feed-forward models.
In this work, we initiate the study of provably learning a multi-head attention layer from random examples.
We prove computational lower bounds showing that in the worst case, exponential dependence on $m$ is unavoidable.
- Score: 55.2904547651831
- License: http://arxiv.org/licenses/nonexclusive-distrib/1.0/
- Abstract: The multi-head attention layer is one of the key components of the
transformer architecture that sets it apart from traditional feed-forward
models. Given a sequence length $k$, attention matrices
$\mathbf{\Theta}_1,\ldots,\mathbf{\Theta}_m\in\mathbb{R}^{d\times d}$, and
projection matrices $\mathbf{W}_1,\ldots,\mathbf{W}_m\in\mathbb{R}^{d\times
d}$, the corresponding multi-head attention layer $F: \mathbb{R}^{k\times d}\to
\mathbb{R}^{k\times d}$ transforms length-$k$ sequences of $d$-dimensional
tokens $\mathbf{X}\in\mathbb{R}^{k\times d}$ via $F(\mathbf{X}) \triangleq
\sum^m_{i=1}
\mathrm{softmax}(\mathbf{X}\mathbf{\Theta}_i\mathbf{X}^\top)\mathbf{X}\mathbf{W}_i$.
In this work, we initiate the study of provably learning a multi-head attention
layer from random examples and give the first nontrivial upper and lower bounds
for this problem:
- Provided $\{\mathbf{W}_i, \mathbf{\Theta}_i\}$ satisfy certain
non-degeneracy conditions, we give a $(dk)^{O(m^3)}$-time algorithm that learns
$F$ to small error given random labeled examples drawn uniformly from $\{\pm
1\}^{k\times d}$.
- We prove computational lower bounds showing that in the worst case,
exponential dependence on $m$ is unavoidable.
We focus on Boolean $\mathbf{X}$ to mimic the discrete nature of tokens in
large language models, though our techniques naturally extend to standard
continuous settings, e.g. Gaussian. Our algorithm, which is centered around
using examples to sculpt a convex body containing the unknown parameters, is a
significant departure from existing provable algorithms for learning
feedforward networks, which predominantly exploit algebraic and rotation
invariance properties of the Gaussian distribution. In contrast, our analysis
is more flexible as it primarily relies on various upper and lower tail bounds
for the input distribution and "slices" thereof.
Related papers
- The Communication Complexity of Approximating Matrix Rank [50.6867896228563]
We show that this problem has randomized communication complexity $Omega(frac1kcdot n2log|mathbbF|)$.
As an application, we obtain an $Omega(frac1kcdot n2log|mathbbF|)$ space lower bound for any streaming algorithm with $k$ passes.
arXiv Detail & Related papers (2024-10-26T06:21:42Z) - Provable Acceleration of Nesterov's Accelerated Gradient for Rectangular Matrix Factorization and Linear Neural Networks [46.04785603483612]
We prove that Nesterov's accelerated gradient attains an complexity $O(kappalogfrac1epsilon)$.
In particular, we prove that NAG can also attain an accelerated linear convergence rate.
arXiv Detail & Related papers (2024-10-12T20:33:37Z) - In-depth Analysis of Low-rank Matrix Factorisation in a Federated Setting [21.002519159190538]
We analyze a distributed algorithm to compute a low-rank matrix factorization on $N$ clients.
We obtain a global $mathbfV$ in $mathbbRd times r$ common to all clients and a local $mathbfUi$ in $mathbbRn_itimes r$.
arXiv Detail & Related papers (2024-09-13T12:28:42Z) - Optimal Estimator for Linear Regression with Shuffled Labels [17.99906229036223]
This paper considers the task of linear regression with shuffled labels.
$mathbf Y in mathbb Rntimes m, mathbf Pi in mathbb Rntimes p, mathbf B in mathbb Rptimes m$, and $mathbf Win mathbb Rntimes m$, respectively.
arXiv Detail & Related papers (2023-10-02T16:44:47Z) - Fast $(1+\varepsilon)$-Approximation Algorithms for Binary Matrix
Factorization [54.29685789885059]
We introduce efficient $(1+varepsilon)$-approximation algorithms for the binary matrix factorization (BMF) problem.
The goal is to approximate $mathbfA$ as a product of low-rank factors.
Our techniques generalize to other common variants of the BMF problem.
arXiv Detail & Related papers (2023-06-02T18:55:27Z) - Learning a Single Neuron with Adversarial Label Noise via Gradient
Descent [50.659479930171585]
We study a function of the form $mathbfxmapstosigma(mathbfwcdotmathbfx)$ for monotone activations.
The goal of the learner is to output a hypothesis vector $mathbfw$ that $F(mathbbw)=C, epsilon$ with high probability.
arXiv Detail & Related papers (2022-06-17T17:55:43Z) - Beyond Independent Measurements: General Compressed Sensing with GNN
Application [4.924126492174801]
We consider the problem of recovering a structured signal $mathbfx in mathbbRn$ from noisy cone observations.
We show that the effective rank of $mathbfB$ may be used as a surrogate for the number of measurements.
arXiv Detail & Related papers (2021-10-30T20:35:56Z) - Fast Graph Sampling for Short Video Summarization using Gershgorin Disc
Alignment [52.577757919003844]
We study the problem of efficiently summarizing a short video into several paragraphs, leveraging recent progress in fast graph sampling.
Experimental results show that our algorithm achieves comparable video summarization as state-of-the-art methods, at a substantially reduced complexity.
arXiv Detail & Related papers (2021-10-21T18:43:00Z) - Random matrices in service of ML footprint: ternary random features with
no performance loss [55.30329197651178]
We show that the eigenspectrum of $bf K$ is independent of the distribution of the i.i.d. entries of $bf w$.
We propose a novel random technique, called Ternary Random Feature (TRF)
The computation of the proposed random features requires no multiplication and a factor of $b$ less bits for storage compared to classical random features.
arXiv Detail & Related papers (2021-10-05T09:33:49Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.