Learning to (Learn at Test Time): RNNs with Expressive Hidden States
- URL: http://arxiv.org/abs/2407.04620v2
- Date: Sun, 11 Aug 2024 00:42:18 GMT
- Title: Learning to (Learn at Test Time): RNNs with Expressive Hidden States
- Authors: Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, Tatsunori Hashimoto, Carlos Guestrin,
- Abstract summary: We propose a new class of sequence modeling layers with linear complexity and an expressive hidden state.
Since the hidden state is updated by training even on test sequences, our layers are called Test-Time Training layers.
- Score: 69.78469963604063
- License: http://creativecommons.org/licenses/by/4.0/
- Abstract: Self-attention performs well in long context but has quadratic complexity. Existing RNN layers have linear complexity, but their performance in long context is limited by the expressive power of their hidden state. We propose a new class of sequence modeling layers with linear complexity and an expressive hidden state. The key idea is to make the hidden state a machine learning model itself, and the update rule a step of self-supervised learning. Since the hidden state is updated by training even on test sequences, our layers are called Test-Time Training (TTT) layers. We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model and a two-layer MLP respectively. We evaluate our instantiations at the scale of 125M to 1.3B parameters, comparing with a strong Transformer and Mamba, a modern RNN. Both TTT-Linear and TTT-MLP match or exceed the baselines. Similar to Transformer, they can keep reducing perplexity by conditioning on more tokens, while Mamba cannot after 16k context. With preliminary systems optimization, TTT-Linear is already faster than Transformer at 8k context and matches Mamba in wall-clock time. TTT-MLP still faces challenges in memory I/O, but shows larger potential in long context, pointing to a promising direction for future research.
Related papers
- Stuffed Mamba: State Collapse and State Capacity of RNN-Based Long-Context Modeling [69.36377985746878]
We study the cause of the inability to process long context for RNNs and suggest critical mitigations.
We first investigate *state collapse* (SC), a phenomenon that causes severe performance degradation on sequence lengths not encountered during training.
We train a series of Mamba-2 models on long documents to empirically estimate the recurrent state capacity in language modeling and passkey retrieval.
arXiv Detail & Related papers (2024-10-09T17:54:28Z) - Were RNNs All We Needed? [53.393497486332]
We revisit traditional recurrent neural networks (RNNs) from over a decade ago.
We show that by removing their hidden state dependencies from their input, forget, and update gates, LSTMs and GRUs no longer need to BPTT and can be efficiently trained in parallel.
arXiv Detail & Related papers (2024-10-02T03:06:49Z) - Attention as an RNN [66.5420926480473]
We show that attention can be viewed as a special Recurrent Neural Network (RNN) with the ability to compute its textitmany-to-one RNN output efficiently.
We introduce a new efficient method of computing attention's textitmany-to-many RNN output based on the parallel prefix scan algorithm.
We show Aarens achieve comparable performance to Transformers on $38$ datasets spread across four popular sequential problem settings.
arXiv Detail & Related papers (2024-05-22T19:45:01Z) - Is Mamba Effective for Time Series Forecasting? [30.85990093479062]
We propose a Mamba-based model named Simple-Mamba (S-Mamba) for time series forecasting.
Specifically, we tokenize the time points of each variate autonomously via a linear layer.
Experiments on thirteen public datasets prove that S-Mamba maintains low computational overhead and achieves leading performance.
arXiv Detail & Related papers (2024-03-17T08:50:44Z) - Test-Time Training on Video Streams [54.07009446207442]
Prior work has established test-time training (TTT) as a general framework to further improve a trained model at test time.
We extend TTT to the streaming setting, where multiple test instances arrive in temporal order.
Online TTT significantly outperforms the fixed-model baseline for four tasks, on three real-world datasets.
arXiv Detail & Related papers (2023-07-11T05:17:42Z) - SpikeGPT: Generative Pre-trained Language Model with Spiking Neural Networks [21.616328837090396]
Spiking Neural Networks (SNNs) leverage sparse and event-driven activations to reduce the computational overhead associated with model inference.
We implement generative language model with binary, event-driven spiking activation units.
SpikeGPT is the largest backpropagation-trained SNN model to date, rendering it suitable for both the generation and comprehension of natural language.
arXiv Detail & Related papers (2023-02-27T16:43:04Z) - Improving Representational Continuity via Continued Pretraining [76.29171039601948]
Transfer learning community (LP-FT) outperforms naive training and other continual learning methods.
LP-FT also reduces forgetting in a real world satellite remote sensing dataset (FMoW)
variant of LP-FT gets state-of-the-art accuracies on an NLP continual learning benchmark.
arXiv Detail & Related papers (2023-02-26T10:39:38Z) - Large Scale Time-Series Representation Learning via Simultaneous Low and
High Frequency Feature Bootstrapping [7.0064929761691745]
We propose a non-contrastive self-supervised learning approach efficiently captures low and high-frequency time-varying features.
Our method takes raw time series data as input and creates two different augmented views for two branches of the model.
To demonstrate the robustness of our model we performed extensive experiments and ablation studies on five real-world time-series datasets.
arXiv Detail & Related papers (2022-04-24T14:39:47Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.