I have a pruned TF model, which I need to retrain with the streaming data. ). I want to retrain only the non-zero weights of the 80% pruned model. I’d like to avoid creating a mask and performing additional calculations as I want to minimize retraining time. Here’s the training function that I’m currently using.
@tf.function def train_step(inputs, targets): with tf.GradientTape() as tape: predictions = model(inputs, training=True) loss_value = loss_fn(targets, predictions) grads = tape.gradient(loss_value, model.trainable_weights) optimizer.apply_gradients(zip(grads, model.trainable_weights)) return loss_value
Is there any way we can calculate the grads only of the non-zero weights or use only the non-zero weights in
model.trainable_weights ? If not, is there any way to use tf.IndexedSlices to update non-zero weights efficiently?
Highly appreciate any support on this