Why Seperate pruning and QAT APIs in model optimization toolkit?

Happily using the tfmot in various projects for a while now… but one aspect of the current design puzzles the heck out of me. Why are there disjoint APIs from QAT and pruning?

After all:

  • You can model pruning as “just another kind of Quantizer” (one that maps some values to 0.0 and leaves the rest unchanged).
  • Though less vital than for pruning , supporting scheduling of the degree of quantization applied by Quantizers can be useful in QAT (esp. when quantizing down to sub 8=bit bit-widths).
  • Pruning-as-a-kind-of-QAT also avoids the need for a special-cases two-step training processs if QAT and pruning are to be combined.

A quick PoC implementation (Quantizer composition operator + Incremental Pruning “Quantizer”) created to simplifiy porting models using a legacy internal library seems to work just fine.

On a similar note: the pruning API seems go to some trouble to prune by over-writing pruned weight variables rather than simply “injecting” a masking operation (with straight-through estimator for gradient). Surely, due to constant folding applied when models are optimized for inference (e.g. tflite converter) the “end-result” of masking would be the same for less coding effort?

What am I missing? Anyone from the tfmot Team(s) care to shed any light?