In-context Learning for Mixture of Linear Regressions: Existence, Generalization and Training Dynamics
- URL: http://arxiv.org/abs/2410.14183v2
- Date: Sun, 09 Feb 2025 03:40:52 GMT
- Title: In-context Learning for Mixture of Linear Regressions: Existence, Generalization and Training Dynamics
- Authors: Yanhao Jin, Krishnakumar Balasubramanian, Lifeng Lai,
- Abstract summary: We prove that there exists a transformer capable of achieving a prediction error of order $mathcalO(sqrtd/n)$ with high probability.<n>We also analyze the training dynamics of transformers with single linear self-attention layers, demonstrating that, with appropriately parameters, gradient flow optimization over the population mean square loss converges to a global optimum.
- Score: 34.458004744956334
- License: http://arxiv.org/licenses/nonexclusive-distrib/1.0/
- Abstract: We investigate the in-context learning capabilities of transformers for the $d$-dimensional mixture of linear regression model, providing theoretical insights into their existence, generalization bounds, and training dynamics. Specifically, we prove that there exists a transformer capable of achieving a prediction error of order $\mathcal{O}(\sqrt{d/n})$ with high probability, where $n$ represents the training prompt size in the high signal-to-noise ratio (SNR) regime. Moreover, we derive in-context excess risk bounds of order $\mathcal{O}(L/\sqrt{B})$ for the case of two mixtures, where $B$ denotes the number of training prompts, and $L$ represents the number of attention layers. The dependence of $L$ on the SNR is explicitly characterized, differing between low and high SNR settings. We further analyze the training dynamics of transformers with single linear self-attention layers, demonstrating that, with appropriately initialized parameters, gradient flow optimization over the population mean square loss converges to a global optimum. Extensive simulations suggest that transformers perform well on this task, potentially outperforming other baselines, such as the Expectation-Maximization algorithm.
Related papers
- In-Context Linear Regression Demystified: Training Dynamics and Mechanistic Interpretability of Multi-Head Softmax Attention [52.159541540613915]
We study how multi-head softmax attention models are trained to perform in-context learning on linear data.
Our results reveal that in-context learning ability emerges from the trained transformer as an aggregated effect of its architecture and the underlying data distribution.
arXiv Detail & Related papers (2025-03-17T02:00:49Z) - The Sample Complexity of Online Reinforcement Learning: A Multi-model Perspective [55.15192437680943]
We study the sample complexity of online reinforcement learning for nonlinear dynamical systems with continuous state and action spaces.
Our algorithms are likely to be useful in practice, due to their simplicity, the ability to incorporate prior knowledge, and their benign transient behavior.
arXiv Detail & Related papers (2025-01-27T10:01:28Z) - Training Dynamics of Transformers to Recognize Word Co-occurrence via Gradient Flow Analysis [97.54180451650122]
We study the dynamics of training a shallow transformer on a task of recognizing co-occurrence of two designated words.
We analyze the gradient flow dynamics of simultaneously training three attention matrices and a linear layer.
We prove a novel property of the gradient flow, termed textitautomatic balancing of gradients, which enables the loss values of different samples to decrease almost at the same rate and further facilitates the proof of near minimum training loss.
arXiv Detail & Related papers (2024-10-12T17:50:58Z) - Pruning is Optimal for Learning Sparse Features in High-Dimensions [15.967123173054535]
We show that a class of statistical models can be optimally learned using pruned neural networks trained with gradient descent.
We show that pruning neural networks proportional to the sparsity level of $boldsymbolV$ improves their sample complexity compared to unpruned networks.
arXiv Detail & Related papers (2024-06-12T21:43:12Z) - On Mesa-Optimization in Autoregressively Trained Transformers: Emergence and Capability [34.43255978863601]
Several suggest that transformers learn a mesa-optimizer during autorere training.
We show that a stronger assumption related to the moments of data is the sufficient necessary condition that the learned mesa-optimizer can perform.
arXiv Detail & Related papers (2024-05-27T05:41:06Z) - Adaptive Federated Learning Over the Air [108.62635460744109]
We propose a federated version of adaptive gradient methods, particularly AdaGrad and Adam, within the framework of over-the-air model training.
Our analysis shows that the AdaGrad-based training algorithm converges to a stationary point at the rate of $mathcalO( ln(T) / T 1 - frac1alpha ).
arXiv Detail & Related papers (2024-03-11T09:10:37Z) - Improved Algorithm for Adversarial Linear Mixture MDPs with Bandit
Feedback and Unknown Transition [71.33787410075577]
We study reinforcement learning with linear function approximation, unknown transition, and adversarial losses.
We propose a new algorithm that attains an $widetildeO(dsqrtHS3K + sqrtHSAK)$ regret with high probability.
arXiv Detail & Related papers (2024-03-07T15:03:50Z) - Globally Convergent Accelerated Algorithms for Multilinear Sparse
Logistic Regression with $\ell_0$-constraints [2.323238724742687]
Multilinear logistic regression serves as a powerful tool for the analysis of multidimensional data.
We propose an Accelerated Proximal Alternating Minim-MLSR model to solve the $ell_0$-MLSR.
We also demonstrate that APALM$+$ is globally convergent to a first-order critical point as well as to establish convergence by using the Kurdy-Lojasiewicz property.
arXiv Detail & Related papers (2023-09-17T11:05:08Z) - PROMISE: Preconditioned Stochastic Optimization Methods by Incorporating Scalable Curvature Estimates [17.777466668123886]
We introduce PROMISE ($textbfPr$econditioned $textbfO$ptimization $textbfM$ethods by $textbfI$ncorporating $textbfS$calable Curvature $textbfE$stimates), a suite of sketching-based preconditioned gradient algorithms.
PROMISE includes preconditioned versions of SVRG, SAGA, and Katyusha.
arXiv Detail & Related papers (2023-09-05T07:49:10Z) - Transformers as Support Vector Machines [54.642793677472724]
We establish a formal equivalence between the optimization geometry of self-attention and a hard-margin SVM problem.
We characterize the implicit bias of 1-layer transformers optimized with gradient descent.
We believe these findings inspire the interpretation of transformers as a hierarchy of SVMs that separates and selects optimal tokens.
arXiv Detail & Related papers (2023-08-31T17:57:50Z) - Trained Transformers Learn Linear Models In-Context [39.56636898650966]
Attention-based neural networks as transformers have demonstrated a remarkable ability to exhibit inattention learning (ICL)
We show that when transformer training over random instances of linear regression problems, these models' predictions mimic nonlinear of ordinary squares.
arXiv Detail & Related papers (2023-06-16T15:50:03Z) - Contextual Combinatorial Bandits with Probabilistically Triggered Arms [55.9237004478033]
We study contextual bandits with probabilistically triggered arms (C$2$MAB-T) under a variety of smoothness conditions.
Under the triggering modulated (TPM) condition, we devise the C$2$-UC-T algorithm and derive a regret bound $tildeO(dsqrtT)$.
arXiv Detail & Related papers (2023-03-30T02:51:00Z) - Approximate Function Evaluation via Multi-Armed Bandits [51.146684847667125]
We study the problem of estimating the value of a known smooth function $f$ at an unknown point $boldsymbolmu in mathbbRn$, where each component $mu_i$ can be sampled via a noisy oracle.
We design an instance-adaptive algorithm that learns to sample according to the importance of each coordinate, and with probability at least $1-delta$ returns an $epsilon$ accurate estimate of $f(boldsymbolmu)$.
arXiv Detail & Related papers (2022-03-18T18:50:52Z) - Single Trajectory Nonparametric Learning of Nonlinear Dynamics [8.438421942654292]
Given a single trajectory of a dynamical system, we analyze the performance of the nonparametric least squares estimator (LSE)
We leverage recently developed information-theoretic methods to establish the optimality of the LSE for non hypotheses classes.
We specialize our results to a number of scenarios of practical interest, such as Lipschitz dynamics, generalized linear models, and dynamics described by functions in certain classes of Reproducing Kernel Hilbert Spaces (RKHS)
arXiv Detail & Related papers (2022-02-16T19:38:54Z) - Policy Optimization Using Semiparametric Models for Dynamic Pricing [1.3428344011390776]
We study the contextual dynamic pricing problem where the market value of a product is linear in its observed features plus some market noise.
We propose a dynamic statistical learning and decision-making policy that combines semiparametric estimation from a generalized linear model with an unknown link and online decision-making.
arXiv Detail & Related papers (2021-09-13T23:50:01Z) - Variance-Aware Confidence Set: Variance-Dependent Bound for Linear
Bandits and Horizon-Free Bound for Linear Mixture MDP [76.94328400919836]
We show how to construct variance-aware confidence sets for linear bandits and linear mixture Decision Process (MDP)
For linear bandits, we obtain an $widetildeO(mathrmpoly(d)sqrt1 + sum_i=1Ksigma_i2) regret bound, where $d is the feature dimension.
For linear mixture MDP, we obtain an $widetildeO(mathrmpoly(d)sqrtK)$ regret bound, where
arXiv Detail & Related papers (2021-01-29T18:57:52Z) - Nearly Minimax Optimal Reinforcement Learning for Linear Mixture Markov
Decision Processes [91.38793800392108]
We study reinforcement learning with linear function approximation where the underlying transition probability kernel of the Markov decision process (MDP) is a linear mixture model.
We propose a new, computationally efficient algorithm with linear function approximation named $textUCRL-VTR+$ for the aforementioned linear mixture MDPs.
To the best of our knowledge, these are the first computationally efficient, nearly minimax optimal algorithms for RL with linear function approximation.
arXiv Detail & Related papers (2020-12-15T18:56:46Z) - Convergence of Sparse Variational Inference in Gaussian Processes
Regression [29.636483122130027]
We show that a method with an overall computational cost of $mathcalO(log N)2D(loglog N)2)$ can be used to perform inference.
arXiv Detail & Related papers (2020-08-01T19:23:34Z) - Optimal Robust Linear Regression in Nearly Linear Time [97.11565882347772]
We study the problem of high-dimensional robust linear regression where a learner is given access to $n$ samples from the generative model $Y = langle X,w* rangle + epsilon$
We propose estimators for this problem under two settings: (i) $X$ is L4-L2 hypercontractive, $mathbbE [XXtop]$ has bounded condition number and $epsilon$ has bounded variance and (ii) $X$ is sub-Gaussian with identity second moment and $epsilon$ is
arXiv Detail & Related papers (2020-07-16T06:44:44Z) - On the Global Convergence of Training Deep Linear ResNets [104.76256863926629]
We study the convergence of gradient descent (GD) and gradient descent (SGD) for training $L$-hidden-layer linear residual networks (ResNets)
We prove that for training deep residual networks with certain linear transformations at input and output layers, both GD and SGD can converge to the global minimum of the training loss.
arXiv Detail & Related papers (2020-03-02T18:34:49Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.