Support Gradient Checkpointing in TensorFlow 2!

Is there a high-level API available in tf 2 (keras) to support gradient checkpointing, similar to the one-line interface .set_grad_checkpointing() provided by PyTorch?

When using gradient checkpointing in PyTorch, the memory consumption can be significantly reduced (e.g., from 11 GB to 4 GB on P100 when training the VitB16 model with a batch size of up to 64). However, in tf, out-of-memory errors occur.

There are existing issues in tf related to this problem, but there has been no active response from the tf team so far. It seems that tf consumes all the available v-ram, making it extremely difficult (if not impossible) to train large models. The addition of the gradient checkpointing feature would potentially solve this issue.

Old issues.

From a TPU doc, it is mentioned

The TPU runtime attempts to optimize operators to fit the model in memory (called rematerialization, similar to gradient checkpointing),

Can someone shed some light on it? What does it mean rematerialization here? Can it be used on GPU?

Rematerialization / Checkpointing

“Rematerialization” or “checkpointing” is a technique for trading off compute time for lower peak memory utilization when performing reverse-mode automatic differentiation. JAX offers several different default rematerialization “policies” that dictate which kinds of intermediate values are preserved from the forward-pass to the backwards-pass calculation, and which are discarded to be recomputed anew in the backwards-pass.