Latent Recurrent Transformer: Architecture Exploration, Training Strategies, and Scaling Behavior
Abstract Overview
This paper studies the Latent Recurrent Transformer (LRT), a lightweight modification of autoregressive transformers that reuses a high-level hidden state from the previous token as recurrent memory for the next token. The design preserves the standard decoder-only backbone, attention mechanism, and KV-cache interface, while adding a cross-token, cross-layer latent pathway without extra decoding steps in the default inference setup. To make this recurrent dependency trainable at scale, the authors introduce interleaved parallel training, which uses a full-sequence initialization pass and then refines disjoint token subsets in parallel with a shared buffer, at roughly 2× baseline training compute. Experiments on nanochat-style 1.3B and 2.1B backbones trained on FineWeb-Edu 100BT show consistent improvements in bits per byte and CORE few-shot evaluation under matched effective compute, with the default shared-projection variant adding only about 0.3% parameters.
Novelty
The paper's main novelty is a recurrent memory mechanism that reuses an already computed source-layer hidden state from the previous token, rather than adding pause tokens, extra depth recurrence, or changing the KV-cache format. It also introduces interleaved parallel training as a practical approximation to token-level recurrence that preserves substantial training parallelism.
Results
Across both 20-layer and 24-layer backbones, LRT shifts scaling curves toward lower BPB and higher CORE at comparable baseline-equivalent compute. For example, on the 24L model at compute 80, BPB improves from 0.699 for the baseline to 0.695 for LRT-shared and 0.693 for LRT-layerwise; on the 20L model at the same compute, CORE rises from 0.271 to 0.274 and 0.277. Ablations further show that combining KV Projection with Residual Injection works best, an upper-middle source layer is strongest, and chunked training is weaker than interleaved parallel training.
Key Points
- LRT reuses the previous token's high-level source-layer hidden state as recurrent memory, creating a cross-token latent pathway while keeping one normal forward pass per generated token.
- The proposed interleaved parallel training scheme approximates token-level recurrence with a full initialization pass plus parallel subset refinements, giving recurrent-memory-aware supervision at about 2× baseline compute.
- Empirically, LRT improves language modeling loss and in-context evaluation across scaling experiments, and its default shared variant achieves most of the gain with only about 0.3% additional parameters.