Provable In-context Learning for Mixture of Linear Regressions using Transformers
- URL: http://arxiv.org/abs/2410.14183v1
- Date: Fri, 18 Oct 2024 05:28:47 GMT
- Title: Provable In-context Learning for Mixture of Linear Regressions using Transformers
- Authors: Yanhao Jin, Krishnakumar Balasubramanian, Lifeng Lai,
- Abstract summary: We theoretically investigate the in-context learning capabilities of transformers in the context of learning mixtures of linear regression models.
For the case of two mixtures, we demonstrate the existence of transformers that can achieve an accuracy, relative to the oracle predictor, of order $mathcaltildeO((d/n)1/4)$ in the low signal-to-noise ratio (SNR) regime and $mathcaltildeO(sqrtd/n)$ in the high SNR regime.
- Score: 34.458004744956334
- License:
- Abstract: We theoretically investigate the in-context learning capabilities of transformers in the context of learning mixtures of linear regression models. For the case of two mixtures, we demonstrate the existence of transformers that can achieve an accuracy, relative to the oracle predictor, of order $\mathcal{\tilde{O}}((d/n)^{1/4})$ in the low signal-to-noise ratio (SNR) regime and $\mathcal{\tilde{O}}(\sqrt{d/n})$ in the high SNR regime, where $n$ is the length of the prompt, and $d$ is the dimension of the problem. Additionally, we derive in-context excess risk bounds of order $\mathcal{O}(L/\sqrt{B})$, where $B$ denotes the number of (training) prompts, and $L$ represents the number of attention layers. The order of $L$ depends on whether the SNR is low or high. In the high SNR regime, we extend the results to $K$-component mixture models for finite $K$. Extensive simulations also highlight the advantages of transformers for this task, outperforming other baselines such as the Expectation-Maximization algorithm.
Related papers
- Pruning is Optimal for Learning Sparse Features in High-Dimensions [15.967123173054535]
We show that a class of statistical models can be optimally learned using pruned neural networks trained with gradient descent.
We show that pruning neural networks proportional to the sparsity level of $boldsymbolV$ improves their sample complexity compared to unpruned networks.
arXiv Detail & Related papers (2024-06-12T21:43:12Z) - Improved Algorithm for Adversarial Linear Mixture MDPs with Bandit
Feedback and Unknown Transition [71.33787410075577]
We study reinforcement learning with linear function approximation, unknown transition, and adversarial losses.
We propose a new algorithm that attains an $widetildeO(dsqrtHS3K + sqrtHSAK)$ regret with high probability.
arXiv Detail & Related papers (2024-03-07T15:03:50Z) - Globally Convergent Accelerated Algorithms for Multilinear Sparse
Logistic Regression with $\ell_0$-constraints [2.323238724742687]
Multilinear logistic regression serves as a powerful tool for the analysis of multidimensional data.
We propose an Accelerated Proximal Alternating Minim-MLSR model to solve the $ell_0$-MLSR.
We also demonstrate that APALM$+$ is globally convergent to a first-order critical point as well as to establish convergence by using the Kurdy-Lojasiewicz property.
arXiv Detail & Related papers (2023-09-17T11:05:08Z) - Computationally Efficient Horizon-Free Reinforcement Learning for Linear
Mixture MDPs [111.75736569611159]
We propose the first computationally efficient horizon-free algorithm for linear mixture MDPs.
Our algorithm adapts a weighted least square estimator for the unknown transitional dynamic.
This also improves upon the best-known algorithms in this setting when $sigma_k2$'s are known.
arXiv Detail & Related papers (2022-05-23T17:59:18Z) - Approximate Function Evaluation via Multi-Armed Bandits [51.146684847667125]
We study the problem of estimating the value of a known smooth function $f$ at an unknown point $boldsymbolmu in mathbbRn$, where each component $mu_i$ can be sampled via a noisy oracle.
We design an instance-adaptive algorithm that learns to sample according to the importance of each coordinate, and with probability at least $1-delta$ returns an $epsilon$ accurate estimate of $f(boldsymbolmu)$.
arXiv Detail & Related papers (2022-03-18T18:50:52Z) - Policy Optimization Using Semiparametric Models for Dynamic Pricing [1.3428344011390776]
We study the contextual dynamic pricing problem where the market value of a product is linear in its observed features plus some market noise.
We propose a dynamic statistical learning and decision-making policy that combines semiparametric estimation from a generalized linear model with an unknown link and online decision-making.
arXiv Detail & Related papers (2021-09-13T23:50:01Z) - Variance-Aware Confidence Set: Variance-Dependent Bound for Linear
Bandits and Horizon-Free Bound for Linear Mixture MDP [76.94328400919836]
We show how to construct variance-aware confidence sets for linear bandits and linear mixture Decision Process (MDP)
For linear bandits, we obtain an $widetildeO(mathrmpoly(d)sqrt1 + sum_i=1Ksigma_i2) regret bound, where $d is the feature dimension.
For linear mixture MDP, we obtain an $widetildeO(mathrmpoly(d)sqrtK)$ regret bound, where
arXiv Detail & Related papers (2021-01-29T18:57:52Z) - Nearly Minimax Optimal Reinforcement Learning for Linear Mixture Markov
Decision Processes [91.38793800392108]
We study reinforcement learning with linear function approximation where the underlying transition probability kernel of the Markov decision process (MDP) is a linear mixture model.
We propose a new, computationally efficient algorithm with linear function approximation named $textUCRL-VTR+$ for the aforementioned linear mixture MDPs.
To the best of our knowledge, these are the first computationally efficient, nearly minimax optimal algorithms for RL with linear function approximation.
arXiv Detail & Related papers (2020-12-15T18:56:46Z) - Convergence of Sparse Variational Inference in Gaussian Processes
Regression [29.636483122130027]
We show that a method with an overall computational cost of $mathcalO(log N)2D(loglog N)2)$ can be used to perform inference.
arXiv Detail & Related papers (2020-08-01T19:23:34Z) - Optimal Robust Linear Regression in Nearly Linear Time [97.11565882347772]
We study the problem of high-dimensional robust linear regression where a learner is given access to $n$ samples from the generative model $Y = langle X,w* rangle + epsilon$
We propose estimators for this problem under two settings: (i) $X$ is L4-L2 hypercontractive, $mathbbE [XXtop]$ has bounded condition number and $epsilon$ has bounded variance and (ii) $X$ is sub-Gaussian with identity second moment and $epsilon$ is
arXiv Detail & Related papers (2020-07-16T06:44:44Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.