Mechanics of Next Token Prediction with Self-Attention
- URL: http://arxiv.org/abs/2403.08081v1
- Date: Tue, 12 Mar 2024 21:15:38 GMT
- Title: Mechanics of Next Token Prediction with Self-Attention
- Authors: Yingcong Li, Yixiao Huang, M. Emrullah Ildiz, Ankit Singh Rawat, Samet
Oymak
- Abstract summary: Transformer-based language models are trained on large datasets to predict the next token given an input sequence.
We show that training self-attention with gradient descent learns an automaton which generates the next token in two distinct steps.
We hope that these findings shed light on how self-attention processes sequential data and pave the path toward demystifying more complex architectures.
- Score: 41.82477691012942
- License: http://creativecommons.org/licenses/by/4.0/
- Abstract: Transformer-based language models are trained on large datasets to predict
the next token given an input sequence. Despite this simple training objective,
they have led to revolutionary advances in natural language processing.
Underlying this success is the self-attention mechanism. In this work, we ask:
$\textit{What}$ $\textit{does}$ $\textit{a}$ $\textit{single}$
$\textit{self-attention}$ $\textit{layer}$ $\textit{learn}$ $\textit{from}$
$\textit{next-token}$ $\textit{prediction?}$ We show that training
self-attention with gradient descent learns an automaton which generates the
next token in two distinct steps: $\textbf{(1)}$ $\textbf{Hard}$
$\textbf{retrieval:}$ Given input sequence, self-attention precisely selects
the $\textit{high-priority}$ $\textit{input}$ $\textit{tokens}$ associated with
the last input token. $\textbf{(2)}$ $\textbf{Soft}$ $\textbf{composition:}$ It
then creates a convex combination of the high-priority tokens from which the
next token can be sampled. Under suitable conditions, we rigorously
characterize these mechanics through a directed graph over tokens extracted
from the training data. We prove that gradient descent implicitly discovers the
strongly-connected components (SCC) of this graph and self-attention learns to
retrieve the tokens that belong to the highest-priority SCC available in the
context window. Our theory relies on decomposing the model weights into a
directional component and a finite component that correspond to hard retrieval
and soft composition steps respectively. This also formalizes a related
implicit bias formula conjectured in [Tarzanagh et al. 2023]. We hope that
these findings shed light on how self-attention processes sequential data and
pave the path toward demystifying more complex architectures.
Related papers
- Attention with Trained Embeddings Provably Selects Important Tokens [73.77633297039097]
Token embeddings play a crucial role in language modeling but, despite this practical relevance, their theoretical understanding remains limited.<n>Our paper addresses the gap by characterizing the structure of embeddings obtained via gradient descent.<n>Experiments on real-world datasets (IMDB, Yelp) exhibit a phenomenology close to that unveiled by our theory.
arXiv Detail & Related papers (2025-05-22T21:00:09Z) - $\text{M}^{\text{3}}$: A Modular World Model over Streams of Tokens [51.65485693709418]
Token-based world models emerged as a promising modular framework, modeling dynamics over token streams while optimizing tokenization separately.
In this paper, we introduce $textMtext3$, a $textbfm$odular $textbfw$orld $textbfm$odel that extends this framework.
$textMtext3$ achieves several improvements from existing literature to enhance agent performance.
arXiv Detail & Related papers (2025-02-17T08:06:10Z) - ZETA: Leveraging Z-order Curves for Efficient Top-k Attention [22.90397380324185]
We propose ZETA to enable parallel querying of past tokens for entire sequences.
ZETA matches the performance of standard attention on the synthetic textscMulti-Query Associative Recall task.
arXiv Detail & Related papers (2025-01-24T15:33:05Z) - Reasoning to Attend: Try to Understand How <SEG> Token Works [44.33848900059659]
We show that the $texttSEG>$ token contributes to semantic similarity within image-text pairs.
We present READ, which facilitates LMMs' resilient $textbfREA$soning capability of where to atten$textbfD$ under the guidance of highly activated points.
arXiv Detail & Related papers (2024-12-23T17:44:05Z) - Towards Understanding the Universality of Transformers for Next-Token Prediction [20.300660057193017]
Causal Transformers are trained to predict the next token for a given context.
We take a step towards understanding this phenomenon by studying the approximation ability of Transformers for next-token prediction.
arXiv Detail & Related papers (2024-10-03T21:42:21Z) - Inertial Confinement Fusion Forecasting via Large Language Models [48.76222320245404]
In this study, we introduce $textbfLPI-LLM$, a novel integration of Large Language Models (LLMs) with classical reservoir computing paradigms.
We propose the $textitLLM-anchored Reservoir$, augmented with a $textitFusion-specific Prompt$, enabling accurate forecasting of $textttLPI$-generated-hot electron dynamics during implosion.
We also present $textbfLPI4AI$, the first $textttLPI$ benchmark based
arXiv Detail & Related papers (2024-07-15T05:46:44Z) - Creating an AI Observer: Generative Semantic Workspaces [4.031100721019478]
We introduce the $textbf[G]$enerative $textbf[S]$emantic $textbf[W]$orkspace (GSW))
GSW creates a generative-style Semantic framework, as opposed to a traditionally predefined set of lexicon labels.
arXiv Detail & Related papers (2024-06-07T00:09:13Z) - Transformer In-Context Learning for Categorical Data [51.23121284812406]
We extend research on understanding Transformers through the lens of in-context learning with functional data by considering categorical outcomes, nonlinear underlying models, and nonlinear attention.
We present what is believed to be the first real-world demonstration of this few-shot-learning methodology, using the ImageNet dataset.
arXiv Detail & Related papers (2024-05-27T15:03:21Z) - Provably learning a multi-head attention layer [55.2904547651831]
Multi-head attention layer is one of the key components of the transformer architecture that sets it apart from traditional feed-forward models.
In this work, we initiate the study of provably learning a multi-head attention layer from random examples.
We prove computational lower bounds showing that in the worst case, exponential dependence on $m$ is unavoidable.
arXiv Detail & Related papers (2024-02-06T15:39:09Z) - Object Recognition as Next Token Prediction [99.40793702627396]
We present an approach to pose object recognition as next token prediction.
The idea is to apply a language decoder that auto-regressively predicts the text tokens from image embeddings to form labels.
arXiv Detail & Related papers (2023-12-04T18:58:40Z) - Think before you speak: Training Language Models With Pause Tokens [73.61375226378712]
Language models generate responses by producing a series of tokens in immediate succession.
What if instead we were to let the model manipulate say, $K+10$ hidden vectors, before it outputs the $(K+1)th$ token?
We operationalize this idea by performing training and inference on language models with a (learnable) $textitpause$ token.
arXiv Detail & Related papers (2023-10-03T17:32:41Z) - Blessing of Class Diversity in Pre-training [54.335530406959435]
We prove that when the classes of the pre-training task are sufficiently diverse, pre-training can significantly improve the sample efficiency of downstream tasks.
Our proof relies on a vector-form Rademacher complexity chain rule for composite function classes and a modified self-concordance condition.
arXiv Detail & Related papers (2022-09-07T20:10:12Z) - High-dimensional Asymptotics of Feature Learning: How One Gradient Step
Improves the Representation [89.21686761957383]
We study the first gradient descent step on the first-layer parameters $boldsymbolW$ in a two-layer network.
Our results demonstrate that even one step can lead to a considerable advantage over random features.
arXiv Detail & Related papers (2022-05-03T12:09:59Z) - Categorical Representation Learning: Morphism is All You Need [0.0]
We provide a construction for categorical representation learning and introduce the foundations of "$textitcategorifier$"
Every object in a dataset $mathcalS$ can be represented as a vector in $mathbbRn$ by an $textitencoding map$ $E: mathcalObj(mathcalS)tomathbbRn$.
As a proof of concept, we provide an example of a text translator equipped with our technology, showing that our categorical learning model outperforms the
arXiv Detail & Related papers (2021-03-26T23:47:15Z) - Two-way kernel matrix puncturing: towards resource-efficient PCA and
spectral clustering [43.50783459690612]
The method consists in randomly "puncturing" both the data matrix $XinmathbbCptimes n$ and its corresponding kernel (Gram) matrix $K$ through Bernoulli masks.
We empirically confirm on GAN-generated image databases, that it is possible to drastically puncture the data, thereby providing possibly huge computational and storage gains.
arXiv Detail & Related papers (2021-02-24T14:01:58Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.