Towards a Theoretical Understanding of the 'Reversal Curse' via Training Dynamics
- URL: http://arxiv.org/abs/2405.04669v2
- Date: Mon, 28 Oct 2024 05:05:04 GMT
- Title: Towards a Theoretical Understanding of the 'Reversal Curse' via Training Dynamics
- Authors: Hanlin Zhu, Baihe Huang, Shaolun Zhang, Michael Jordan, Jiantao Jiao, Yuandong Tian, Stuart Russell,
- Abstract summary: Auto-regressive large language models (LLMs) show impressive capacities to solve many complex reasoning tasks.
LLMs fail to conclude '$B gets A$' during inference even if the two sentences are semantically identical.
We theoretically analyze the reversal curse via the training dynamics of gradient descent for two auto-regressive models.
- Score: 45.69328374321502
- License:
- Abstract: Auto-regressive large language models (LLMs) show impressive capacities to solve many complex reasoning tasks while struggling with some simple logical reasoning tasks such as inverse search: when trained on '$A \to B$' (e.g., 'Tom is the parent of John'), LLM fails to directly conclude '$B \gets A$' (e.g., 'John is the child of Tom') during inference even if the two sentences are semantically identical, which is known as the 'reversal curse'. In this paper, we theoretically analyze the reversal curse via the training dynamics of (stochastic) gradient descent for two auto-regressive models: (1) a bilinear model that can be viewed as a simplification of a one-layer transformer; (2) one-layer transformers under certain assumptions. Our analysis reveals that for both models, the reversal curse is a consequence of the (effective) model weights 'asymmetry', i.e., the increase of weights from a token $A$ to token $B$ during training does not necessarily cause the increase of the weights from $B$ to $A$, which is caused by the training dynamics under certain choice of loss function and the optimization space of model parameters. Moreover, our analysis can be naturally applied to other logical reasoning tasks such as chain-of-thought (COT), which provides a new perspective different from previous work that focuses on expressivity. Finally, we conduct experiments to validate our theory on multi-layer transformers under different settings. Our code is available at https://github.com/marlo-z/reversal_curse_analysis/.
Related papers
- On the Role of Depth and Looping for In-Context Learning with Task Diversity [69.4145579827826]
We study in-context learning for linear regression with diverse tasks.
We show that multilayer Transformers are not robust to even distributional shifts as small as $O(e-L)$ in Wasserstein distance.
arXiv Detail & Related papers (2024-10-29T03:27:56Z) - Unveiling Induction Heads: Provable Training Dynamics and Feature Learning in Transformers [54.20763128054692]
We study how a two-attention-layer transformer is trained to perform ICL on $n$-gram Markov chain data.
We prove that the gradient flow with respect to a cross-entropy ICL loss converges to a limiting model.
arXiv Detail & Related papers (2024-09-09T18:10:26Z) - Towards Better Understanding of In-Context Learning Ability from In-Context Uncertainty Quantification [7.869708570399577]
We consider a bi-objective prediction task of predicting both the conditional expectation $mathbbE[Y|X]$ and the conditional variance Var$(Y|X)$.
Theoretically, we show that the trained Transformer reaches near Bayes-optimum, suggesting the usage of the information of the training distribution.
arXiv Detail & Related papers (2024-05-24T00:08:55Z) - How do Transformers perform In-Context Autoregressive Learning? [76.18489638049545]
We train a Transformer model on a simple next token prediction task.
We show how a trained Transformer predicts the next token by first learning $W$ in-context, then applying a prediction mapping.
arXiv Detail & Related papers (2024-02-08T16:24:44Z) - Local Convergence of Approximate Newton Method for Two Layer Nonlinear
Regression [21.849997443967705]
Two-layer regression problem has been well-studied in prior works.
First layer is activated by a ReLU unit, and the second layer is activated by a softmax unit.
We prove that the loss function for the Hessian matrix is positive definite and Lipschitz continuous under certain assumptions.
arXiv Detail & Related papers (2023-11-26T19:19:02Z) - A Unified Scheme of ResNet and Softmax [8.556540804058203]
We provide a theoretical analysis of the regression problem: $| langle exp(Ax) + A x, bf 1_n rangle-1 ( exp(Ax) + Ax )
This regression problem is a unified scheme that combines softmax regression and ResNet, which has never been done before.
arXiv Detail & Related papers (2023-09-23T21:41:01Z) - Characterizing Datapoints via Second-Split Forgetting [93.99363547536392]
We propose $$-second-$split$ $forgetting$ $time$ (SSFT), a complementary metric that tracks the epoch (if any) after which an original training example is forgotten.
We demonstrate that $mislabeled$ examples are forgotten quickly, and seemingly $rare$ examples are forgotten comparatively slowly.
SSFT can (i) help to identify mislabeled samples, the removal of which improves generalization; and (ii) provide insights about failure modes.
arXiv Detail & Related papers (2022-10-26T21:03:46Z) - Gradient-based Counterfactual Explanations using Tractable Probabilistic
Models [20.770168600755877]
Counterfactual examples are an appealing class of explanations for machine learning models.
Current approaches primarily solve this task by a complex optimization.
We propose a novel approach to deal with these problems using only two gradient computations.
arXiv Detail & Related papers (2022-05-16T16:05:43Z) - Understanding the Difficulty of Training Transformers [120.99980924577787]
We show that unbalanced gradients are not the root cause of the instability of training.
We propose Admin to stabilize the early stage's training and unleash its full potential in the late stage.
arXiv Detail & Related papers (2020-04-17T13:59:07Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.