State-space models can learn in-context by gradient descent
- URL: http://arxiv.org/abs/2410.11687v2
- Date: Tue, 18 Feb 2025 18:55:39 GMT
- Title: State-space models can learn in-context by gradient descent
- Authors: Neeraj Mohan Sushma, Yudou Tian, Harshvardhan Mestha, Nicolo Colombo, David Kappel, Anand Subramoney,
- Abstract summary: We show that state-space models can perform gradient-based learning and use it for in-context learning in much the same way as transformers.
Specifically, we prove that a single structured state-space model layer, augmented with multiplicative input and output gating, can reproduce the outputs of an implicit linear model.
We also provide novel insights into the relationship between state-space models and linear self-attention, and their ability to learn in-context.
- Score: 1.3087858009942543
- License:
- Abstract: Deep state-space models (Deep SSMs) are becoming popular as effective approaches to model sequence data. They have also been shown to be capable of in-context learning, much like transformers. However, a complete picture of how SSMs might be able to do in-context learning has been missing. In this study, we provide a direct and explicit construction to show that state-space models can perform gradient-based learning and use it for in-context learning in much the same way as transformers. Specifically, we prove that a single structured state-space model layer, augmented with multiplicative input and output gating, can reproduce the outputs of an implicit linear model with least squares loss after one step of gradient descent. We then show a straightforward extension to multi-step linear and non-linear regression tasks. We validate our construction by training randomly initialized augmented SSMs on linear and non-linear regression tasks. The empirically obtained parameters through optimization match the ones predicted analytically by the theoretical construction. Overall, we elucidate the role of input- and output-gating in recurrent architectures as the key inductive biases for enabling the expressive power typical of foundation models. We also provide novel insights into the relationship between state-space models and linear self-attention, and their ability to learn in-context.
Related papers
- Scaling Law for Stochastic Gradient Descent in Quadratically Parameterized Linear Regression [5.801904710149222]
In machine learning, the scaling law describes how the model performance improves with the model and data size scaling up.
This paper studies the scaling law over a linear regression with the model being quadratically parameterized.
As a result, in the canonical linear regression, we provide explicit separations for curves between generalization with and without feature learning, and the information-theoretical lower bound that is to parametrization method and the algorithm.
arXiv Detail & Related papers (2025-02-13T09:29:04Z) - Re-examining learning linear functions in context [1.8843687952462742]
In-context learning (ICL) has emerged as a powerful paradigm for easily adapting Large Language Models (LLMs) to various tasks.
We explore a simple model of ICL in a controlled setup with synthetic training data.
Our findings challenge the prevailing narrative that transformers adopt algorithmic approaches to learn a linear function in-context.
arXiv Detail & Related papers (2024-11-18T10:58:46Z) - Theoretical Foundations of Deep Selective State-Space Models [13.971499161967083]
Deep SSMs demonstrate outstanding performance across a diverse set of domains.
Recent developments show that if the linear recurrence powering SSMs allows for multiplicative interactions between inputs and hidden states.
We show that when random linear recurrences are equipped with simple input-controlled transitions, then the hidden state is provably a low-dimensional projection of a powerful mathematical object.
arXiv Detail & Related papers (2024-02-29T11:20:16Z) - SIP: Injecting a Structural Inductive Bias into a Seq2Seq Model by Simulation [75.14793516745374]
We show how a structural inductive bias can be efficiently injected into a seq2seq model by pre-training it to simulate structural transformations on synthetic data.
Our experiments show that our method imparts the desired inductive bias, resulting in better few-shot learning for FST-like tasks.
arXiv Detail & Related papers (2023-10-01T21:19:12Z) - Latent Traversals in Generative Models as Potential Flows [113.4232528843775]
We propose to model latent structures with a learned dynamic potential landscape.
Inspired by physics, optimal transport, and neuroscience, these potential landscapes are learned as physically realistic partial differential equations.
Our method achieves both more qualitatively and quantitatively disentangled trajectories than state-of-the-art baselines.
arXiv Detail & Related papers (2023-04-25T15:53:45Z) - Transformers learn in-context by gradient descent [58.24152335931036]
Training Transformers on auto-regressive objectives is closely related to gradient-based meta-learning formulations.
We show how trained Transformers become mesa-optimizers i.e. learn models by gradient descent in their forward pass.
arXiv Detail & Related papers (2022-12-15T09:21:21Z) - What learning algorithm is in-context learning? Investigations with
linear models [87.91612418166464]
We investigate the hypothesis that transformer-based in-context learners implement standard learning algorithms implicitly.
We show that trained in-context learners closely match the predictors computed by gradient descent, ridge regression, and exact least-squares regression.
Preliminary evidence that in-context learners share algorithmic features with these predictors.
arXiv Detail & Related papers (2022-11-28T18:59:51Z) - Merging Two Cultures: Deep and Statistical Learning [3.15863303008255]
Merging the two cultures of deep and statistical learning provides insights into structured high-dimensional data.
We show that prediction, optimisation and uncertainty can be achieved using probabilistic methods at the output layer of the model.
arXiv Detail & Related papers (2021-10-22T02:57:21Z) - Kernel and Rich Regimes in Overparametrized Models [69.40899443842443]
We show that gradient descent on overparametrized multilayer networks can induce rich implicit biases that are not RKHS norms.
We also demonstrate this transition empirically for more complex matrix factorization models and multilayer non-linear networks.
arXiv Detail & Related papers (2020-02-20T15:43:02Z) - Learning Bijective Feature Maps for Linear ICA [73.85904548374575]
We show that existing probabilistic deep generative models (DGMs) which are tailor-made for image data, underperform on non-linear ICA tasks.
To address this, we propose a DGM which combines bijective feature maps with a linear ICA model to learn interpretable latent structures for high-dimensional data.
We create models that converge quickly, are easy to train, and achieve better unsupervised latent factor discovery than flow-based models, linear ICA, and Variational Autoencoders on images.
arXiv Detail & Related papers (2020-02-18T17:58:07Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.