"ValueError: Dimensions must be equal" after tfp.vi.fit_surrogate_posterior

Hi,

After running

losses = tfp.vi.fit_surrogate_posterior(
    target_log_prob_fn, 
    surrogate_posterior,
    optimizer=optimizer,
    num_steps=1000, 
    seed=42,
    sample_size=100)

I receive the following error:

ValueError: Dimensions must be equal, but are 1462 and 100 for ‘{{node monte_carlo_variational_loss/expectation/JointDistributionCoroutineAutoBatched_CONSTRUCTED_AT_top_level/log_prob/make_rank_polymorphic/loop_body/fn_of_vectorized_args/add}} = AddV2[T=DT_FLOAT](monte_carlo_variational_loss/expectation/JointDistributionCoroutineAutoBatched_CONSTRUCTED_AT_top_level/log_prob/make_rank_polymorphic/loop_body/fn_of_vectorized_args/GatherV2, monte_carlo_variational_loss/expectation/JointDistributionCoroutineAutoBatched_CONSTRUCTED_AT_top_level/log_prob/make_rank_polymorphic/loop_body/GatherV2_3)’ with input shapes: [1462], [100].

The model integrates a neural net into a multilevel model:

nn_model_layers = keras.Sequential([
  keras.layers.InputLayer(input_shape = (1,)),
  keras.layers.Dense(1),
  tfp.layers.DistributionLambda(lambda t: tfd.Normal(loc=t, scale=1)),
  ])

nn_model = keras.Model(inputs=nn_model_layers.inputs,
                       outputs=nn_model_layers.outputs)
def make_joint_distribution_coroutine(genre, year, num_years, num_observations, nn_model):
  def model():
    # Hyperpriors:
    # mu_alpha ~ Normal(0,1)
    mu_alpha = yield tfd.Normal(loc=0., scale=.1, name = 'alpha_mu')
    # sigma_alpha ~ HalfNormal(0,1)
    sigma_alpha = yield tfd.HalfNormal(scale=.1, name = 'alpha_sigma')
    # Priors:
    # alpha ~ Normal(alpha_mu, alpha_sigma)
    alpha = yield tfd.Normal(loc=mu_alpha*tf.ones(num_years),
                             scale=sigma_alpha,
                             name='alpha')
    # beta ~ neural_network(X)
    beta = yield nn_model(genre)
    # sigma ~ HalfNormal(0,1)
    sigma = yield tfd.HalfNormal(scale=.1, name = 'sigma')
    
    # Likelihood
    random_effect = tf.gather(alpha, year, axis=-1)
    mu = random_effect + beta
    yield tfd.Normal(loc=mu, scale=sigma, name = 'likelihood')
  return tfd.JointDistributionCoroutineAutoBatched(model)

The line beta = yield nn_model(genre) seems to cause the error. nn_model returns a distribution, which should be exepcted when “using” yield, right?

Grateful for any hints on how to resolve this.