Ray Summit

JAX: Accelerated Machine Learning Research via Composable Function Transformations in Python

Wednesday, June 23, 5:05PM UTC

Matt Johnson, Research Scientist, Google Brain | Tech Lead, JAX

View Slides >>>

This talk is about JAX, a system for high-performance machine learning research and numerical computing. JAX offers the familiarity of Python+NumPy together with hardware acceleration. JAX combines these features with user-wielded function transformations, including automatic differentiation, automatic vectorized batching, end-to-end compilation (via XLA), parallelizing over multiple accelerators, and more. Composing these transformations is the key to JAX's power and simplicity. It’s used by researchers for a wide range of advanced applications, from large-scale neural net training, to probabilistic programming, to scientific applications in physics and biology.

Speakers

Matt Johnson

Matt Johnson

Research Scientist, Google Brain | Tech Lead, JAX

Matt Johnson is a research scientist at Google Brain interested in software systems powering machine learning research. He's the tech lead for JAX, a system for high-performance machine learning research and numerical computing. When moonlighting as a machine learning researcher, he works on making neural ODEs faster to solve, automatically exploiting conjugacy in probabilistic programs, and composing graphical models with neural networks. Matt was a postdoc with Ryan Adams at the Harvard Intelligent Probabilistic Systems Group and Bob Datta in the Datta Lab at the Harvard Medical School. His Ph.D. is from MIT in EECS, where he worked with Alan Willsky on Bayesian time series models and scalable inference. He was an undergrad at UC Berkeley (Go Bears!).