Update only non-zero weights during model retraining

Hi
I have a pruned TF model, which I need to retrain with the streaming data. ). I want to retrain only the non-zero weights of the 80% pruned model. I’d like to avoid creating a mask and performing additional calculations as I want to minimize retraining time. Here’s the training function that I’m currently using.

@tf.function
def train_step(inputs, targets):
    with tf.GradientTape() as tape:
       predictions = model(inputs, training=True)
       loss_value = loss_fn(targets, predictions)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    return loss_value

Is there any way we can calculate the grads only of the non-zero weights or use only the non-zero weights in model.trainable_weights ? If not, is there any way to use tf.IndexedSlices to update non-zero weights efficiently?

Highly appreciate any support on this