Distribution parameter as a function of a trainable variable

I am currently trying to fit a hidden markov model with a time-varying transition distribution using the Adam optimizer. In order to specify the transition distribution as a function of some covariates that depend on time, I am trying to describe its logits in terms of a trainable variable. However, I’m not sure how to train the variable that the distribution’s parameters depend on instead of training the distribution’s parameters directly as I’ve done in the past. Here is the code I have so far:

# Set the number of states and covariates for the model
num_states = 2
num_covariates = 1

# Randomly initialize the initial state distribution
initial_logits = tf.Variable(rng.random([num_states]), name='initial_logits', dtype=tf.float32)
initial_distribution = tfd.Categorical(logits=initial_logits)

# Randomly initialize the regression coefficients.
# This is the variable I want to train that my transition distribution's logits depend on
regression_coeffs = tf.Variable(
    rng.random([num_covariates + 1, num_states, num_states]) * (1 - np.diag([1] * num_states)),

# Randomly generate some covariates
covariates = rng.random([num_observations - 1, num_covariates])

# Compute the transition logits for each time step
# I think this is the part that I need to rewrite in order to allow the regression_coeffs variable to be trained
transition_logits = []
for i in range(num_observations - 1):
        regression_coeffs[0] + 
            regression_coeffs[1:] * covariates[i].reshape(num_covariates, 1, 1),

# Package up all of the logits for each time step into a tensor and create the transition distribution from them
transition_logits = tf.Variable(np.array(transition_logits), name='transition_logits', dtype=tf.float32, trainable=False)
transition_distribution = tfd.Categorical(logits=transition_logits)

# I have some other stuff to define my observation distribution...

# Defining my hidden markov model
hmm = tfd.HiddenMarkovModel(
    initial_distribution = initial_distribution,
    transition_distribution = transition_distribution,
    observation_distribution = joint_dists,
    num_steps = num_observations,
    time_varying_transition_distribution = True

# Define a loss function
def compute_loss():
    return -tf.reduce_logsumexp(hmm.log_prob(observations))

# Define an optimizer to perform back propagation
optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)
criterion = tfp.optimizer.convergence_criteria.LossNotDecreasing(rtol=0.01)

def trace_fn(traceable_quantities):
    return traceable_quantities.loss

# And finally, training the HMM
loss_history = tfp.math.minimize(

Please let me know if I should clarify anything or include more of my code. Any help would be greatly appreciated!

Ok, I figured out a solution. I changed my loss function to recalculate my transition logits and recreate the hmm on each training step. My new loss function looks like this:

def compute_loss():
    hmm = tfd.HiddenMarkovModel(
        initial_distribution = initial_distribution,
        transition_distribution = tfd.Categorical(logits=get_transition_logits()),
        observation_distribution = joint_dists,
        num_steps = num_observations,
        time_varying_transition_distribution = True
    return -tf.reduce_logsumexp(hmm.log_prob(observations))