Online Reward-Weighted Fine-Tuning of Flow Matching with Wasserstein Regularization
- URL: http://arxiv.org/abs/2502.06061v1
- Date: Sun, 09 Feb 2025 22:45:15 GMT
- Title: Online Reward-Weighted Fine-Tuning of Flow Matching with Wasserstein Regularization
- Authors: Jiajun Fan, Shuaike Shen, Chaoran Cheng, Yuxin Chen, Chumeng Liang, Ge Liu,
- Abstract summary: We propose an easy-to-use and theoretically sound fine-tuning method for flow-based generative models.
By introducing an online rewardweighting mechanism, our approach guides the model to prioritize high-reward regions in the data manifold.
Our method achieves optimal policy convergence while allowing controllable trade-offs between reward and diversity.
- Score: 14.320131946691268
- License:
- Abstract: Recent advancements in reinforcement learning (RL) have achieved great success in fine-tuning diffusion-based generative models. However, fine-tuning continuous flow-based generative models to align with arbitrary user-defined reward functions remains challenging, particularly due to issues such as policy collapse from overoptimization and the prohibitively high computational cost of likelihoods in continuous-time flows. In this paper, we propose an easy-to-use and theoretically sound RL fine-tuning method, which we term Online Reward-Weighted Conditional Flow Matching with Wasserstein-2 Regularization (ORW-CFM-W2). Our method integrates RL into the flow matching framework to fine-tune generative models with arbitrary reward functions, without relying on gradients of rewards or filtered datasets. By introducing an online reward-weighting mechanism, our approach guides the model to prioritize high-reward regions in the data manifold. To prevent policy collapse and maintain diversity, we incorporate Wasserstein-2 (W2) distance regularization into our method and derive a tractable upper bound for it in flow matching, effectively balancing exploration and exploitation of policy optimization. We provide theoretical analyses to demonstrate the convergence properties and induced data distributions of our method, establishing connections with traditional RL algorithms featuring Kullback-Leibler (KL) regularization and offering a more comprehensive understanding of the underlying mechanisms and learning behavior of our approach. Extensive experiments on tasks including target image generation, image compression, and text-image alignment demonstrate the effectiveness of our method, where our method achieves optimal policy convergence while allowing controllable trade-offs between reward maximization and diversity preservation.
Related papers
- CDSA: Conservative Denoising Score-based Algorithm for Offline Reinforcement Learning [25.071018803326254]
Distribution shift is a major obstacle in offline reinforcement learning.
Previous conservative offline RL algorithms struggle to generalize to unseen actions.
We propose to use the gradient fields of the dataset density generated from a pre-trained offline RL algorithm to adjust the original actions.
arXiv Detail & Related papers (2024-06-11T17:59:29Z) - Diffusion-based Reinforcement Learning via Q-weighted Variational Policy Optimization [55.97310586039358]
Diffusion models have garnered widespread attention in Reinforcement Learning (RL) for their powerful expressiveness and multimodality.
We propose a novel model-free diffusion-based online RL algorithm, Q-weighted Variational Policy Optimization (QVPO)
Specifically, we introduce the Q-weighted variational loss, which can be proved to be a tight lower bound of the policy objective in online RL under certain conditions.
We also develop an efficient behavior policy to enhance sample efficiency by reducing the variance of the diffusion policy during online interactions.
arXiv Detail & Related papers (2024-05-25T10:45:46Z) - Offline Policy Optimization in RL with Variance Regularizaton [142.87345258222942]
We propose variance regularization for offline RL algorithms, using stationary distribution corrections.
We show that by using Fenchel duality, we can avoid double sampling issues for computing the gradient of the variance regularizer.
The proposed algorithm for offline variance regularization (OVAR) can be used to augment any existing offline policy optimization algorithms.
arXiv Detail & Related papers (2022-12-29T18:25:01Z) - Diffusion Policies as an Expressive Policy Class for Offline
Reinforcement Learning [70.20191211010847]
Offline reinforcement learning (RL) aims to learn an optimal policy using a previously collected static dataset.
We introduce Diffusion Q-learning (Diffusion-QL) that utilizes a conditional diffusion model to represent the policy.
We show that our method can achieve state-of-the-art performance on the majority of the D4RL benchmark tasks.
arXiv Detail & Related papers (2022-08-12T09:54:11Z) - Model-Free Learning of Optimal Deterministic Resource Allocations in
Wireless Systems via Action-Space Exploration [4.721069729610892]
We propose a technically grounded and scalable deterministic-dual gradient policy method for efficiently learning optimal parameterized resource allocation policies.
Our method not only efficiently exploits gradient availability of popular universal representations such as deep networks, but is also truly model-free, as it relies on consistent zeroth-order gradient approximations of associated random network services constructed via low-dimensional perturbations in action space.
arXiv Detail & Related papers (2021-08-23T18:26:16Z) - COMBO: Conservative Offline Model-Based Policy Optimization [120.55713363569845]
Uncertainty estimation with complex models, such as deep neural networks, can be difficult and unreliable.
We develop a new model-based offline RL algorithm, COMBO, that regularizes the value function on out-of-support state-actions.
We find that COMBO consistently performs as well or better as compared to prior offline model-free and model-based methods.
arXiv Detail & Related papers (2021-02-16T18:50:32Z) - Pareto Deterministic Policy Gradients and Its Application in 5G Massive
MIMO Networks [32.099949375036495]
We consider jointly optimizing cell load balance and network throughput via a reinforcement learning (RL) approach.
Our rationale behind using RL is to circumvent the challenges of analytically modeling user mobility and network dynamics.
To accomplish this joint optimization, we integrate vector rewards into the RL value network and conduct RL action via a separate policy network.
arXiv Detail & Related papers (2020-12-02T15:35:35Z) - Policy Gradient Methods for the Noisy Linear Quadratic Regulator over a
Finite Horizon [3.867363075280544]
We explore reinforcement learning methods for finding the optimal policy in the linear quadratic regulator (LQR) problem.
We produce a global linear convergence guarantee for the setting of finite time horizon and state dynamics under weak assumptions.
We show results for the case where we assume a model for the underlying dynamics and where we apply the method to the data directly.
arXiv Detail & Related papers (2020-11-20T09:51:49Z) - MOPO: Model-based Offline Policy Optimization [183.6449600580806]
offline reinforcement learning (RL) refers to the problem of learning policies entirely from a large batch of previously collected data.
We show that an existing model-based RL algorithm already produces significant gains in the offline setting.
We propose to modify the existing model-based RL methods by applying them with rewards artificially penalized by the uncertainty of the dynamics.
arXiv Detail & Related papers (2020-05-27T08:46:41Z) - Mixed Reinforcement Learning with Additive Stochastic Uncertainty [19.229447330293546]
Reinforcement learning (RL) methods often rely on massive exploration data to search optimal policies, and suffer from poor sampling efficiency.
This paper presents a mixed RL algorithm by simultaneously using dual representations of environmental dynamics to search the optimal policy.
The effectiveness of the mixed RL is demonstrated by a typical optimal control problem of non-affine nonlinear systems.
arXiv Detail & Related papers (2020-02-28T08:02:34Z) - Nested-Wasserstein Self-Imitation Learning for Sequence Generation [158.19606942252284]
We propose the concept of nested-Wasserstein distance for distributional semantic matching.
A novel nested-Wasserstein self-imitation learning framework is developed, encouraging the model to exploit historical high-rewarded sequences.
arXiv Detail & Related papers (2020-01-20T02:19:13Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.