Training Chain-of-Thought via Latent-Variable Inference
- URL: http://arxiv.org/abs/2312.02179v1
- Date: Tue, 28 Nov 2023 17:47:32 GMT
- Title: Training Chain-of-Thought via Latent-Variable Inference
- Authors: Du Phan, Matthew D. Hoffman, David Dohan, Sholto Douglas, Tuan Anh Le,
Aaron Parisi, Pavel Sountsov, Charles Sutton, Sharad Vikram, Rif A. Saurous
- Abstract summary: Large language models (LLMs) solve problems more accurately and interpretably when instructed to work out the answer step by step using a chain-of-thought'' prompt.
Naively combining CoT with supervised tuning requires supervision not just of the correct answers, but also of detailed rationales that lead to those answers.
We propose a fine-tuning strategy that tries to maximize the emphmarginal log-likelihood of generating a correct answer using CoT prompting.
- Score: 30.21067593018967
- License: http://arxiv.org/licenses/nonexclusive-distrib/1.0/
- Abstract: Large language models (LLMs) solve problems more accurately and interpretably
when instructed to work out the answer step by step using a
``chain-of-thought'' (CoT) prompt. One can also improve LLMs' performance on a
specific task by supervised fine-tuning, i.e., by using gradient ascent on some
tunable parameters to maximize the average log-likelihood of correct answers
from a labeled training set. Naively combining CoT with supervised tuning
requires supervision not just of the correct answers, but also of detailed
rationales that lead to those answers; these rationales are expensive to
produce by hand. Instead, we propose a fine-tuning strategy that tries to
maximize the \emph{marginal} log-likelihood of generating a correct answer
using CoT prompting, approximately averaging over all possible rationales. The
core challenge is sampling from the posterior over rationales conditioned on
the correct answer; we address it using a simple Markov-chain Monte Carlo
(MCMC) expectation-maximization (EM) algorithm inspired by the self-taught
reasoner (STaR), memoized wake-sleep, Markovian score climbing, and persistent
contrastive divergence. This algorithm also admits a novel control-variate
technique that drives the variance of our gradient estimates to zero as the
model improves. Applying our technique to GSM8K and the tasks in BIG-Bench
Hard, we find that this MCMC-EM fine-tuning technique typically improves the
model's accuracy on held-out examples more than STaR or prompt-tuning with or
without CoT.
Related papers
- Improving LLM Reasoning through Scaling Inference Computation with Collaborative Verification [52.095460362197336]
Large language models (LLMs) struggle with consistent and accurate reasoning.
LLMs are trained primarily on correct solutions, reducing their ability to detect and learn from errors.
We propose a novel collaborative method integrating Chain-of-Thought (CoT) and Program-of-Thought (PoT) solutions for verification.
arXiv Detail & Related papers (2024-10-05T05:21:48Z) - Training Nonlinear Transformers for Chain-of-Thought Inference: A Theoretical Generalization Analysis [82.51626700527837]
Chain-of-shift (CoT) is an efficient method that enables the reasoning ability of large language models by augmenting the query using examples with multiple intermediate steps.
We show that despite the theoretical success of CoT, it fails to provide an accurate generalization when CoT does.
arXiv Detail & Related papers (2024-10-03T03:12:51Z) - Recursive Introspection: Teaching Language Model Agents How to Self-Improve [30.086494067593268]
We develop RISE: Recursive IntroSpEction, an approach for fine-tuning large language models.
Our experiments show that RISE enables Llama2, Llama3, and Mistral models to improve themselves with more turns on math reasoning tasks.
arXiv Detail & Related papers (2024-07-25T17:35:59Z) - CoRMF: Criticality-Ordered Recurrent Mean Field Ising Solver [4.364088891019632]
We propose an RNN-based efficient Ising model solver, the Criticality-ordered Recurrent Mean Field (CoRMF)
By leveraging the approximated tree structure of the underlying Ising graph, the newly-obtained criticality order enables the unification between variational mean-field and RNN.
CoRFM solves the Ising problems in a self-train fashion without data/evidence, and the inference tasks can be executed by directly sampling from RNN.
arXiv Detail & Related papers (2024-03-05T16:55:06Z) - ReFT: Reasoning with Reinforced Fine-Tuning [9.80361828538909]
We propose a simple yet effective approach called Reinforced Fine-Tuning (ReFT) to enhance the generalizability of learning LLMs for reasoning.
ReFT first warmups the model with SFT, and then employs on-line reinforcement learning, specifically the PPO algorithm in this paper.
Experiments on GSM8K, MathQA, and SVAMP datasets show that ReFT significantly outperforms SFT.
arXiv Detail & Related papers (2024-01-17T04:43:21Z) - GRACE: Discriminator-Guided Chain-of-Thought Reasoning [75.35436025709049]
We propose Guiding chain-of-thought ReAsoning with a CorrectnEss Discriminator (GRACE) to steer the decoding process towards producing correct reasoning steps.
GRACE employs a discriminator trained with a contrastive loss over correct and incorrect steps, which is used during decoding to score next-step candidates.
arXiv Detail & Related papers (2023-05-24T09:16:51Z) - Policy Gradient With Serial Markov Chain Reasoning [10.152838128195468]
We introduce a new framework that performs decision-making in reinforcement learning as an iterative reasoning process.
We show our framework has several useful properties that are inherently missing from traditional RL.
Our resulting algorithm achieves state-of-the-art performance in popular Mujoco and DeepMind Control benchmarks.
arXiv Detail & Related papers (2022-10-13T06:15:29Z) - Navigating to the Best Policy in Markov Decision Processes [68.8204255655161]
We investigate the active pure exploration problem in Markov Decision Processes.
Agent sequentially selects actions and, from the resulting system trajectory, aims at the best as fast as possible.
arXiv Detail & Related papers (2021-06-05T09:16:28Z) - Gaussian MRF Covariance Modeling for Efficient Black-Box Adversarial
Attacks [86.88061841975482]
We study the problem of generating adversarial examples in a black-box setting, where we only have access to a zeroth order oracle.
We use this setting to find fast one-step adversarial attacks, akin to a black-box version of the Fast Gradient Sign Method(FGSM)
We show that the method uses fewer queries and achieves higher attack success rates than the current state of the art.
arXiv Detail & Related papers (2020-10-08T18:36:51Z) - Adaptive Sampling for Best Policy Identification in Markov Decision
Processes [79.4957965474334]
We investigate the problem of best-policy identification in discounted Markov Decision (MDPs) when the learner has access to a generative model.
The advantages of state-of-the-art algorithms are discussed and illustrated.
arXiv Detail & Related papers (2020-09-28T15:22:24Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.