Support Gradient Checkpointing in TensorFlow 2!

Is there a high-level API available in tf 2 (keras) to support gradient checkpointing, similar to the one-line interface .set_grad_checkpointing() provided by PyTorch?

When using gradient checkpointing in PyTorch, the memory consumption can be significantly reduced (e.g., from 11 GB to 4 GB on P100 when training the VitB16 model with a batch size of up to 64). However, in tf, out-of-memory errors occur.

There are existing issues in tf related to this problem, but there has been no active response from the tf team so far. It seems that tf consumes all the available v-ram, making it extremely difficult (if not impossible) to train large models. The addition of the gradient checkpointing feature would potentially solve this issue.

Old issues.