Why is it slow for the first inference

I have two simple network implementations:

import numpy as np
import tensorflow as tf
import time

class SimpleNet(tf.keras.Model):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.mod1 = tf.keras.Sequential([
            tf.keras.layers.Conv2D(64, (3, 3), padding="same"),
        self.final = tf.keras.layers.Conv2D(3, (3, 3), padding="same")

    def call(self, x):
        x = self.mod1(x)
        x = self.final(x)
        return x

def simple_net():
    x = tf.keras.Input(shape=(256, 256, 3))

    y = tf.keras.Sequential([
        tf.keras.layers.Conv2D(64, (3, 3), padding="same"),

    y = tf.keras.layers.Conv2D(3, (3, 3), padding="same")(y)

    return tf.keras.Model(x, y)

if __name__ == "__main__":

    x_test = np.random.random((32, 256, 256, 3))
    x_test = tf.convert_to_tensor(x_test)

    model1 = SimpleNet()
    model1.build(tf.TensorShape((None, 256, 256, 3)))

    model2 = simple_net()


    for i in range(4):
        s = time.time()
        pred = model1(x_test)
        e = time.time()
        print(f"time: {e-s} s")


    for i in range(4):
        s = time.time()
        pred = model2(x_test)
        e = time.time()
        print(f"time: {e-s} s")

The result is:

2022-07-14 11:32:45.087770: I tensorflow/core/util/util.cc:175] Experimental oneDNN custom operations are on. If you experience issues, please turn them off by setting the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
(32, 256, 256, 3)
Model: "simple_net"
 Layer (type)                Output Shape              Param #
 sequential (Sequential)     multiple                  1792

 conv2d_1 (Conv2D)           multiple                  1731

Total params: 3,523
Trainable params: 3,523
Non-trainable params: 0
Model: "model"
 Layer (type)                Output Shape              Param #
 input_1 (InputLayer)        [(None, 256, 256, 3)]     0

 sequential_1 (Sequential)   multiple                  1792

 conv2d_3 (Conv2D)           (None, 256, 256, 3)       1731

Total params: 3,523
Trainable params: 3,523
Non-trainable params: 0
time: 0.029999732971191406 s
time: 0.006798267364501953 s
time: 0.0066530704498291016 s
time: 0.00666046142578125 s

time: 0.006479501724243164 s
time: 0.006349802017211914 s
time: 0.006381988525390625 s
time: 0.006410121917724609 s

I feel that the first inference time of SimpleNet is a bit strange. The first time is significantly longer than the laters. Also, I notice that the first implementation is always slower than the second
So I want to ask for help.

  1. Do I write the code wrong? How to make the first inference as fast as others.
  2. Why the two implementations take different times (the first implementation slower the second)

Seems like this is exactly the issue on mine.