Why tf.while_loop didn't work after set jit_compile=True

I’ tring to use jit_complile to accelerate training of my model, and after that the loop in my code didn’t work anymore. Here is the code of training:

@tf.function(autograph=True, jit_compile=True)
    def train_step(self, label, fea_ids, fea_vals, model):
        with tf.GradientTape() as tape:
            pred = model([fea_ids, fea_vals])
            loss = model.loss(label, pred)
            loss = loss - 0.5 * pred
            gradients = tape.gradient(loss, model.trainable_weights)
            model.optimizer.apply_gradients(zip(gradients, model.trainable_weights))
        return tf.reduce_mean(loss)

And here is the loop:

single_mask = tf.where(feat_index > 0, True, False)
        # before_multihot_single_mask=single_mask
def single_mask_while_loop(single_mask):
    def cond(i, single_mask):
        return i < self.len_multihot_fea
    def body(i, single_mask):
        single_mask = single_mask & (tf.where(feat_index < self.multi_hot_fea_tf[i, 0], True, False) | tf.where(feat_index>=self.multi_hot_fea_tf[i, 1], True, False))
        return i + 1, single_mask
    i = tf.constant(0, dtype=tf.int64)
    i, single_mask = tf.while_loop(cond, body, [i, single_mask])
    return single_mask
single_mask = single_mask_while_loop(single_mask)

After check, single_mask didn’t change after the loop, how can I solve this? Thanks for help.

Hi @luoyang102605, I have executed a sample code that contains tf.while_loop while enabling jit_compile I did not face any error. Here is the sample code i have used

def sumSquare2(n):
  i, result = tf.constant(0), tf.constant(0)
  c = lambda i, _: tf.less(i, n)
  b = lambda i, result: (i + 1, result + i * i)
  return tf.while_loop(c, b, [i, result])[1]

If possible Could you please provide the complete code to reproduce the issue. Thank You

Hi@Kiran_Sai_Ramineni, here are codes of model and interface:

Thanks for help!