Distribution transformed by IAF layer has log_prob pb with tfkl.Input


I am not an expert on TF and I try to implement some Normalizing Flows to setup exercises on density estimations. I am working with TF 2.8.2 and import usual tf lib.

import tensorflow as tf
import tensorflow_probability as tfp

from tensorflow_probability import bijectors as tfb
from tensorflow_probability import distributions as tfd
from tensorflow import keras as tfk
from tensorflow.keras import layers as tfkl

Now, I face a problem trying to use a simple IAF layer

base_dist = tfd.MultivariateNormalDiag(loc=tf.zeros([2], DTYPE),name='base dist')

flow_bijector = tfb.Invert(tfb.MaskedAutoregressiveFlow(name ='IAF',
                params=2, hidden_units=[512, 512], activation='relu')))

trans_dist = tfd.TransformedDistribution(

If the following calls seems to be ok

x_test = trans_dist.sample()


<tf.Tensor: shape=(), dtype=float32, numpy=-3.4625406>

preparing a training raise an error

x_ = tfkl.Input(shape=(2,), dtype=tf.float32)
log_prob_ = trans_dist.log_prob(x_)
model = tfk.Model(x_, log_prob_)
                loss=lambda _, log_prob: -tf.reduce_mean(log_prob)) 


----> 2 log_prob_ = trans_dist.log_prob(x_)

TypeError: You are passing KerasTensor(type_spec=TensorSpec(shape=(), dtype=tf.int32, name=None), 
inferred_value=[2], name='tf.math.reduce_prod_2/Prod:0', description="created by layer 
'tf.math.reduce_prod_2'"), an intermediate Keras symbolic input/output, to a TF API that does not allow 
registering custom dispatchers, such as `tf.cond`, `tf.function`, gradient tapes, or `tf.map_fn`. Keras 
Functional model construction only supports TF API calls that *do* support dispatching, such as 
`tf.math.add` or `tf.reshape`. Other APIs cannot be called directly on symbolic Kerasinputs/outputs. 
You can work around this limitation by putting the operation in a custom Keras layer `call` and calling 
that layer on this symbolic input/output.

Notice that if instead of IAF , I use a MAF layer

flow_bijector = tfb.MaskedAutoregressiveFlow(name ='IAF',
                params=2, hidden_units=[512, 512], activation='relu'))

then no error is raised.

Does someone can help me, that would be great.

Here is an update: there is similar problem with Real NVP modeling with Keras Model schema as I noticed in this post.

I suspect a common origin of failure.