# TFP with JAX backend: how to train a bijector?

Dear experts,
I was aiming to transpose this demo of training a user Bijector with the JAX backend.

I run on Google Colab. Here is part of the code which seems to work. After there are problems for the training code

``````import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
mpl.rc('image', cmap='jet')
mpl.rcParams['font.size'] = 16

import jax
import jax.numpy as jnp
import numpy as np

import tensorflow.compat.v2 as tf
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

class Cubic(tfb.Bijector):
def __init__(self, a, b, validate_args=False, name='Cubic'):
self.a = jnp.atleast_1d(a)
self.b = jnp.atleast_1d(b)
super(Cubic, self).__init__(
validate_args=validate_args, forward_min_event_ndims=0, name=name)

def _forward(self,x):
return jnp.squeeze(jnp.power(self.a*x + self.b,3))

def _inverse(self,y):
return (jnp.sign(y)*jnp.power(jnp.abs(y),1/3)-self.b)/self.a

def _forward_log_det_jacobian(self,x):
return jnp.log(3.*jnp.abs(self.a))+2.*jnp.log(jnp.abs(self.a*x+self.b))

cubic = Cubic([0.25],[-0.1])
x = jnp.linspace(-10,10,500).reshape(-1,1)
plt.plot(x,cubic.forward(x))
plt.show()

plt.plot(x,cubic.inverse(x),lw=3)
plt.plot(x,tfb.Invert(cubic).forward(x),ls="--",c="cyan")
plt.show()

plt.plot(x,cubic.forward_log_det_jacobian(x,event_ndims=0))
plt.show()

# Target distrib
probs = [0.45,0.55]
mix_gauss = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(
probs=probs),
components_distribution=tfd.Normal(
loc=[2.3, -0.8],       # One for each component.
scale=[0.4, 0.4]))  # And same here.

x = jnp.linspace(-5.0,5.0,100)
plt.plot(x,mix_gauss.prob(x))
plt.title('Data distribution')
plt.show()

``````

Now I would like to get some training/validation dataset so I proceed to

``````x_train = mix_gauss.sample(10000, seed = jax.random.PRNGKey(0))
x_train = tf.data.Dataset.from_tensor_slices(x_train)
x_train = x_train.batch(128)

x_valid = mix_gauss.sample(1000, seed = jax.random.PRNGKey(1))
x_valid = tf.data.Dataset.from_tensor_slices(x_valid)
x_valid = x_valid.batch(128)
``````

Then,

``````trainable_inv_cubic = tfb.Invert(Cubic(a=0.25,b=-0.1))
# (1) Base distn
normal = tfd.Normal(loc=0.,scale=1.)
# trainable distrib
trainable_dist = tfd.TransformedDistribution(normal,trainable_inv_cubic)

x = jnp.linspace(-5,5,100)
plt.figure(figsize=(12,4))
plt.plot(x,mix_gauss.prob(x),label='data')
plt.plot(x,trainable_dist.prob(x),label='trainable')
plt.title('Data & Trainable distribution')
plt.show()
``````

The problem is the following as there are no `trainable_inv_cubic.trainable_variables` the computation of grads cannot be done

``````num_epochs = 10
train_losses = []
valid_losses = []

for epoch in range(num_epochs):
print("Epoch {}...".format(epoch))
train_loss = tf.keras.metrics.Mean()
val_loss = tf.keras.metrics.Mean()

# Train
for train_batch in x_train:
tape.watch(trainable_inv_cubic.trainable_variables)
loss = -trainable_dist.log_prob(train_batch)
train_loss(loss)
train_losses.append(train_loss.result().numpy())

# Validation
for valid_batch in x_valid:
loss = -trainable_dist.log_prob(valid_batch)
val_loss(loss)
valid_losses.append(val_loss.result().numpy())

``````

I’m used to perform some optimization with pure JAX code and I am not a TF expert at all, so if someone can help me in this translation I would be very grateful. Thanks

If You find a bug the bijector forward function weakly caches the result->input mapping to make downstream inverses and log-determinants fast. But somehow this is also interfering with the gradient. A workaround is adding a `del out.

Here is the error message

``````AttributeError                            Traceback (most recent call last)

<ipython-input-17-9a8348b0b336> in <module>
15             loss = -trainable_dist.log_prob(train_batch)
16         train_loss(loss)
19     train_losses.append(train_loss.result().numpy())

1 frames