Keras Preprocessing - adapt multiple layers in one go

I’m a huge fan of tf.data mainly on how it speeds up the preprocessing of large datasets that don’t fit in memory. I have been using the Keras Preprocessing layers for a while now and I’m still struggling to overcome one main issue that is to adapt multiple layers at once.

In the example given to Introduce Preprocessing layers in Keras the author shows this snippet:

text_vectorizer = tf.keras.layers.TextVectorization(
     output_mode='multi_hot', max_tokens=2500)
features = train_ds.map(lambda x, y: x)
text_vectorizer.adapt(features)

normalizer = tf.keras.layers.Normalization(axis=None)
normalizer.adapt(features.map(lambda x: tf.strings.length(x)))

def preprocess(x):
  multi_hot_terms = text_vectorizer(x)
  normalized_length = normalizer(tf.strings.length(x))
  # Combine the multi-hot encoding with review length.
  return tf.keras.layers.concatenate((multi_hot_terms, normalized_length))

def forward_pass(x):
  return tf.keras.layers.Dense(1)(x)  # Linear model.

inputs = tf.keras.Input(shape=(1,), dtype='string')
outputs = forward_pass(preprocess(inputs))
model = tf.keras.Model(inputs, outputs)
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True))
model.fit(train_ds, epochs=5)

Now because we’re calling adapt twice this code will iterate over the dataset two times. My question is if there is a way to change this code so that both layers will be adapted in one pass over the data? Some sort of Model class but for preprocessing like this:

class Preprocessor(tf.made_up_class.Preprocess):
  def __init__(**kwargs):
    self.text_vectorizer = tf.keras.layers.TextVectorization(
     output_mode='multi_hot', max_tokens=2500)
    self.normalizer() = tf.keras.layers.Normalization(axis=None)

  def adapt(self, x):
    vectorized_text = self.text_vectorizer(x)
    out = self.normalizer(vectorized_text)
    return out


preprocessor = Preprocessor()
preprocessor.adapt(features)

Maybe this specific example is tricky but many time one ends up fitting many StringLookup layers for different columns when using structured data, which can take hours if your data is big.

I saw this post about a new package but I’m not sure it preprocess the features in one pass.