ValueError: Shape must be rank 2 but is rank 3

Full error message:
ValueError: Shape must be rank 2 but is rank 3 for ‘{{node in_top_k/InTopKV2}} = InTopKV2[T=DT_INT64](sequential_1/dense_85/Softmax, ArgMax, in_top_k/InTopKV2/k)’ with input shapes: [1,6,1], [1,6], [].

Traceback:

in user code:

    File "C:\Users\Wellson\AppData\Roaming\Python\Python310\site-packages\keras\engine\training.py", line 1021, in train_function  *
        return step_function(self, iterator)
    File "C:\Users\Wellson\AppData\Roaming\Python\Python310\site-packages\keras\engine\training.py", line 1010, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\Wellson\AppData\Roaming\Python\Python310\site-packages\keras\engine\training.py", line 1000, in run_step  **
        outputs = model.train_step(data)
    File "C:\Users\Wellson\AppData\Roaming\Python\Python310\site-packages\keras\engine\training.py", line 864, in train_step
        return self.compute_metrics(x, y, y_pred, sample_weight)
    File "C:\Users\Wellson\AppData\Roaming\Python\Python310\site-packages\keras\engine\training.py", line 957, in compute_metrics
        self.compiled_metrics.update_state(y, y_pred, sample_weight)
    File "C:\Users\Wellson\AppData\Roaming\Python\Python310\site-packages\keras\engine\compile_utils.py", line 459, in update_state
        metric_obj.update_state(y_t, y_p, sample_weight=mask)
    File "C:\Users\Wellson\AppData\Roaming\Python\Python310\site-packages\keras\utils\metrics_utils.py", line 70, in decorated
        update_op = update_state_fn(*args, **kwargs)
    File "C:\Users\Wellson\AppData\Roaming\Python\Python310\site-packages\keras\metrics.py", line 178, in update_state_fn
        return ag_update_state(*args, **kwargs)
    File "C:\Users\Wellson\AppData\Roaming\Python\Python310\site-packages\keras\metrics.py", line 729, in update_state  **
        matches = ag_fn(y_true, y_pred, **self._fn_kwargs)
    File "C:\Users\Wellson\AppData\Roaming\Python\Python310\site-packages\keras\metrics.py", line 4112, in top_k_categorical_accuracy
        tf.compat.v1.math.in_top_k(

    ValueError: Shape must be rank 2 but is rank 3 for '{{node in_top_k/InTopKV2}} = InTopKV2[T=DT_INT64](sequential_1/dense_85/Softmax, ArgMax, in_top_k/InTopKV2/k)' with input shapes: [1,6,1], [1,6], [].

I found this error when using the Top_k_categorical_accuracy metrics, although I am not really sure why there is a shape issue only with this particular metric.

I then ran more tests to check if this is really a metric issue or another issue and below are my findings so far:

*Note: Error = 1 means error happened. Error = 0 means error did not happen.

Please do let me know if I am supposed to use this metric differently (than other metrics) or if this is really an implementation issue (i.e., bug).

Thanks!