Attention with Trained Embeddings Provably Selects Important Tokens
- URL: http://arxiv.org/abs/2505.17282v3
- Date: Wed, 25 Jun 2025 15:19:05 GMT
- Title: Attention with Trained Embeddings Provably Selects Important Tokens
- Authors: Diyuan Wu, Aleksandr Shevchenko, Samet Oymak, Marco Mondelli,
- Abstract summary: Token embeddings play a crucial role in language modeling but, despite this practical relevance, their theoretical understanding remains limited.<n>Our paper addresses the gap by characterizing the structure of embeddings obtained via gradient descent.<n>Experiments on real-world datasets (IMDB, Yelp) exhibit a phenomenology close to that unveiled by our theory.
- Score: 73.77633297039097
- License: http://creativecommons.org/licenses/by/4.0/
- Abstract: Token embeddings play a crucial role in language modeling but, despite this practical relevance, their theoretical understanding remains limited. Our paper addresses the gap by characterizing the structure of embeddings obtained via gradient descent. Specifically, we consider a one-layer softmax attention model with a linear head for binary classification, i.e., $\texttt{Softmax}( p^\top E_X^\top ) E_X v = \frac{ \sum_{i=1}^T \exp(p^\top E_{x_i}) E_{x_i}^\top v}{\sum_{j=1}^T \exp(p^\top E_{x_{j}}) }$, where $E_X = [ E_{x_1} , \dots, E_{x_T} ]^\top$ contains the embeddings of the input sequence, $p$ is the embedding of the $\mathrm{\langle cls \rangle}$ token and $v$ the output vector. First, we show that, already after a single step of gradient training with the logistic loss, the embeddings $E_X$ capture the importance of tokens in the dataset by aligning with the output vector $v$ proportionally to the frequency with which the corresponding tokens appear in the dataset. Then, after training $p$ via gradient flow until convergence, the softmax selects the important tokens in the sentence (i.e., those that are predictive of the label), and the resulting $\mathrm{\langle cls \rangle}$ embedding maximizes the margin for such a selection. Experiments on real-world datasets (IMDB, Yelp) exhibit a phenomenology close to that unveiled by our theory.
Related papers
- Learning Networks from Wide-Sense Stationary Stochastic Processes [7.59499154221528]
A key inference problem here is to learn edge connectivity from node outputs (potentials)<n>We use a Whittle's maximum likelihood estimator (MLE) to learn the support of $Last$ from temporally correlated samples.<n>We show that the MLE problem is strictly convex, admitting a unique solution.
arXiv Detail & Related papers (2024-12-04T23:14:00Z) - On Understanding Attention-Based In-Context Learning for Categorical Data [49.40350941996942]
We develop a network composed of attention blocks, with each block employing a self-attention layer followed by a cross-attention layer, with associated skip connections.<n>This model can exactly perform multi-step functional GD inference for in-context inference with categorical observations.
arXiv Detail & Related papers (2024-05-27T15:03:21Z) - Mechanics of Next Token Prediction with Self-Attention [41.82477691012942]
Transformer-based language models are trained on large datasets to predict the next token given an input sequence.
We show that training self-attention with gradient descent learns an automaton which generates the next token in two distinct steps.
We hope that these findings shed light on how self-attention processes sequential data and pave the path toward demystifying more complex architectures.
arXiv Detail & Related papers (2024-03-12T21:15:38Z) - Provably learning a multi-head attention layer [55.2904547651831]
Multi-head attention layer is one of the key components of the transformer architecture that sets it apart from traditional feed-forward models.
In this work, we initiate the study of provably learning a multi-head attention layer from random examples.
We prove computational lower bounds showing that in the worst case, exponential dependence on $m$ is unavoidable.
arXiv Detail & Related papers (2024-02-06T15:39:09Z) - Solving Quadratic Systems with Full-Rank Matrices Using Sparse or Generative Priors [33.0212223058894]
The problem of recovering a signal from a quadratic system $y_i=boldsymbol xtopboldsymbol A_iboldsymbol x, i=1,ldots,m$ with full-rank matrices $boldsymbol A_i$ frequently arises in applications such as unassigned distance geometry and sub-wavelength imaging.
This paper addresses the high-dimensional case where $mll n$ incorporating by prior knowledge of $boldsymbol x$.
arXiv Detail & Related papers (2023-09-16T16:00:07Z) - Max-Margin Token Selection in Attention Mechanism [34.11955962489916]
We prove that running gradient descent on $boldsymbolp$ or equivalently $boldW$ converges in to a max-margin solution.
Remarkably, our results are applicable to general data and precisely as an optimal token selection.
arXiv Detail & Related papers (2023-06-23T16:35:46Z) - High-dimensional Asymptotics of Feature Learning: How One Gradient Step
Improves the Representation [89.21686761957383]
We study the first gradient descent step on the first-layer parameters $boldsymbolW$ in a two-layer network.
Our results demonstrate that even one step can lead to a considerable advantage over random features.
arXiv Detail & Related papers (2022-05-03T12:09:59Z) - Self-training Converts Weak Learners to Strong Learners in Mixture
Models [86.7137362125503]
We show that a pseudolabeler $boldsymbolbeta_mathrmpl$ can achieve classification error at most $C_mathrmerr$.
We additionally show that by running gradient descent on the logistic loss one can obtain a pseudolabeler $boldsymbolbeta_mathrmpl$ with classification error $C_mathrmerr$ using only $O(d)$ labeled examples.
arXiv Detail & Related papers (2021-06-25T17:59:16Z) - Learning a Latent Simplex in Input-Sparsity Time [58.30321592603066]
We consider the problem of learning a latent $k$-vertex simplex $KsubsetmathbbRdtimes n$, given access to $AinmathbbRdtimes n$.
We show that the dependence on $k$ in the running time is unnecessary given a natural assumption about the mass of the top $k$ singular values of $A$.
arXiv Detail & Related papers (2021-05-17T16:40:48Z) - Nearly Horizon-Free Offline Reinforcement Learning [97.36751930393245]
We revisit offline reinforcement learning on episodic time-homogeneous Markov Decision Processes with $S$ states, $A$ actions and planning horizon $H$.
We obtain the first set of nearly $H$-free sample complexity bounds for evaluation and planning using the empirical MDPs.
arXiv Detail & Related papers (2021-03-25T18:52:17Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.