Rethinking Sharpness-Aware Minimization as Variational Inference
- URL: http://arxiv.org/abs/2210.10452v1
- Date: Wed, 19 Oct 2022 10:35:54 GMT
- Title: Rethinking Sharpness-Aware Minimization as Variational Inference
- Authors: Szilvia Ujv\'ary, Zsigmond Telek, Anna Kerekes, Anna M\'esz\'aros,
Ferenc Husz\'ar
- Abstract summary: Sharpness-aware (SAM) aims to improve the generalisation of gradient-based learning by seeking out flat minima.
We establish connections between SAM and Mean-Field Variational Inference (MFVI) of neural network parameters.
- Score: 1.749935196721634
- License: http://arxiv.org/licenses/nonexclusive-distrib/1.0/
- Abstract: Sharpness-aware minimization (SAM) aims to improve the generalisation of
gradient-based learning by seeking out flat minima. In this work, we establish
connections between SAM and Mean-Field Variational Inference (MFVI) of neural
network parameters. We show that both these methods have interpretations as
optimizing notions of flatness, and when using the reparametrisation trick,
they both boil down to calculating the gradient at a perturbed version of the
current mean parameter. This thinking motivates our study of algorithms that
combine or interpolate between SAM and MFVI. We evaluate the proposed
variational algorithms on several benchmark datasets, and compare their
performance to variants of SAM. Taking a broader perspective, our work suggests
that SAM-like updates can be used as a drop-in replacement for the
reparametrisation trick.
Related papers
- Enhancing Sharpness-Aware Minimization by Learning Perturbation Radius [6.78775404181577]
We propose a bilevel optimization framework called LEarning the perTurbation radiuS to learn the perturbation radius for sharpness-aware minimization algorithms.
Experimental results on various architectures and benchmark datasets in computer vision and natural language processing demonstrate the effectiveness of the proposed LETS method.
arXiv Detail & Related papers (2024-08-15T15:40:57Z) - A Universal Class of Sharpness-Aware Minimization Algorithms [57.29207151446387]
We introduce a new class of sharpness measures, leading to new sharpness-aware objective functions.
We prove that these measures are textitly expressive, allowing any function of the training loss Hessian matrix to be represented by appropriate hyper and determinants.
arXiv Detail & Related papers (2024-06-06T01:52:09Z) - Effective Gradient Sample Size via Variation Estimation for Accelerating Sharpness aware Minimization [19.469113881229646]
Sharpness-aware Minimization (SAM) has been proposed recently to improve model generalization ability.
SAM calculates the gradient twice in each optimization step, thereby doubling the computation costs.
We propose a simple yet efficient sampling method to significantly accelerate SAM.
arXiv Detail & Related papers (2024-02-24T05:48:05Z) - Systematic Investigation of Sparse Perturbed Sharpness-Aware
Minimization Optimizer [158.2634766682187]
Deep neural networks often suffer from poor generalization due to complex and non- unstructured loss landscapes.
SharpnessAware Minimization (SAM) is a popular solution that smooths the loss by minimizing the change of landscape when adding a perturbation.
In this paper, we propose Sparse SAM (SSAM), an efficient and effective training scheme that achieves perturbation by a binary mask.
arXiv Detail & Related papers (2023-06-30T09:33:41Z) - Normalization Layers Are All That Sharpness-Aware Minimization Needs [53.799769473526275]
Sharpness-aware minimization (SAM) was proposed to reduce sharpness of minima.
We show that perturbing only the affine normalization parameters (typically comprising 0.1% of the total parameters) in the adversarial step of SAM can outperform perturbing all of the parameters.
arXiv Detail & Related papers (2023-06-07T08:05:46Z) - AdaSAM: Boosting Sharpness-Aware Minimization with Adaptive Learning
Rate and Momentum for Training Deep Neural Networks [76.90477930208982]
Sharpness aware (SAM) has been extensively explored as it can generalize better for training deep neural networks.
Integrating SAM with adaptive learning perturbation and momentum acceleration, dubbed AdaSAM, has already been explored.
We conduct several experiments on several NLP tasks, which show that AdaSAM could achieve superior performance compared with SGD, AMS, and SAMsGrad.
arXiv Detail & Related papers (2023-03-01T15:12:42Z) - mSAM: Micro-Batch-Averaged Sharpness-Aware Minimization [20.560184120992094]
Sharpness-Aware Minimization technique modifies the fundamental loss function that steers gradient descent methods toward flatter minima.
We extend a recently developed and well-studied general framework for flatness analysis to theoretically show that SAM achieves flatter minima than SGD, and mSAM achieves even flatter minima than SAM.
arXiv Detail & Related papers (2023-02-19T23:27:12Z) - Improved Deep Neural Network Generalization Using m-Sharpness-Aware
Minimization [14.40189851070842]
Sharpness-Aware Minimization (SAM) modifies the underlying loss function to guide descent methods towards flatter minima.
Recent work suggests that mSAM can outperform SAM in terms of test accuracy.
This paper presents a comprehensive empirical evaluation of mSAM on various tasks and datasets.
arXiv Detail & Related papers (2022-12-07T00:37:55Z) - Efficient Sharpness-aware Minimization for Improved Training of Neural
Networks [146.2011175973769]
This paper proposes Efficient Sharpness Aware Minimizer (M) which boosts SAM s efficiency at no cost to its generalization performance.
M includes two novel and efficient training strategies-StochasticWeight Perturbation and Sharpness-Sensitive Data Selection.
We show, via extensive experiments on the CIFAR and ImageNet datasets, that ESAM enhances the efficiency over SAM from requiring 100% extra computations to 40% vis-a-vis bases.
arXiv Detail & Related papers (2021-10-07T02:20:37Z) - Recursive Inference for Variational Autoencoders [34.552283758419506]
Inference networks of traditional Variational Autoencoders (VAEs) are typically amortized.
Recent semi-amortized approaches were proposed to address this drawback.
We introduce an accurate amortized inference algorithm.
arXiv Detail & Related papers (2020-11-17T10:22:12Z) - Extrapolation for Large-batch Training in Deep Learning [72.61259487233214]
We show that a host of variations can be covered in a unified framework that we propose.
We prove the convergence of this novel scheme and rigorously evaluate its empirical performance on ResNet, LSTM, and Transformer.
arXiv Detail & Related papers (2020-06-10T08:22:41Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.