Using a mixture distribution with hidden Markov models

Hello, I am currently trying to fit a hidden Markov model on a sequence of observations that consist of angles and speeds. The observations are taken from a dataset in which the amount of time between them is not constant. However, I want to model a sequence of states with a constant amount of time between them. In order to do this, I am trying to model “missing” observations, i.e. states for which an observation does not exist. I am trying to do this with a mixture distribution where the first element of each event indicates whether the observation is missing or not. Missing observations therefore correspond to [0, 0, 0], while non-missing observations correspond to [1, angle, speed], where angle is in [-pi, pi] and speed is in [0, infinity). However, I am getting an error I don’t understand when trying to compute the log probability of the hidden Markov model. Here is my current code:

# assume 2 states
num_states = 2

# Randomly initialize the initial state distribution as well as the transition probabilities
initial_probs = tf.Variable(scipy.special.softmax(rng.random([num_states])), name='initial_probs', dtype=tf.float32)
transition_probs = tf.Variable(scipy.special.softmax(rng.random([num_states, num_states]), axis=1), name='transition_probs', dtype=tf.float32)

# Initialize locations and concentrations of Von Mises distributions for turning angles
vm_locs = tf.Variable(np.zeros(num_states), dtype=tf.float32)
vm_cons = tf.Variable(np.zeros(num_states), dtype=tf.float32)

# Initialize shapes and rates of Gamma distributions for speed
gamma_shapes = tf.Variable(np.ones(num_states), dtype=tf.float32)
gamma_rates = tf.Variable(np.ones(num_states), dtype=tf.float32)

mixed_dists = tfd.Mixture(
    cat=tfd.Categorical(probs=[[0.5, 0.5]] * num_states),
    components=[
        tfd.Independent(tfd.Categorical(probs=[[[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]]] * num_states, dtype=tf.float32), reinterpreted_batch_ndims=1),
        tfd.Blockwise([
            tfd.Categorical(probs=[[0.0, 1.0]] * num_states),
            tfd.VonMises(loc=vm_locs, concentration=vm_cons),
            tfd.Gamma(concentration=gamma_shapes, rate=gamma_rates)
        ], dtype_override=tf.float32)
    ]
)

hmm3 = tfd.HiddenMarkovModel(
    initial_distribution = tfd.Categorical(probs=initial_probs),
    transition_distribution = tfd.Categorical(probs=transition_probs),
    observation_distribution = mixed_dists,
    num_steps = 15
)

hmm3.log_prob(hmm3.sample()) # error here

The error I am getting is this:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-638-bade53a804d5> in <module>
----> 1 hmm3.log_prob(observations[:15])
~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow_probability\python\distributions\distribution.py in log_prob(self, value, name, **kwargs)
   1294         values of type `self.dtype`.
   1295     """
-> 1296     return self._call_log_prob(value, name, **kwargs)
   1297 
   1298   def _call_prob(self, value, name, **kwargs):

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow_probability\python\distributions\distribution.py in _call_log_prob(self, value, name, **kwargs)
   1276     with self._name_and_control_scope(name, value, kwargs):
   1277       if hasattr(self, '_log_prob'):
-> 1278         return self._log_prob(value, **kwargs)
   1279       if hasattr(self, '_prob'):
   1280         return tf.math.log(self._prob(value, **kwargs))

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow_probability\python\distributions\hidden_markov_model.py in _log_prob(self, value)
    487       # working_obs :: num_steps batch_shape 1 underlying_event_shape
    488 
--> 489       observation_probs = observation_distribution.log_prob(working_obs)
    490       # observation_probs :: num_steps batch_shape num_states
    491 

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow_probability\python\distributions\distribution.py in log_prob(self, value, name, **kwargs)
   1294         values of type `self.dtype`.
   1295     """
-> 1296     return self._call_log_prob(value, name, **kwargs)
   1297 
   1298   def _call_prob(self, value, name, **kwargs):

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow_probability\python\distributions\distribution.py in _call_log_prob(self, value, name, **kwargs)
   1276     with self._name_and_control_scope(name, value, kwargs):
   1277       if hasattr(self, '_log_prob'):
-> 1278         return self._log_prob(value, **kwargs)
   1279       if hasattr(self, '_prob'):
   1280         return tf.math.log(self._prob(value, **kwargs))

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow_probability\python\distributions\mixture.py in _log_prob(self, x)
    281   def _log_prob(self, x):
    282     x = tf.convert_to_tensor(x, name='x')
