Spurious dimensionality during evaluation

Our team has implemented a tf.keras.Model sub-classed model overriding the train_step() method. We also implemented the call() method of the underlying model.

Now, during evaluation (i.e., running inference on what’s passed to validation_data), we observe that a spurious dimension gets added to intermediate values. Like the expected dimension is, say (128, 768) (where 128 denotes the batch size) but we get (128, 128, 768). But when a separate test_step() is implemented this behavior goes away.

Has anyone faced something similar?

Cc: @nilabhra

Do you have a minimized version tto shaee to reproduce this?