07-09, 13:15–13:45 (US/Pacific), Ballroom
This talk provides an overview of several libraries in the open-source JAX ecosystem (such as Equinox, Diffrax, Optimistix, ...) In short, we have been building an "autodifferentiable GPU-capable SciPy". These libraries offer the foundational core of tools that have made it possible for us to train neural networks (e.g. score-based diffusions for image generation), solve PDEs, and smoothly handle hybridisations of the two (e.g. fit neural ODEs to scientific data). By the end of the talk, the goal is for you to be able to walk away with a slew of new modelling tools, suitable for tackling problems both in ML and in science.
Modern deep learning frameworks have provided a plethora of tools for numerical computing, offering advanced functionality like autodifferentiation, autoparallelism, and GPU support. Many of these new functionalities are of great use for scientific modelling! For example stiff differential equation solvers (for biochemical reactions, ...) benefit from autodifferentiation to compute Jacobians, whilst large-scale simulations (weather, astrophysics, ...) benefit from autoparallelism across GPU clusters. This fact has already become well-appreciated by the scientific modelling community, and there is now a substantial effort underway to build the necessary tools in these modern deep learning frameworks.
JAX offers an excellent computational framework for such efforts: it offers a jax.numpy
API which provides an easy onboarding experience to users of existing NumPy-based libraries, whilst its autodifferentiation and autoparallelism tools are best-in-class.
This talk offers an introduction to the JAX scientific ecosystem, which is a well-developed effort of this sort, offering many of the necessary numerical computing primitives for scientific modelling: libraries like Diffrax offer solvers for ODEs, SDEs, some PDEs, and so on, whilst libraries like Optimistix offer solvers for nonlinear problems like root finding or nonlinear least squares. The primitives of the ecosystem are generally at the same level as the tools provided by SciPy, and so often refer to this effort as 'autodifferentiable GPU-capable SciPy'.
There is a further benefit to this approach: many deep learning ideas are themselves finding direct use in scientific problems -- modern scientific models often feature neural networks somewhere in them! -- and so we will also discuss libraries like Equinox for expressing neural networks, or Optax for first-order gradient methods for minimisation problems.
All libraries are available as permissively-licensed open-source projects on GitHub; we refer to https://github.com/patrick-kidger/equinox (2.3k stars) as a starting point, which offers further links to the rest of the ecosystem within.
The intended audience for this talk are those already familiar with NumPy and SciPy, but no familiarity with JAX will be assumed. By the end the audience will have gained a basic familiarity with JAX, its abstractions, the suite of NumPy- and SciPy-like tools available within it, and several examples of these being applied to existing domain-specific problems. We expect that they will then feel sufficiently empowered to go forward and deploy this on their own problems!
Patrick is a tech lead on ML for protein optimization at Cradle.bio, and founded much of the open-source scientific JAX ecosystem. He has previously worked as an ML researcher at Google X, held a visiting appointment at Imperial College London, and received a PhD from Oxford on neural differential equations.