Keras Newsletter (October 6, 2023)

Keras 3 new features overview

  • New distribution API
    • Enable data and model parallel training across devices
    • Initially via JAX, with TensorFlow and PyTorch support in later versions
  • SparseTensor support with TensorFlow
  • Train a model in one framework, reload in another framework
  • Use the runtime most appropriate for your hardware / environment, without any code change
  • Write framework-agnostic layers, models, losses, metrics, optimizers, and reuse them in native TF/JAX/PyTorch workflows

Keras 3.0 pre-release timeline

  • How to install Keras 3?
    • Now: pip install keras-core
    • In a few weeks: pip install keras-nightly
    • In November (tentative): pip install keras
  • Contributing:

tf.keras compatibility

  • TensorFlow 2.16 will use to Keras 3 by default.
  • Upgrade from Keras 2 to Keras 3 for TensorFlow users
    • No change needed in the user code
    • If anything breaks, use the legacy tf.keras
  • How to install the keras 2 (the legacy tf.keras)?
    • TensorFlow 2.15 and earlier:
      • pip install tensorflow
    • From TensorFlow 2.16 release (Q1 2024):
      • pip uninstall keras
      • pip install tf-keras

Keras multi-backend distribute API

  • Lightweight data parallel / model parallel distribution API built on top of:

  • All the heavy lifting is already done in XLA and GSPMD !

  • Main class:

    • DataParallel/ModelParallel: distribution setting for model weights and input data
    • LayoutMap: encapsulates sharding for your Keras model
  • Other API classes map directly onto backend primitives:

    • DeviceMesh β†’ jax.sharding.Mesh
    • TensorLayout β†’ jax.sharding.NamedSharding
  • Example:

  • Users can also train their Keras models using:

    • A custom training loop
    • Backend-specific distribution APIs directly
  • The roadmap

    • Jax.sharding implementation - DONE
    • Multi-worker/process training
    • Model saving and checkpointing for distributed models
    • Utility for distributing datasets (multi-backend equivalent of tf.distribute.Strategy.distribute_datasets_from_function)
    • Sharding constraints on intermediate values (equivalent of jax.lax.with_sharding_constraint)
    • MultiSlice capability (jax.experimental.mesh_utils.create_hybrid_device_mesh)
    • PyTorch/XLA sharding implementation
    • TensorFlow DTensor implementation

KerasCV updates

  • 0.7.0 release coming soon with some exciting new features
    • Semantic segmentation support
      • Full coverage of augmentation layers
      • DeepLabV3Plus
      • SegFormer
    • Segment Anything Model (SAM)
      • Guide to follow the 0.7 release
    • Various small bug fixes
      • Switched to programmatic API export

KerasNLP updates

  • Upcoming features for LLM workflows.

    • Preconfigured model parallelism for backbones.
    • LoRA API for efficient fine tuning.
  • Let’s take a look.

    • Note that these are still under development.
    • APIs might change!
  • KerasNLP - Lora high-level

KerasNLP - Lora low-level

KerasNLP - Model parallelism

KerasTuner updates

  • KerasTuner supports multi-backend Keras
    • From KerasTuner v1.4
  • Only exporting public APIs
    • From KerasTuner v1.4
    • All private APIs are hidden under keras_tuner.src
      • Change import keras_tuner.*** to import keras_tuner.src.***
1 Like