Transformers Learn to Implement Multi-step Gradient Descent with Chain of Thought
- URL: http://arxiv.org/abs/2502.21212v1
- Date: Fri, 28 Feb 2025 16:40:38 GMT
- Title: Transformers Learn to Implement Multi-step Gradient Descent with Chain of Thought
- Authors: Jianhao Huang, Zixuan Wang, Jason D. Lee,
- Abstract summary: Chain of Thought (CoT) prompting has been shown to significantly improve the performance of large language models (LLMs)<n>We study the training dynamics of transformers over a CoT objective on an in-context weight prediction task for linear regression.
- Score: 46.71030329872635
- License: http://arxiv.org/licenses/nonexclusive-distrib/1.0/
- Abstract: Chain of Thought (CoT) prompting has been shown to significantly improve the performance of large language models (LLMs), particularly in arithmetic and reasoning tasks, by instructing the model to produce intermediate reasoning steps. Despite the remarkable empirical success of CoT and its theoretical advantages in enhancing expressivity, the mechanisms underlying CoT training remain largely unexplored. In this paper, we study the training dynamics of transformers over a CoT objective on an in-context weight prediction task for linear regression. We prove that while a one-layer linear transformer without CoT can only implement a single step of gradient descent (GD) and fails to recover the ground-truth weight vector, a transformer with CoT prompting can learn to perform multi-step GD autoregressively, achieving near-exact recovery. Furthermore, we show that the trained transformer effectively generalizes on the unseen data. With our technique, we also show that looped transformers significantly improve final performance compared to transformers without looping in the in-context learning of linear regression. Empirically, we demonstrate that CoT prompting yields substantial performance improvements.
Related papers
- Enhancing Auto-regressive Chain-of-Thought through Loop-Aligned Reasoning [47.06427150903487]
Chain-of-Thought (CoT) prompting has emerged as a powerful technique for enhancing language model's reasoning capabilities.<n>Looped Transformers possess remarkable length generalization capabilities, but their limited generality and adaptability prevent them from serving as an alternative to auto-regressive solutions.<n>We propose RELAY to better leverage the strengths of Looped Transformers.
arXiv Detail & Related papers (2025-02-12T15:17:04Z) - Can Looped Transformers Learn to Implement Multi-step Gradient Descent for In-context Learning? [69.4145579827826]
We show a fast flow on the regression loss despite the gradient non-ity algorithms for our convergence landscape.
This is the first theoretical analysis for multi-layer Transformer in this setting.
arXiv Detail & Related papers (2024-10-10T18:29:05Z) - Training Nonlinear Transformers for Chain-of-Thought Inference: A Theoretical Generalization Analysis [82.51626700527837]
Chain-of-shift (CoT) is an efficient method that enables the reasoning ability of large language models by augmenting the query using examples with multiple intermediate steps.
We show that despite the theoretical success of CoT, it fails to provide an accurate generalization when CoT does.
arXiv Detail & Related papers (2024-10-03T03:12:51Z) - Transformers Handle Endogeneity in In-Context Linear Regression [34.458004744956334]
We show that transformers inherently possess a mechanism to handle endogeneity effectively using instrumental variables (IV)<n>We propose an in-context pretraining scheme and provide theoretical guarantees showing that the global minimizer of the pre-training loss achieves a small excess loss.
arXiv Detail & Related papers (2024-10-02T06:21:04Z) - Emergent Agentic Transformer from Chain of Hindsight Experience [96.56164427726203]
We show that a simple transformer-based model performs competitively with both temporal-difference and imitation-learning-based approaches.
This is the first time that a simple transformer-based model performs competitively with both temporal-difference and imitation-learning-based approaches.
arXiv Detail & Related papers (2023-05-26T00:43:02Z) - Transformers learn in-context by gradient descent [58.24152335931036]
Training Transformers on auto-regressive objectives is closely related to gradient-based meta-learning formulations.
We show how trained Transformers become mesa-optimizers i.e. learn models by gradient descent in their forward pass.
arXiv Detail & Related papers (2022-12-15T09:21:21Z) - Finetuning Pretrained Transformers into RNNs [81.72974646901136]
Transformers have outperformed recurrent neural networks (RNNs) in natural language generation.
A linear-complexity recurrent variant has proven well suited for autoregressive generation.
This work aims to convert a pretrained transformer into its efficient recurrent counterpart.
arXiv Detail & Related papers (2021-03-24T10:50:43Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.