Poisson matrix factorization in TFP. Why are losses = Inf and variables not updated?

Hi,

When I run the following code, my tf parameter variables are not updated and the losses are Inf. Can anyone help here? Thanks!

PS I believe that my mistake is in the steps: “# Initialize the parameters in the variational distribution” AND/ OR “#Variational model and surrogate posterior”:

from sklearn.datasets import load_breast_cancer
data = load_breast_cancer()
num_fea = 30
df = pd.DataFrame(data["data"][:,:num_fea], columns=data["feature_names"][:num_fea])
X = np.array((dfX - dfX.mean())/dfX.std()) # Standardize
num_datapoints, data_dim = X.shape

holdout_portion = 0.2

n_holdout = int(holdout_portion * num_datapoints * data_dim)

holdout_row = np.random.randint(num_datapoints, size=n_holdout)
holdout_col = np.random.randint(data_dim, size=n_holdout)

holdout_mask = (sparse.coo_matrix((np.ones(n_holdout),     #The data (ones) in any order
                            (holdout_row, holdout_col)),   # Indices of which rows and columns the data needs to be placed
                            shape = X.shape)).toarray()    # the shape of the entire matrix in which the data needs to be placed and other entries left empty

holdout_subjects = np.unique(holdout_row)
#print(holdout_mask)
holdout_mask = np.minimum(1, holdout_mask)                  # There were some repetitions, which also needs to be one

x_train = np.multiply(1-holdout_mask, X)
x_vad = np.multiply(holdout_mask, X)

num_datapoints, data_dim = x_train.shape
latent_dim = 5
def pmf_model(data_dim, latent_dim, num_datapoints, mask, gamma_prior = 0.1):
    w = yield tfd.Gamma(concentration = gamma_prior * tf.ones([latent_dim, data_dim]),
                        rate = gamma_prior * tf.ones([latent_dim, data_dim]),
                        name="w")  # parameter
    z = yield tfd.Gamma(concentration = gamma_prior * tf.ones([num_datapoints, latent_dim]),
                        rate = gamma_prior * tf.ones([num_datapoints, latent_dim]),
                        name="z")  # local latent variable / substitute confounder
    x = yield tfd.Poisson(rate = tf.multiply(tf.matmul(z, w), mask),
                          name="x")  # (modeled) data
    


concrete_pmf_model = functools.partial(pmf_model,
                                       data_dim=data_dim,
                                       latent_dim=latent_dim,
                                       num_datapoints=num_datapoints,
                                       mask=mask)

model = tfd.JointDistributionCoroutineAutoBatched(concrete_pmf_model)


# Initialize w and z as a tensorflow variable
w = tf.Variable(tf.random.gamma([latent_dim, data_dim], alpha = 0.1))
z = tf.Variable(tf.random.gamma([num_datapoints, latent_dim], alpha = 0.1))

# target log joint porbability
target_log_prob_fn = lambda w, z: model.log_prob((w, z, x_train))

# Initialize the parameters in the variational distribution
qw_conc = tf.random.uniform([latent_dim, data_dim], minval = 1e-5)
qz_conc = tf.random.uniform([num_datapoints, latent_dim], minval = 1e-5)
qw_rate = tf.maximum(tfp.util.TransformedVariable(tf.random.uniform([latent_dim, data_dim]),
                                                  bijector=tfb.Softplus()), 1e-5)
qz_rate = tf.maximum(tfp.util.TransformedVariable(tf.random.uniform([num_datapoints, latent_dim]),
                                                  bijector=tfb.Softplus()), 1e-5)

# Variational model and surrogate posterior:
def factored_gamma_variational_model():
    qw = yield tfd.TransformedDistribution(distribution = tfd.Normal(loc = qw_conc,
                                                                     scale = qw_rate),
                                           bijector = tfb.Exp(),
                                           name = "qw")
    qz = yield tfd.TransformedDistribution(distribution = tfd.Normal(loc = qz_conc,
                                                                     scale = qz_rate),
                                           bijector = tfb.Exp(),
                                           name = "qz")

surrogate_posterior = tfd.JointDistributionCoroutineAutoBatched(
    factored_gamma_variational_model)

losses = tfp.vi.fit_surrogate_posterior(
    target_log_prob_fn,
    surrogate_posterior=surrogate_posterior,
    optimizer=tf.optimizers.Adam(learning_rate=0.05),
    num_steps=500)

losses