Converting snake-dqn code to python has ValueError

I’m currently working on converting snake-dqn project(Tensorflow demo, tfjs-examples/snake-dqn at master · tensorflow/tfjs-examples · GitHub) from tensorflow.js to python tensorflow.
And the code below is the converted code of the trainOnReplayBatch function of snake-dqn’s agent.js to Python.

    def train_on_replay_batch(self, batch_size, gamma, optimizer):
        :param batch_size:
        :type batch_size: int
        :param gamma:
        :type gamma:
        :param optimizer:
        :type optimizer:
        batch = self.replay_memory.sample(batch_size)
        #I edited get_state_tensor function because i don't need parameter h and w.
        #i think it is not causing the error.
        state_tensor = get_state_tensor([example[0] for example in batch])
        action_tensor = tf.constant([example[1] for example in batch], dtype='int32')
        reward_tensor = tf.constant([example[2] for example in batch], dtype=tf.int32)
        next_state_tensor = get_state_tensor([example[4] for example in batch])
        with tf.GradientTape() as tape:
            qs = tf.Variable(tf.cast(tf.math.reduce_sum(
                self.online_network.__call__(state_tensor, True) * (tf.one_hot(action_tensor, SQUARE_SIZE)), axis=-1),
            result = tf.Variable(self.target_network(next_state_tensor))
            next_max_q_tensor = tf.math.reduce_max(result, axis=-1)
            done_mask = tf.cast(
                tf.Variable(1, dtype=tf.int32) - tf.Variable([example[3] for example in batch], dtype=tf.int32),
            target_qs = reward_tensor + tf.cast(next_max_q_tensor * done_mask * gamma, dtype=tf.int32)
            loss_fn = keras.losses.MeanSquaredError()
            loss = tf.reduce_mean(loss_fn(target_qs, qs))
        grads = tape.gradient(target=tf.cast(loss,dtype=tf.float32),sources= self.target_network.trainable_variables)

When I run this code, i get ValueError: No gradients provided for any variable.
I asked about this error in other websites, and i got some solutions from them.
Here are solutions:

  1. and make next_max_q_tensor and done_mask to be leaf variable. So remove that code and insert tf.stop_gradient(next_max_q_tensor)
  2. these codes try to update parameters of target_network using next_max_q_tensor (which is leaf variable).So change code to update online_network instead.

And i re-writed code and it comes below.

       def train_on_replay_batch(self, batch_size, gamma, optimizer):
        :param batch_size:
        :type batch_size: int
        :param gamma:
        :type gamma:
        :param optimizer:
        :type optimizer:
        batch = self.replay_memory.sample(batch_size)
        #I edited get_state_tensor function because i don't need parameter h and w.
        #i think it is not causing the error.
        state_tensor = get_state_tensor([example[0] for example in batch])
        action_tensor = tf.constant([example[1] for example in batch], dtype=tf.int32)
        reward_tensor = tf.constant([example[2] for example in batch], dtype=tf.float32)
        next_state_tensor = get_state_tensor([example[4] for example in batch])
        loss_fn = keras.losses.MeanSquaredError()
        with tf.GradientTape() as tape:
            qs = tf.Variable(tf.cast(tf.math.reduce_sum(
                self.online_network.__call__(state_tensor, training=True) * (tf.one_hot(action_tensor, SQUARE_SIZE)),
            result = tf.Variable(self.target_network(next_state_tensor))
            next_max_q_tensor = tf.math.reduce_max(result, axis=-1)
            done_mask = tf.Variable(1, dtype=tf.float32) - tf.Variable([example[3] for example in batch], dtype=tf.float32)
            target_qs = reward_tensor + next_max_q_tensor * done_mask * gamma
            loss = tf.reduce_mean(loss_fn(target_qs, qs))
        grads = tape.gradient(target=tf.cast(loss, dtype=tf.float32), sources=self.target_network.trainable_variables)
        optimizer.apply_gradients(zip(grads, self.online_network.trainable_variables))

But i still get the same error. What should i do to fix it?