Latent Recurrent Transformer: Architecture Exploration, Training Strategies, and Scaling Behavior
Abstractの概要
本論文では、自己回帰型Transformerの軽量な改良版であるLatent Recurrent Transformer (LRT) について研究する。これは、前のトークンの高レベルの隠れ状態を次のトークンのためのリカレントメモリとして再利用するものである。この設計は、標準的なデコーダのみのバックボーン、アテンション機構、KVキャッシュインターフェースを維持しつつ、デフォルトの推論設定において追加のデコードステップなしで、トークン間およびレイヤー間の潜在的な経路を追加する。このリカレントな依存関係を大規模に学習できるようにするため、著者らはインターリーブ並列学習(interleaved parallel training)を導入している。これは、全シーケンスの初期化パスを使用し、その後共有バッファを用いて互いに素なトークンサブセットを並列に改良するものであり、計算量はベースラインの学習の約2倍となる。FineWeb-Edu 100BTで学習されたnanochatスタイルの13億パラメータおよび21億パラメータのバックボーンでの実験では、一致した実効計算量の下で、bits per byte (BPB) およびCORE数ショット評価における一貫した改善が示され、デフォルトの共有プロジェクション変種ではパラメータの増加は約0.3%に留まった。
新規性
本論文の主な新規性は、ポーズトークンの追加、追加の実装階層の再帰、あるいはKVキャッシュフォーマットの変更を行うのではなく、すでに計算された前のトークンのソースレイヤーの隠れ状態を再利用するリカレントメモリ機構である。また、大規模な学習の並列性を維持しながら、トークンレベルの再帰を実用的に近似するアプローチとして、インターリーブ並列学習を同時に考案・導入した点も画期的である。
成果
20層および24層の両バックボーンにおける実験において、LRTは、ベースラインと同等の計算量でスケーリング曲線をより低いBPBとより高いCOREの方向へシフトさせる。例えば、計算単位80での24層モデルでは、ベースラインのBPB 0.699からLRT(共有型)の0.695、LRT(レイヤー毎)の0.693へと改善し、同計算量の20層モデルではCORE評価が0.271から0.274、0.277へと向上する。アブレーション調査により、KVプロジェクションと残差注入の組み合わせが最も機能すること、中上段のソースレイヤーを用いるのが最も強力であること、そしてチャンク学習はインターリーブ並列学習よりも性能が劣ることが追加で示されている。
論文の注目点
- LRTは前のトークンの高レベルなソースレイヤーの隠れ状態をリカレントメモリとして再利用し、生成されるトークンごとに通常のフォワードパスを1回維持しつつ、トークン間の潜在的経路を作成する。
- 提案されたインターリーブ並列学習スキームは、全体の初期化パスと並列のサブセットの段階的改良によってトークンレベルの再帰を近似し、ベースラインの約2倍の計算量でリカレントメモリを考慮した学習を実現する。
- 実験的評価として、LRTは各種の層スケーリングにおいて言語モデリング損失とコンテキスト内評価を改善し、デフォルトの共有バリアントはわずか約0.3%の追加パラメータで改善効果の大部分を達成する。