Automatic Functional Differentiation in JAX
- URL: http://arxiv.org/abs/2311.18727v2
- Date: Sun, 28 Jan 2024 07:16:55 GMT
- Title: Automatic Functional Differentiation in JAX
- Authors: Min Lin
- Abstract summary: We extend JAX with the capability to automatically differentiate higher-order functions (functionals and operators)
By representing functions as a generalization of arrays, we seamlessly use JAX's existing primitive system to implement higher-order functions.
- Score: 8.536145202129827
- License: http://creativecommons.org/licenses/by/4.0/
- Abstract: We extend JAX with the capability to automatically differentiate higher-order
functions (functionals and operators). By representing functions as a
generalization of arrays, we seamlessly use JAX's existing primitive system to
implement higher-order functions. We present a set of primitive operators that
serve as foundational building blocks for constructing several key types of
functionals. For every introduced primitive operator, we derive and implement
both linearization and transposition rules, aligning with JAX's internal
protocols for forward and reverse mode automatic differentiation. This
enhancement allows for functional differentiation in the same syntax
traditionally use for functions. The resulting functional gradients are
themselves functions ready to be invoked in python. We showcase this tool's
efficacy and simplicity through applications where functional derivatives are
indispensable. The source code of this work is released at
https://github.com/sail-sg/autofd .
Related papers
- Expanded Gating Ranges Improve Activation Functions [0.0]
We find that Expanded ArcTan Linear Unit (xATLU), Expanded GELU (xGELU), and Expanded SiLU (xSiLU) outperform existing activation functions within a transformer architecture.
We also show that expanded gating ranges show promising results in improving first-order Gated Linear Units (GLU)
arXiv Detail & Related papers (2024-05-25T09:12:17Z) - JaxDecompiler: Redefining Gradient-Informed Software Design [0.0]
JaxDecompiler is a tool that transforms any JAX function into an editable Python code.
This article introduces JaxDecompiler, a tool that transforms any JAX function into an editable Python code.
arXiv Detail & Related papers (2024-03-14T20:32:31Z) - gFaaS: Enabling Generic Functions in Serverless Computing [0.1433758865948252]
gF is a novel framework that facilitates holistic development and management of functions across diverse F platforms.
Results from our experiments demonstrate that gF functions perform similarly to native platform-specific functions across various scenarios.
arXiv Detail & Related papers (2024-01-18T20:25:20Z) - Functional Diffusion [55.251174506648454]
We propose a new class of generative diffusion models, called functional diffusion.
functional diffusion can be seen as an extension of classical diffusion models to an infinite-dimensional domain.
We show generative results on complicated signed distance functions and deformation functions defined on 3D surfaces.
arXiv Detail & Related papers (2023-11-26T21:35:34Z) - Equivariance with Learned Canonicalization Functions [77.32483958400282]
We show that learning a small neural network to perform canonicalization is better than using predefineds.
Our experiments show that learning the canonicalization function is competitive with existing techniques for learning equivariant functions across many tasks.
arXiv Detail & Related papers (2022-11-11T21:58:15Z) - Provable General Function Class Representation Learning in Multitask
Bandits and MDPs [58.624124220900306]
multitask representation learning is a popular approach in reinforcement learning to boost the sample efficiency.
In this work, we extend the analysis to general function class representations.
We theoretically validate the benefit of multitask representation learning within general function class for bandits and linear MDP.
arXiv Detail & Related papers (2022-05-31T11:36:42Z) - Equinox: neural networks in JAX via callable PyTrees and filtered
transformations [4.264192013842096]
JAX and PyTorch are two popular Python autodifferentiation frameworks.
JAX is based around pure functions and functional programming.
PyTorch has popularised the use of an object-oriented (OO) class-based syntax for defining parameterised functions.
arXiv Detail & Related papers (2021-10-30T14:08:56Z) - Differentiable Spline Approximations [48.10988598845873]
Differentiable programming has significantly enhanced the scope of machine learning.
Standard differentiable programming methods (such as autodiff) typically require that the machine learning models be differentiable.
We show that leveraging this redesigned Jacobian in the form of a differentiable "layer" in predictive models leads to improved performance in diverse applications.
arXiv Detail & Related papers (2021-10-04T16:04:46Z) - On Correctness of Automatic Differentiation for Non-Differentiable
Functions [14.222887950206658]
We show that autodiff systems are correct in any formal sense when applied to non-differentiable functions.
We propose a new type of derivatives, called intensional derivatives, and prove that these derivatives always exist and coincide with standard derivatives for almost all inputs.
In this way, we formally establish the correctness of autodiff systems applied to non-differentiable functions.
arXiv Detail & Related papers (2020-06-12T01:57:13Z) - Automatic Differentiation in ROOT [62.997667081978825]
In mathematics and computer algebra, automatic differentiation (AD) is a set of techniques to evaluate the derivative of a function specified by a computer program.
This paper presents AD techniques available in ROOT, supported by Cling, to produce derivatives of arbitrary C/C++ functions.
arXiv Detail & Related papers (2020-04-09T09:18:50Z) - Invariant Feature Coding using Tensor Product Representation [75.62232699377877]
We prove that the group-invariant feature vector contains sufficient discriminative information when learning a linear classifier.
A novel feature model that explicitly consider group action is proposed for principal component analysis and k-means clustering.
arXiv Detail & Related papers (2019-06-05T07:15:17Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.