Help with VICReg loss terms

Can anyone help with the tensorflow implementation of the VICReg loss terms. Thank you in advance.

I have been able to implement them. For anyone interested, here it is;

def off_diagonal(x):
    n, m = x.shape[0], x.shape[1]
    assert (n == m), f"Not a square tensor, dimensions found: {n} and {m}"
    flattened_tensor = tf.reshape(x, [-1])[:-1]
    elements = tf.reshape(flattened_tensor, [n - 1, n + 1])[:, 1:]
    return tf.reshape(elements, [-1])

def invariance_loss(z_a,z_b):
  '''invariance loss'''

  mse_loss = tf.keras.metrics.mean_squared_error(z_a,z_b)
  return mse_loss

def variance_loss(z_a,z_b):
  '''variance preservation term to maintain the standard deviation 
  of each embedding over a batch applied separately to the two branches

  std_z_a = tf.math.sqrt(tf.math.reduce_variance(z_a,axis=0) + 1e-4)
  std_z_b = tf.math.sqrt(tf.math.reduce_variance(z_b,axis=0) +1e-4)
  std_loss = (tf.math.reduce_mean(tf.nn.relu(1 - std_z_a)) +
            tf.math.reduce_mean(tf.nn.relu(1 - std_z_b))) * 0.5
  return std_loss

def covariance_loss(z_a,z_b):
  '''covariance between pairs of embedding over a batch applied 
  separately to the two branches

  z_a = z_a - tf.math.reduce_mean(z_a,axis=0)
  z_b = z_b - tf.math.reduce_mean(z_b,axis=0)
  cov_z_a = tf.linalg.matmul(z_a,z_a,transpose_a=True) / (N-1)
  cov_z_b = tf.linalg.matmul(z_b,z_b,transpose_a=True) / (N-1)
  cov_loss_z_a = tf.math.divide(tf.math.reduce_sum(tf.math.pow(off_diagonal(cov_z_a), 2)), D)
  cov_loss_z_b = tf.math.divide(
                tf.math.reduce_sum(tf.math.pow(off_diagonal(cov_z_b), 2)), D
  return cov_loss_z_a + cov_loss_z_b

where N and D are the batch size and dimension size respectively

1 Like