Feature learning as alignment: a structural property of gradient descent in non-linear neural networks
- URL: http://arxiv.org/abs/2402.05271v4
- Date: Sun, 17 Nov 2024 22:18:40 GMT
- Title: Feature learning as alignment: a structural property of gradient descent in non-linear neural networks
- Authors: Daniel Beaglehole, Ioannis Mitliagkas, Atish Agarwala,
- Abstract summary: We show that the neural feature ansatz (NFA) becomes correlated during training.
We establish that the alignment is driven by the interaction of weight changes induced by SGD with the pre-activation features.
We prove the derivative alignment occurs almost surely in specific high dimensional settings.
- Score: 13.032185349152492
- License:
- Abstract: Understanding the mechanisms through which neural networks extract statistics from input-label pairs through feature learning is one of the most important unsolved problems in supervised learning. Prior works demonstrated that the gram matrices of the weights (the neural feature matrices, NFM) and the average gradient outer products (AGOP) become correlated during training, in a statement known as the neural feature ansatz (NFA). Through the NFA, the authors introduce mapping with the AGOP as a general mechanism for neural feature learning. However, these works do not provide a theoretical explanation for this correlation or its origins. In this work, we further clarify the nature of this correlation, and explain its emergence. We show that this correlation is equivalent to alignment between the left singular structure of the weight matrices and the newly defined pre-activation tangent features at each layer. We further establish that the alignment is driven by the interaction of weight changes induced by SGD with the pre-activation features, and analyze the resulting dynamics analytically at early times in terms of simple statistics of the inputs and labels. We prove the derivative alignment occurs almost surely in specific high dimensional settings. Finally, we introduce a simple optimization rule motivated by our analysis of the centered correlation which dramatically increases the NFA correlations at any given layer and improves the quality of features learned.
Related papers
- Minimum-Norm Interpolation Under Covariate Shift [14.863831433459902]
In-distribution research on high-dimensional linear regression has led to the identification of a phenomenon known as textitbenign overfitting
We prove the first non-asymptotic excess risk bounds for benignly-overfit linear interpolators in the transfer learning setting.
arXiv Detail & Related papers (2024-03-31T01:41:57Z) - Weak Correlations as the Underlying Principle for Linearization of
Gradient-Based Learning Systems [1.0878040851638]
This paper delves into gradient descent-based learning algorithms, that display a linear structure in their parameter dynamics.
We establish this apparent linearity arises due to weak correlations between the first and higher-order derivatives of the hypothesis function.
Exploiting the relationship between linearity and weak correlations, we derive a bound on deviations from linearity observed during the training trajectory of gradient descent.
arXiv Detail & Related papers (2024-01-08T16:44:23Z) - Decomposing neural networks as mappings of correlation functions [57.52754806616669]
We study the mapping between probability distributions implemented by a deep feed-forward network.
We identify essential statistics in the data, as well as different information representations that can be used by neural networks.
arXiv Detail & Related papers (2022-02-10T09:30:31Z) - Data-driven emergence of convolutional structure in neural networks [83.4920717252233]
We show how fully-connected neural networks solving a discrimination task can learn a convolutional structure directly from their inputs.
By carefully designing data models, we show that the emergence of this pattern is triggered by the non-Gaussian, higher-order local structure of the inputs.
arXiv Detail & Related papers (2022-02-01T17:11:13Z) - Modeling Implicit Bias with Fuzzy Cognitive Maps [0.0]
This paper presents a Fuzzy Cognitive Map model to quantify implicit bias in structured datasets.
We introduce a new reasoning mechanism equipped with a normalization-like transfer function that prevents neurons from saturating.
arXiv Detail & Related papers (2021-12-23T17:04:12Z) - Gradient Starvation: A Learning Proclivity in Neural Networks [97.02382916372594]
Gradient Starvation arises when cross-entropy loss is minimized by capturing only a subset of features relevant for the task.
This work provides a theoretical explanation for the emergence of such feature imbalance in neural networks.
arXiv Detail & Related papers (2020-11-18T18:52:08Z) - Connecting Weighted Automata, Tensor Networks and Recurrent Neural
Networks through Spectral Learning [58.14930566993063]
We present connections between three models used in different research fields: weighted finite automata(WFA) from formal languages and linguistics, recurrent neural networks used in machine learning, and tensor networks.
We introduce the first provable learning algorithm for linear 2-RNN defined over sequences of continuous vectors input.
arXiv Detail & Related papers (2020-10-19T15:28:00Z) - Provably Efficient Neural Estimation of Structural Equation Model: An
Adversarial Approach [144.21892195917758]
We study estimation in a class of generalized Structural equation models (SEMs)
We formulate the linear operator equation as a min-max game, where both players are parameterized by neural networks (NNs), and learn the parameters of these neural networks using a gradient descent.
For the first time we provide a tractable estimation procedure for SEMs based on NNs with provable convergence and without the need for sample splitting.
arXiv Detail & Related papers (2020-07-02T17:55:47Z) - Revisiting Initialization of Neural Networks [72.24615341588846]
We propose a rigorous estimation of the global curvature of weights across layers by approximating and controlling the norm of their Hessian matrix.
Our experiments on Word2Vec and the MNIST/CIFAR image classification tasks confirm that tracking the Hessian norm is a useful diagnostic tool.
arXiv Detail & Related papers (2020-04-20T18:12:56Z) - Hierarchical Gaussian Process Priors for Bayesian Neural Network Weights [16.538973310830414]
A desirable class of priors would represent weights compactly, capture correlations between weights, and allow inclusion of prior knowledge.
This paper introduces two innovations: (i) a process-based hierarchical model for network weights based on unit embeddings that can flexibly encode correlated weight structures, and (ii) input-dependent versions of these weight priors that can provide convenient ways to regularize the function space.
We show these models provide desirable test-time uncertainty estimates on out-of-distribution data, demonstrate cases of modeling inductive biases for neural networks with kernels, and demonstrate competitive predictive performance on an active learning benchmark
arXiv Detail & Related papers (2020-02-10T07:19:52Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.