JAX is a high-performance Python library that brings just-in-time compilation, automatic differentiation and easy parallelism to NumPy-style array programming.
JAX unifies familiar NumPy-like array syntax with compiler-level speed-ups. Developed by Google (with Nvidia and community contributions), it combines Autograd-style differentiation with XLA-powered just-in-time (JIT) compilation and powerful program transformations.
JAX underpins research frameworks such as Flax, Haiku, Optax, T5X and Scenic. It is widely adopted for large-scale machine-learning research, differentiable scientific computing, physics simulation and rapid prototyping of novel model architectures.
Capability | Benefit |
---|---|
NumPy-compatible API | Minimal learning curve for Python users |
Composable transforms | Express complex algorithms concisely |
Multi-accelerator back-ends | Run the same code on CPU, GPU or TPU |
Pure-Python workflow | No new DSL to learn; leverage standard tooling |
Apache-2.0 licence | Free for commercial and academic projects |