Our team has implemented a
tf.keras.Model sub-classed model overriding the
train_step() method. We also implemented the
call() method of the underlying model.
Now, during evaluation (i.e., running inference on what’s passed to
validation_data), we observe that a spurious dimension gets added to intermediate values. Like the expected dimension is, say
(128, 768) (where 128 denotes the batch size) but we get
(128, 128, 768). But when a separate
test_step() is implemented this behavior goes away.
Has anyone faced something similar?