--> 283     distribution_log_probs = [d.log_prob(x) for d in self.components]
    284     cat_log_probs = self._cat_probs(log_probs=True)
    285     final_log_probs = [

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow_probability\python\distributions\mixture.py in <listcomp>(.0)
    281   def _log_prob(self, x):
    282     x = tf.convert_to_tensor(x, name='x')
--> 283     distribution_log_probs = [d.log_prob(x) for d in self.components]
    284     cat_log_probs = self._cat_probs(log_probs=True)
    285     final_log_probs = [

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow_probability\python\distributions\distribution.py in log_prob(self, value, name, **kwargs)
   1294         values of type `self.dtype`.
   1295     """
-> 1296     return self._call_log_prob(value, name, **kwargs)
   1297 
   1298   def _call_prob(self, value, name, **kwargs):

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow_probability\python\distributions\distribution.py in _call_log_prob(self, value, name, **kwargs)
   1276     with self._name_and_control_scope(name, value, kwargs):
   1277       if hasattr(self, '_log_prob'):
-> 1278         return self._log_prob(value, **kwargs)
   1279       if hasattr(self, '_prob'):
   1280         return tf.math.log(self._prob(value, **kwargs))

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow_probability\python\distributions\independent.py in _log_prob(self, x, **kwargs)
    283   def _log_prob(self, x, **kwargs):
    284     return self._reduce(
--> 285         self._sum_fn(), self.distribution.log_prob(x, **kwargs))
    286 
    287   def _unnormalized_log_prob(self, x, **kwargs):

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow_probability\python\distributions\distribution.py in log_prob(self, value, name, **kwargs)
   1294         values of type `self.dtype`.
   1295     """
-> 1296     return self._call_log_prob(value, name, **kwargs)
   1297 
   1298   def _call_prob(self, value, name, **kwargs):

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow_probability\python\distributions\distribution.py in _call_log_prob(self, value, name, **kwargs)
   1276     with self._name_and_control_scope(name, value, kwargs):
   1277       if hasattr(self, '_log_prob'):
-> 1278         return self._log_prob(value, **kwargs)
   1279       if hasattr(self, '_prob'):
   1280         return tf.math.log(self._prob(value, **kwargs))

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow_probability\python\distributions\categorical.py in _log_prob(self, k)
    293     k, logits = _broadcast_cat_event_and_params(
    294         k, logits, base_dtype=dtype_util.base_dtype(self.dtype))
--> 295     return -tf.nn.sparse_softmax_cross_entropy_with_logits(
    296         labels=k, logits=logits)
    297 

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow\python\util\dispatch.py in wrapper(*args, **kwargs)
    204     """Call target, and fall back on dispatchers if there is a TypeError."""
    205     try:
--> 206       return target(*args, **kwargs)
    207     except (TypeError, ValueError):
    208       # Note: convert_to_eager_tensor currently raises a ValueError, not a

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow\python\ops\nn_ops.py in sparse_softmax_cross_entropy_with_logits_v2(labels, logits, name)
   4226       of the labels is not equal to the rank of the logits minus one.
   4227   """
-> 4228   return sparse_softmax_cross_entropy_with_logits(
   4229       labels=labels, logits=logits, name=name)
   4230 

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow\python\util\dispatch.py in wrapper(*args, **kwargs)
    204     """Call target, and fall back on dispatchers if there is a TypeError."""
    205     try:
--> 206       return target(*args, **kwargs)
    207     except (TypeError, ValueError):
    208       # Note: convert_to_eager_tensor currently raises a ValueError, not a

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow\python\ops\nn_ops.py in sparse_softmax_cross_entropy_with_logits(_sentinel, labels, logits, name)
   4159       # The second output tensor contains the gradients.  We use it in
   4160       # _CrossEntropyGrad() in nn_grad but not here.
-> 4161       cost, _ = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
   4162           precise_logits, labels, name=name)
   4163       cost = array_ops.reshape(cost, labels_shape)

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow\python\ops\gen_nn_ops.py in sparse_softmax_cross_entropy_with_logits(features, labels, name)
  11247       return _result
  11248     except _core._NotOkStatusException as e:
> 11249       _ops.raise_from_not_ok_status(e, name)
  11250     except _core._FallbackException:
  11251       pass

~\anaconda3\envs\shark_research_project\lib\site-packages\tensorflow\python\framework\ops.py in raise_from_not_ok_status(e, name)
   6895   message = e.message + (" name: " + name if name is not None else "")
   6896   # pylint: disable=protected-access
-> 6897   six.raise_from(core._status_to_exception(e.code, message), None)
   6898   # pylint: enable=protected-access
   6899 

~\anaconda3\envs\shark_research_project\lib\site-packages\six.py in raise_from(value, from_value)

InvalidArgumentError: Received a label value of -1 which is outside the valid range of [0, 2).  Label values: 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 1 0 0 1 0 0 0 0 0 0 0 0 1 0 0 1 0 0 1 0 0 1 0 0 1 1 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 1 1 0 1 0 0 1 0 0 0 0 0 0 0 0 1 0 -1 1 0 -1 [Op:SparseSoftmaxCrossEntropyWithLogits]

The strangest part is that when I create the hidden Markov model with 14 steps and call hmm3.log_prob(hmm3.sample), I don’t get any errors. Please let me know if I should clarify anything or include more information. Any help would be appreciated!

/cc @Christopher_Suter

@markdaoust I think we could add a new probability TAG so that the TFP team could be subscribed.

2 Likes

I believe the problem is that your Mixture is over 2 distributions with incompatible event spaces. The way Mixture is implemented, it evaluates the log_prob of inputs under all the component distributions, and then weights them according to mixture probs. So any input event must be able to be passed to all mixture components. I suspect, but haven’t confirmed, that samples coming from your “real” distribution, which may include negative floats (like the VonMises), are being cast to integers somewhere along the line so they can be passed into the Categorical log prob (which uses the sparse cross entropy function).

The solution is to have your “fallback” distribution actually live over the same space as your real observation distribution. Here’s a mixed_dists that will work in your code snippet, by using Deterministic:

mixed_dists = tfd.Mixture(
    cat=tfd.Categorical(probs=[[0.5, 0.5]] * num_states),
    components=[
        tfd.Independent(
            tfd.Deterministic(loc=tf.zeros([2, 3])),
            reinterpreted_batch_ndims=1),
        tfd.Blockwise([
            tfd.Categorical(probs=[[0.0, 1.0]] * num_states),
            tfd.VonMises(loc=vm_locs, concentration=vm_cons),
            tfd.Gamma(concentration=gamma_shapes, rate=gamma_rates)
        ], dtype_override=tf.float32)
    ]
)

This way, samples from the second component will always be valid inputs to the first component (they’ll just have 0 prob density, almost always).

BTW I think the 14 vs 15 step thing was a red herring. If you use a different RNG seed you’ll probably be able to get the same error on any (sufficiently long) series. It just depends on whether your VonMises has a chance to yield negative samples.

HTH!

2 Likes

Thanks, this solved the issue I was running into! Once again I’ve learned something new after having bashed my head against a wall for a while trying to use the wrong tool! Unfortunately, I think this solution has brought forward another issue. When I try to fit the HMM using the Adam optimizer to maximize the log probability, I get nan gradients for the gamma concentrations after the first iteration. You mentioned that the input to log_prob is evaluated under all mixture components, which leads me to believe that the issue is with the gamma concentrations being updated to 1.1 on the first training step, after which observations of the form [0, 0, 0] have a log probability of -Inf under the gamma distribution. Then again, I could be totally off, I am not very familiar with Tensorflow. Here is the relevant code:

# Define a loss function
def log_prob():
    return hmm3.log_prob(observations)

# Define an optimizer to perform back propagation
optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)

# Make sure probabilities sum to 1
def normalize_probs(probs):
    abs_probs = tf.math.abs(probs)
    if len(probs.shape) > 1:
        sums = tf.reshape(tf.reduce_sum(abs_probs, axis=1), [probs.shape[0], 1])
    else:
        sums = tf.reduce_sum(abs_probs)
    return abs_probs / sums

def wrap_to_pi(A):
    return ((A - np.pi) % (2 * np.pi) - np.pi)

# Run a step of the optimizer
@tf.function(autograph=False)
def train_op():
    with tf.GradientTape() as tape:
        neg_log_prob = -log_prob()
    vars = [initial_probs, transition_probs, vm_locs, vm_cons, gamma_shapes, gamma_rates]
    grads = tape.gradient(neg_log_prob, vars)
    optimizer.apply_gradients(zip(grads, vars))
    initial_probs.assign(normalize_probs(initial_probs))
    transition_probs.assign(normalize_probs(transition_probs))
    vm_locs.assign(wrap_to_pi(vm_locs))
    vm_cons.assign(tf.math.abs(vm_cons))
    gamma_shapes.assign(tf.math.abs(gamma_shapes))
    gamma_rates.assign(tf.math.abs(gamma_rates))
    return (neg_log_prob, *vars), grads

# Train on the observations
loss_history = []
for step in range(10):
    ts, grads = train_op()
    loss, ip, tp, vl, vc, gs, gr = [t.numpy() for t in ts]
    loss_history.append(loss)
    if step % 1 == 0:
        print("step {}: log prob {}\nInitial probs: {}\nTransition probs:\n{}\nVon Mises locs: {}\nVon Mises cons: {}\nGamma shapes: {}\nGamma rates: {}\n".format(step, -loss, ip, tp, vl, vc, gs, gr))

I would appreciate any insight you might have on how to fix this new error, and thanks a lot for the help you’ve provided so far!

I made some modifications to your code in this colab.

Things I changed:

  1. [Major-ish] Reparameterize in terms of logits instead of probs – this is (a) generally numerically favorable (internally your probs will be logit-ed anyway) and (b) saves us having to manually normalize to 1 all the time.
  2. [Major-ish] Use tfp.util.TransformedVariable to constrain positive params and vm_loc to lie in [-pi, pi]. This allows us to eliminate the variable clipping.
  3. [Major-ish] Use [0., 0., 1.] as the default value for the deterministic missing data fallback. This prevents a nan log_prob from the Mixture.
  4. [Minor] Use hmm3.trainable_variables instead of variable lists.

With those changes, I see some valid-seeming optimization steps. Please note, though, that I haven’t thought very much about the actual model here – I’m just pattern-matching on common numerical and configuration-style pitfalls. In particular, (1) I have not thought about whether this approach to missing data makes sense in the HMM context (it very well might — I just haven’t thought about it at all!) and (2) one generally wants to be a bit cautious about optimizing on a non-Euclidean manifold, which the VonMises loc is doing (it’s on a circle!) — Amari, 1998 would have us think carefully about metrics and such, and adapt our gradients appropriately. I think the constrained version I wrote here is probably ok as long as the actual optimum is not at -pi/pi, which are effectively “at infinity” in the unconstrained parameter space…If the optimum is there, we can probably just add a phase shift to that angle variable to move it away from infinity.

HTH!

3 Likes

Ok had an idea and modified one more thing – switched to using tfp.math.minimize instead of manual optimization. This has some niceties like automatically using a tf.while_loop and tf.function compiling things. It returns losses by default, but I wrote a custom trace_fn to trace all quantities (vars, grads, etc) just to demonstrate how that could be used. You can write whatever you want in trace_fn and it’ll preserve those values and return them, per-step, at the end of optimization.

3 Likes

Oh wow, thanks a lot! I managed to figure out the change to [0, 0, 1] for a missing variable myself a little while after my first follow up, but I wasn’t sure if there was a better way to do it. There is a lot in your optimizations that I’m not too familiar with, but most of it makes sense and the code is overall much cleaner. I had been wondering about how to properly constrain those variables during optimization! Thanks again!

1 Like

done: tf-probability

1 Like

Hello! I’m having trouble understanding tfd.Mixture and _tfd.MixtureSameFamily when I apply them. The parameters of the Mixture distribution function are set to be the same. However, when modeling with both tfd.Mixture and _tfd.MixtureSameFamily, the problem of different probabilities occurs. The following program can be run directly.

from tensorflow import keras

from keras import layers

import numpy as np

import tensorflow as tf

from tensorflow_probability import distributions as tfd

mus = tf.convert_to_tensor([[11, 23, 34 ]], dtype=tf.float32)

sigs = tf.convert_to_tensor([[22, 12, 32 ]], dtype=tf.float32)

out_pi = tf.convert_to_tensor([[0.2, 0.3, 0.5 ]], dtype=tf.float32)

print(‘weight’,out_pi)

print(‘sigs’,sigs)

print(‘mus’,mus)

output_dim = 1

num_mixes = 3

the first method

cat = tfd.Categorical(logits=out_pi)

print(‘cat-----------’,cat)

component_splits = [output_dim] * num_mixes

mus1 = tf.split(mus, num_or_size_splits=component_splits, axis=1)

sigs1 = tf.split(sigs, num_or_size_splits=component_splits, axis=1)

print(‘sigs1-----------’,sigs1)

print(‘mus1-----------’,mus1)

coll = [tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale in zip(mus1, sigs1)]

print(‘coll-----------------’,coll)

mixture1 = tfd.Mixture(cat=cat, components=coll)

ex4 = mixture1.mean()

the second method

mus=mus.numpy().squeeze()

sigs=sigs.numpy().squeeze()

out_pi=out_pi.numpy().squeeze()

mixture_cdf = tfd.MixtureSameFamily(

  mixture_distribution=tfd.Categorical(probs=out_pi),

  components_distribution=tfd.Normal(loc=mus, scale=sigs))  # One for each component.

ex1 = mixture_cdf.mean()

case study

rrrrr = [[2],[50],[30]]

the third method

CCdf = tfd.Normal(mus[0], sigs[0],name=‘Normal’).cdf(rrrrr)*out_pi[0] + tfd.Normal(mus[1], sigs[1],name=‘Normal’).cdf(rrrrr)*out_pi[1] + tfd.Normal(mus[2], sigs[2],name=‘Normal’).cdf(rrrrr)*out_pi[2]

print(‘ex’,ex1.numpy(),ex4.numpy())

print(’----------’)

print(‘cdf_mixture_cdf’,mixture_cdf.prob(rrrrr))

print(‘mixture1’,mixture1.prob(rrrrr))

print(’----------’)

print(‘cdf_mixture_cdf’,mixture_cdf.cdf(rrrrr))

print(‘mixture1’,mixture1.cdf(rrrrr))#### appear ‘NotImplementedError: _is_increasing not implemented.’

print(‘CCdf’,CCdf)