Does model.evaluate reshape output shapes under the hood?

On tensorflow 2.10.0

My following questions are based on the assumption that model.evaluate(X,y) does the same as model.predict(X), then manually doing the evaluation against y.

I’m confused why model.evaluate gives mae 5.5 while my manual calculation using preds (output of model.predict) gives 3.5.

By manually reshaping preds back from 2 n-dim to 3 n-dim to get preds_3d, i can finally recover the 5.5 mae.

What is model.evaluate doing under the hood? I tried tracing through debugging but there are too many nested functions and engineering concerned code to figure out where’s the work done. (if someone can show the debugging flow i’m happy to learn).

Assuming model.evaluate does call model.predict first, how did the 2 n-dim output of model.predict get compared with the 3 n-dim shape of y? (Here i’m assuming my numpy manual calculate is also how tensorflow does it under the hood, since tensorflow depends on numpy).

import tensorflow as tf
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras import optimizers, metrics

import numpy as np
np.random.seed(42)

X  = np.random.randint(0,10, size=(2,112,3))
y  = np.random.randint(0,10, size=(2,1,1))
print(f'{X.shape = }\n{y.shape = }')

def init_baseline():
    model = models.Sequential()
    model.add(layers.Lambda(lambda x: x[:,-1,1,None]))
    
    adam = optimizers.Adam(learning_rate=0.02)
    model.compile(loss='mse', optimizer=adam, metrics=["mae"])

    return model

model = init_baseline()
model.evaluate(X,y)
preds = model.predict(X)
preds_3d = preds.reshape(2,1,1)

print(f'{preds.shape = }\n{preds_3d.shape = }')
print('preds mae',np.abs(y - preds).mean())
print('preds_3d mae',np.abs(y - preds_3d).mean())

Outputs

X.shape = (2, 112, 3)
y.shape = (2, 1, 1)
1/1 [==============================] - 0s 285ms/step - loss: 30.0000 - mae: 5.5000
[30.0, 5.5]
1/1 [==============================] - 0s 71ms/step
preds.shape = (2, 1)
preds_3d.shape = (2, 1, 1)
preds mae 3.5
preds_3d mae 5.5