Equinox: neural networks in JAX via callable PyTrees and filtered
transformations
- URL: http://arxiv.org/abs/2111.00254v1
- Date: Sat, 30 Oct 2021 14:08:56 GMT
- Title: Equinox: neural networks in JAX via callable PyTrees and filtered
transformations
- Authors: Patrick Kidger and Cristian Garcia
- Abstract summary: JAX and PyTorch are two popular Python autodifferentiation frameworks.
JAX is based around pure functions and functional programming.
PyTorch has popularised the use of an object-oriented (OO) class-based syntax for defining parameterised functions.
- Score: 4.264192013842096
- License: http://creativecommons.org/licenses/by/4.0/
- Abstract: JAX and PyTorch are two popular Python autodifferentiation frameworks. JAX is
based around pure functions and functional programming. PyTorch has popularised
the use of an object-oriented (OO) class-based syntax for defining
parameterised functions, such as neural networks. That this seems like a
fundamental difference means current libraries for building parameterised
functions in JAX have either rejected the OO approach entirely (Stax) or have
introduced OO-to-functional transformations, multiple new abstractions, and
been limited in the extent to which they integrate with JAX (Flax, Haiku,
Objax). Either way this OO/functional difference has been a source of tension.
Here, we introduce `Equinox', a small neural network library showing how a
PyTorch-like class-based approach may be admitted without sacrificing JAX-like
functional programming. We provide two main ideas. One: parameterised functions
are themselves represented as `PyTrees', which means that the parameterisation
of a function is transparent to the JAX framework. Two: we filter a PyTree to
isolate just those components that should be treated when transforming (`jit',
`grad' or `vmap'-ing) a higher-order function of a parameterised function --
such as a loss function applied to a model. Overall Equinox resolves the above
tension without introducing any new programmatic abstractions: only PyTrees and
transformations, just as with regular JAX. Equinox is available at
\url{https://github.com/patrick-kidger/equinox}.
Related papers
- JAXbind: Bind any function to JAX [0.0]
JAXbind provides an easy-to-use Python interface for defining custom, so-called JAX primitives.
JAXbind allows a user to interface the JAX function engine with custom derivatives and rules, enabling all JAX transformations for the custom primitive.
arXiv Detail & Related papers (2024-03-13T16:50:04Z) - TopoX: A Suite of Python Packages for Machine Learning on Topological
Domains [89.9320422266332]
TopoX is a Python software suite that provides reliable and user-friendly building blocks for computing and machine learning on topological domains.
TopoX consists of three packages: TopoNetX, TopoEmbedX and TopoModelx.
arXiv Detail & Related papers (2024-02-04T10:41:40Z) - Automatic Functional Differentiation in JAX [8.536145202129827]
We extend JAX with the capability to automatically differentiate higher-order functions (functionals and operators)
By representing functions as a generalization of arrays, we seamlessly use JAX's existing primitive system to implement higher-order functions.
arXiv Detail & Related papers (2023-11-30T17:23:40Z) - Hyperparameter Tuning Cookbook: A guide for scikit-learn, PyTorch,
river, and spotPython [0.20305676256390928]
This document provides a guide to hyperparameter tuning using spotPython for scikit-learn, PyTorch, and river.
With a hands-on approach and step-by-step explanations, this cookbook serves as a practical starting point.
arXiv Detail & Related papers (2023-07-17T16:20:27Z) - PyTorch Hyperparameter Tuning - A Tutorial for spotPython [0.20305676256390928]
This tutorial includes a brief comparison with Ray Tune, a Python library for running experiments and tuning hyperparameters.
We show that spotPython achieves similar or even better results while being more flexible and transparent than Ray Tune.
arXiv Detail & Related papers (2023-05-19T17:47:50Z) - SequeL: A Continual Learning Library in PyTorch and JAX [50.33956216274694]
SequeL is a library for Continual Learning that supports both PyTorch and JAX frameworks.
It provides a unified interface for a wide range of Continual Learning algorithms, including regularization-based approaches, replay-based approaches, and hybrid approaches.
We release SequeL as an open-source library, enabling researchers and developers to easily experiment and extend the library for their own purposes.
arXiv Detail & Related papers (2023-04-21T10:00:22Z) - Equivariance with Learned Canonicalization Functions [77.32483958400282]
We show that learning a small neural network to perform canonicalization is better than using predefineds.
Our experiments show that learning the canonicalization function is competitive with existing techniques for learning equivariant functions across many tasks.
arXiv Detail & Related papers (2022-11-11T21:58:15Z) - PyHopper -- Hyperparameter optimization [51.40201315676902]
We present PyHopper, a black-box optimization platform for machine learning researchers.
PyHopper's goal is to integrate with existing code with minimal effort and run the optimization process with minimal necessary manual oversight.
With simplicity as the primary theme, PyHopper is powered by a single robust Markov-chain Monte-Carlo optimization algorithm.
arXiv Detail & Related papers (2022-10-10T14:35:01Z) - JAXFit: Trust Region Method for Nonlinear Least-Squares Curve Fitting on
the GPU [0.0]
We implement a trust region method on the GPU for nonlinear least squares curve fitting problems using a new deep learning Python library called JAX.
Our open source package, JAXFit, works for both unconstrained and constrained curve fitting problems.
arXiv Detail & Related papers (2022-08-25T16:13:29Z) - PyHHMM: A Python Library for Heterogeneous Hidden Markov Models [63.01207205641885]
PyHHMM is an object-oriented Python implementation of Heterogeneous-Hidden Markov Models (HHMMs)
PyHHMM emphasizes features not supported in similar available frameworks: a heterogeneous observation model, missing data inference, different model order selection criterias, and semi-supervised training.
PyHHMM relies on the numpy, scipy, scikit-learn, and seaborn Python packages, and is distributed under the Apache-2.0 License.
arXiv Detail & Related papers (2022-01-12T07:32:36Z) - OPFython: A Python-Inspired Optimum-Path Forest Classifier [68.8204255655161]
This paper proposes a Python-based Optimum-Path Forest framework, denoted as OPFython.
As OPFython is a Python-based library, it provides a more friendly environment and a faster prototyping workspace than the C language.
arXiv Detail & Related papers (2020-01-28T15:46:19Z)
This list is automatically generated from the titles and abstracts of the papers in this site.
This site does not guarantee the quality of this site (including all information) and is not responsible for any consequences.