Mechanics of Next Token Prediction with Self-Attention
- URL: http://arxiv.org/abs/2403.08081v1
- Date: Tue, 12 Mar 2024 21:15:38 GMT
- Title: Mechanics of Next Token Prediction with Self-Attention
- Authors: Yingcong Li, Yixiao Huang, M. Emrullah Ildiz, Ankit Singh Rawat, Samet
Oymak
- Abstract summary: Transformer-based language models are trained on large datasets to predict the next token given an input sequence.
We show that training self-attention with gradient descent learns an automaton which generates the next token in two distinct steps.
We hope that these findings shed light on how self-attention processes sequential data and pave the path toward demystifying more complex architectures.
- Score: 41.82477691012942
- License: http://creativecommons.org/licenses/by/4.0/
- Abstract: Transformer-based language models are trained on large datasets to predict
the next token given an input sequence. Despite this simple training objective,
they have led to revolutionary advances in natural language processing.
Underlying this success is the self-attention mechanism. In this work, we ask:
$\textit{What}$ $\textit{does}$ $\textit{a}$ $\textit{single}$
$\textit{self-attention}$ $\textit{layer}$ $\textit{learn}$ $\textit{from}$
$\textit{next-token}$ $\textit{prediction?}$ We show that training
self-attention with gradient descent learns an automaton which generates the
next token in two distinct steps: $\textbf{(1)}$ $\textbf{Hard}$
$\textbf{retrieval:}$ Given input sequence, self-attention precisely selects
the $\textit{high-priority}$ $\textit{input}$ $\textit{tokens}$ associated with
the last input token. $\textbf{(2)}$ $\textbf{Soft}$ $\textbf{composition:}$ It
then creates a convex combination of the high-priority tokens from which the
next token can be sampled. Under suitable conditions, we rigorously
characterize these mechanics through a directed graph over tokens extracted
from the training data. We prove that gradient descent implicitly discovers the
strongly-connected components (SCC) of this graph and self-attention learns to
retrieve the tokens that belong to the highest-priority SCC available in the
context window. Our theory relies on decomposing the model weights into a
directional component and a finite component that correspond to hard retrieval
and soft composition steps respectively. This also formalizes a related
implicit bias formula conjectured in [Tarzanagh et al. 2023]. We hope that
these findings shed light on how self-attention processes sequential data and
pave the path toward demystifying more complex architectures.
Related papers
- Creating an AI Observer: Generative Semantic Workspaces [4.031100721019478]
We introduce the $textbf[G]$enerative $textbf[S]$emantic $textbf[W]$orkspace (GSW))
GSW creates a generative-style Semantic framework, as opposed to a traditionally predefined set of lexicon labels.
arXiv Detail & Related papers (2024-06-07T00:09:13Z) - Transformer In-Context Learning for Categorical Data [51.23121284812406]
We extend research on understanding Transformers through the lens of in-context learning with functional data by considering categorical outcomes, nonlinear underlying models, and nonlinear attention.
We present what is believed to be the first real-world demonstration of this few-shot-learning methodology, using the ImageNet dataset.
arXiv Detail & Related papers (2024-05-27T15:03:21Z) - Object Recognition as Next Token Prediction [99.40793702627396]
We present an approach to pose object recognition as next token prediction.
The idea is to apply a language decoder that auto-regressively predicts the text tokens from image embeddings to form labels.
arXiv Detail & Related papers (2023-12-04T18:58:40Z) - Think before you speak: Training Language Models With Pause Tokens [73.61375226378712]
Language models generate responses by producing a series of tokens in immediate succession.
What if instead we were to let the model manipulate say, $K+10$ hidden vectors, before it outputs the $(K+1)th$ token?
We operationalize this idea by performing training and inference on language models with a (learnable) $textitpause$ token.
arXiv Detail & Related papers (2023-10-03T17:32:41Z) - Scan and Snap: Understanding Training Dynamics and Token Composition in
1-layer Transformer [37.37547759817417]
Transformer architecture has shown impressive performance in multiple research domains.
We analyze its SGD training dynamics for the task of next token prediction.
We prove that self-attention acts as a emphdiscriminative scanning algorithm.
arXiv Detail & Related papers (2023-05-25T15:59:13Z) - High-dimensional Asymptotics of Feature Learning: How One Gradient Step
Improves the Representation [89.21686761957383]
We study the first gradient descent step on the first-layer parameters $boldsymbolW$ in a two-layer network.
Our results demonstrate that even one step can lead to a considerable advantage over random features.
arXiv Detail & Related papers (2022-05-03T12:09:59Z) - Accurately Modeling Biased Random Walks on Weighted Graphs Using
$\textit{Node2vec+}$ [0.0]
We extend $textitnode2vec$ to $textitnode2vec+$ in a way that accounts for edge weights when calculating walk biases.
We show that $textitnode2vec+$ is more robust to additive noise than $textitnode2vec$ in weighted graphs.
We also demonstrate that $textitnode2vec+$ significantly outperforms $textitnode2vec$ on a commonly benchmarked multi-label dataset.
arXiv Detail & Related papers (2021-09-15T17:59:25Z) - Categorical Representation Learning: Morphism is All You Need [0.0]
We provide a construction for categorical representation learning and introduce the foundations of "$textitcategorifier$"
Every object in a dataset $mathcalS$ can be represented as a vector in $mathbbRn$ by an $textitencoding map$ $E: mathcalObj(mathcalS)tomathbbRn$.
As a proof of concept, we provide an example of a text translator equipped with our technology, showing that our categorical learning model outperforms the
arXiv Detail & Related papers (2021-03-26T23:47:15Z) - Adversarial Linear Contextual Bandits with Graph-Structured Side
Observations [80.95090605985042]
A learning agent repeatedly chooses from a set of $K$ actions after being presented with a $d$-dimensional context vector.
The agent incurs and observes the loss of the chosen action, but also observes the losses of its neighboring actions in the observation structures.
Two efficient algorithms are developed based on textttEXP3.
arXiv Detail & Related papers (2020-12-10T15:40:07Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.