Distributed training with JAX & Flax

Distributed training with JAX & Flax

Distributed training with JAX & Flax. Training models on accelerators with JAX and Flax differs slightly from training with CPU.