Poisson matrix factorization in TFP with variational inference and surrogate posterior

Dear Tensorflow community,

I am new to Tensorflow and have thus encountered problem when I try to implement Poisson Matrix factorization (on the Wisconsin Breast Cancer dataset). The problem is the following:

I wish to build a gamma/poisson version of the PPCA model described here:

I have defined my model with gamma and Poisson distributions for the posterior, i.e. the target log joint probability, as well as initialising the two latent variables in my model (u,v). Moreover, That is:

N, M = x_train.shape
L = 5
min_scale = 1e-5
mask = 1-holdout_mask

# Number of data points = N, data dimension = M, latent dimension = L
def pmf_model(M, L, N, gamma_prior = 0.1, mask = mask):
    v = yield tfd.Gamma(concentration = gamma_prior * tf.ones([L, M]),
                        rate = gamma_prior * tf.ones([L, M]),
                        name = "v")  # parameter
    u = yield tfd.Gamma(concentration = gamma_prior * tf.ones([N, L]),
                        rate = gamma_prior * tf.ones([N, L]),
                        name = "u")  # local latent variable
    x = yield tfd.Poisson(rate = tf.multiply(tf.matmul(u, v), mask), name="x")  # (modeled) data
    
pmf_model(M = M, L = L, N = N, mask = mask)

concrete_pmf_model = functools.partial(pmf_model,
                                       M = M,
                                       L = L,
                                       N = N,
                                       mask = mask)

model = tfd.JointDistributionCoroutineAutoBatched(concrete_pmf_model)

# Initialize v and u as a tensorflow variable
v = tf.Variable(tf.random.gamma([L, M], alpha = 0.1))
u = tf.Variable(tf.random.gamma([N, L], alpha = 0.1))

# target log joint porbability
target_log_prob_fn = lambda v, u: model.log_prob((v, u, x_train))

# Initialize v and u as a tensorflow variable
v = tf.Variable(tf.random.gamma([L, M], alpha = 0.1))
u = tf.Variable(tf.random.gamma([N, L], alpha = 0.1))

Then I need to state trainable variables/ parameters, which I do in the following (possibly wrong) way:


qV_variable0 = tf.Variable(tf.random.uniform([L, M]))
qU_variable0 = tf.Variable(tf.random.uniform([N, L]))

qV_variable1 = tf.maximum(tfp.util.TransformedVariable(tf.random.uniform([L, M]), 
                                                        bijector=tfb.Softplus()), min_scale)
qU_variable1 = tf.maximum(tfp.util.TransformedVariable(tf.random.uniform([N, L]), 
                                                        bijector=tfb.Softplus()), min_scale)


Ultimately, I make my model for the surrogate posterior and estimate the losses and trainable parameters:

def factored_pmf_variational_model():
    
    qv = yield tfd.TransformedDistribution(distribution = tfd.Normal(loc = qV_variable0,
                                                                     scale = qV_variable1), 
                                           bijector = tfb.Exp(),
                                           name = "qv")
    
    qu = yield tfd.TransformedDistribution(distribution = tfd.Normal(loc = qU_variable0,
                                                                     scale = qU_variable1), 
                                           bijector = tfb.Exp(),
                                           name = "qu")

surrogate_posterior = tfd.JointDistributionCoroutineAutoBatched(
    factored_pmf_variational_model)


losses = tfp.vi.fit_surrogate_posterior(
    target_log_prob_fn,
    surrogate_posterior=surrogate_posterior,
    optimizer=tf.optimizers.Adam(learning_rate=0.05),
    num_steps=500)

My code does NOT give an error however, my (after running the entire script stated here) trainable parameters are NaN for the qV_variable0 and qU_variable0. Is there a kind person, who can tell me why it goes wrong and it would be lovely to see a demonstration of how to use bijectors in the correct manner in tensorflow probability/ distributions with models estimated using Variational inference. Please also let me know if it is my target model or my surrogate posterior understanding that is wrong.

Thank you so much in advance!

maybe @Christopher_Suter might be able to help

1 Like