Hello,

I am trying to integrate a neural network (using Keras) with a probabilistic model (using JointDistributionCoroutineAutoBatched). While the probabilistic model part works fine, I struggle with integrating the neural network. Appreciate any help on this!

Something similar was done here a few years ago using Edward https://willwolf.io/2017/06/15/random-effects-neural-networks/

Here is the working probabilistic (mixed-effects) model, which is pretty much the same as in this tutorial, using the same data (radon) https://www.tensorflow.org/probability/examples/Linear_Mixed_Effects_Model_Variational_Inference

```
def make_joint_distribution_coroutine(floor, county, n_counties):
def model():
# Hyperpriors:
# mu_alpha ~ Normal(0,1)
mu_alpha = yield tfd.Normal(loc=0., scale=.1, name = 'mu_alpha')
# sigma_alpha ~ HalfNormal(1)
sigma_alpha = yield tfd.HalfNormal(scale=.1, name = 'sigma_alpha')
# Priors:
# alpha ~ Normal(mu_alpha, sigma_alpha)
alpha = yield tfd.Normal(loc=mu_alpha*tf.ones(n_counties),
scale=sigma_alpha,
name='alpha')
# beta ~ Normal(0,1)
beta = yield tfd.Normal(loc=0., scale=.1, name = 'beta')
# sigma ~ HalfNormal(1)
sigma = yield tfd.HalfNormal(scale=.1, name = 'sigma')
# Likelihood
fixed_effect = floor*beta
random_effect = tf.gather(alpha, county, axis=-1)
mu = random_effect + fixed_effect
yield tfd.Normal(loc=mu, scale=sigma, name = 'likelihood')
return tfd.JointDistributionCoroutineAutoBatched(model)
joint = make_joint_distribution_coroutine(floor, county, n_counties)
```

Now the goal is to replace the ‘fixed_effect’ part with a neural network. I tried the following, yielding different error messages. Since everything revolves around distributions, I thought it would make sense to have the Neural Network yield a distribution as well. Here ‘floor*beta’ is replaced by ‘neural_network(floor, n_observations)’

```
@tf.function
def neural_network(floor, n_observations):
inputs = tf.keras.layers.Input(shape = (n_observations,), name = "inputs")
outputs = tf.keras.layers.Dense(1, activation = 'linear', kernel_regularizer = tf.keras.regularizers.l2(.001))(inputs)
return tfp.layers.DistributionLambda(make_distribution_fn= lambda outputs: tfd.Normal(loc=outputs, scale=1))
def make_joint_distribution_coroutine(floor, county, n_counties):
def model():
# Hyperpriors:
# mu_alpha ~ Normal(0,1)
mu_alpha = yield tfd.Normal(loc=0., scale=.1, name = 'mu_alpha')
# sigma_alpha ~ HalfNormal(1)
sigma_alpha = yield tfd.HalfNormal(scale=.1, name = 'sigma_alpha')
# Priors:
# alpha ~ Normal(mu_alpha, sigma_alpha)
alpha = yield tfd.Normal(loc=mu_alpha*tf.ones(n_counties),
scale=sigma_alpha,
name='alpha')
# sigma ~ HalfNormal(1)
sigma = yield tfd.HalfNormal(scale=.1, name = 'sigma')
# Likelihood
fixed_effect = neural_network(floor, n_observations)
random_effect = tf.gather(alpha, county, axis=-1)
mu = random_effect + fixed_effect
yield tfd.Normal(loc=mu, scale=sigma, name = 'likelihood')
return tfd.JointDistributionCoroutineAutoBatched(model)
joint = make_joint_distribution_coroutine(floor, county, n_counties)
```

But this yields the following error:

```
Error in py_call_impl(callable, dots$args, dots$keywords) :
TypeError: To be compatible with tf.function, Python functions must return zero or more Tensors or ExtensionTypes or None values; in compilation of <function neural_network at 0x000001E849878310>, found return value of type DistributionLambda, which is not a Tensor or ExtensionType.
```

More in line with the Edward example posted above, I also tried the following:

```
@tf.function
def neural_network(floor, n_observations):
inputs = tf.keras.layers.Input(shape = (n_observations,), name = "inputs")
outputs = tf.keras.layers.Dense(1, activation = 'linear', kernel_regularizer = tf.keras.regularizers.l2(.001))(inputs)
return tf.keras.backend.squeeze(outputs, axis=1)
```

which yields:

```
Error in py_call_impl(callable, dots$args, dots$keywords) :
TypeError: To be compatible with tf.function, Python functions must return zero or more Tensors or ExtensionTypes or None values; in compilation of <function neural_network at 0x000001E849B6B9D0>, found return value of type KerasTensor, which is not a Tensor or ExtensionType.
```