How to set tfp.layers.DistributionLambda layer convert_to_tensor_fn to be dependent on keras learning phase?

This question is about tensorflow probability and the tf.keras backend.

Could someone provide an example of how to set the convert_to_tensor_fn parameter of a DistributionLambda layer so that the layer returns a sample while in training and the mean otherwise?

I’m trying this with an IndependentNormal layer and am not making any progress. If I assign tf.keras.backend.in_train_phase(d.sample(), d.mean()), then use, the mean is always returned. I also tried tf.keras.learning_phase() in a tf.cond() statement.

One thing I noticed is that both in_train_phase and learning_phase are not mentioned in the tensorflow 2.11 documentation so I’m wondering if they are deprecated or not supported anymore? They disappeared around tf 2.4.

The idea is that I’d like this layer to sample from the distribution during training, but to return the mean of the distribution during validation and inference. This might be a common use-case but I haven’t found any examples that work.