Ray Summit
JAX: Accelerated Machine Learning Research via Composable Function Transformations in Python
Wednesday, June 23 10:05 AM PDTMatt Johnson, Research Scientist, Google Brain | Tech Lead, JAX
This talk is about JAX, a system for high-performance machine learning research and numerical computing. JAX offers the familiarity of Python+NumPy together with hardware acceleration. JAX combines these features with user-wielded function transformations, including automatic differentiation, automatic vectorized batching, end-to-end compilation (via XLA), parallelizing over multiple accelerators, and more. Composing these transformations is the key to JAX's power and simplicity. It’s used by researchers for a wide range of advanced applications, from large-scale neural net training, to probabilistic programming, to scientific applications in physics and biology.
Speakers

Matt Johnson
Research Scientist, Google Brain | Tech Lead, JAX, Google Brain | JAX