TFP JAX: The transition kernel drastically decreases speed

Dear all,

I am currently learning Bayesian analysis and utilizing tensorflow_probability.substrates.jax, but I’ve encountered some issues. While using jax with jit for NUTS alone, the performance is quite fast. However, when combined with transformed transitionKernel, the speed decreases drastically. Here’s a summary of the time taken:

  • TFP GPU: NUTS alone took 118.2952 seconds
  • TFP GPU: NUTS + Bijector took 1986.8306 seconds
  • TFP GPU: NUTS + DualAveragingStepSizeAdaptation took 141.0955 seconds
  • TFP GPU: NUTS + Bijector + DualAveragingStepSizeAdaptation took 2397.5875 seconds
  • Numpypro GPU: NUTS + Bijector + DualAveragingStepSizeAdaptation took 180 seconds

I’ve conducted speed tests comparing with Numpypro, and essentially, Numpypro with dual averaging step size adaptation and parameter constraints is equivalent to tensorflow_probability NUTS alone.

Could there be something I’ve missed? Is there room for optimization in this process?

Please find the data and code for reproducibility bellow:

Data

Google Colab

Please note that I’m only using the first 100 lines of the data.

Additionally, as a potential cause, I observed similar speed loss when using the LKJ distribution for other models. (I could post one of them if needed.)

Thank you in advance for your assistance.

Sebastian