How to generate a batched, differentiable N-d-grid

Hi,
I am currently trying to create the following function:
Args:

  • starts: A tensor of shape (batch, N) where N ist the dimensionality of the target grid. specifying the smallest value for the grid in each dimension
  • stops: A tensor of shape (batch, N) where N ist the dimensionality of the target grid. specifying the largest value for the grid in each dimension
  • nums: A list of the number of steps of the target grid for each dimension of the shape (N)

Returns: A tensor of shape (batch, nums[0],nums[1],…nums[N-1],N) being a evenly-spaced grid of dimension N with each dimension k going ranging from starts[b,k] to stops[b,k] in nums[k] steps in batch b.

This is pretty similar to tfg.geometry.representation.grid.generate  |  TensorFlow Graphics
However this function does not play well with the batch size not being specified and in the docs it says it is not differentiable:

startsInput=tf.keras.Input(shape=(2,))
stopsInput=tf.keras.Input(shape=(2,))
tfg.geometry.representation.grid.generate(starts=startsInput, stops=stopsInput, nums=[3, 3])

This call fails since the method calls tf.unstack(starts), but the first dimension of starts is “None” since the batch size in unknown at this time. However the call with real data works as expected:

tfg.geometry.representation.grid.generate(starts=[[1.0, 1.0]], stops=[[2.0, 2.0]], nums=[3, 3])

Is there a way to create such a method that works with undefined batch sizes and is differentiable?

I think I found a solution to this: tf.linspace supports batched operation and is differentiable in the current version. This results in the following function:

def generate_2d_grid(starts, stops, nums:Tuple, name="grid_generate"):
    """Generates a 2D grid, similar to tf.linspace, but 2d.

    Args:
      starts: A tensor of shape `[B, 2]` containing the start coordinates of the
        grid.
      stops: A tensor of shape `[B, 2]` containing the stop coordinates of the
        grid.
      nums: A tuple of the form `[w,h]` containing the number of points in each
        dimension. Must be known at compile time.
      name: A name for this op that defaults to "grid_generate".

    Returns:
      A tensor of shape `[B, w, h, 2]` containing the 2D grid. starts and ends are inclusive. The last dimension contains the x and y coordinates of the grid points.
    """
    with tf.compat.v1.name_scope(name):
        # shape (B,2)
        starts = tf.convert_to_tensor(starts)
        # shape (B,2)
        stops = tf.convert_to_tensor(stops)
        # shape (B,w)
        w_range=tf.linspace(starts[:,0],stops[:,0],nums[0],axis=1)

        # shape (B,w)
        second_spacial_dim_start=tf.einsum('i,j->ij',starts[:,1],tf.ones((nums[0],)))
        # shape (B,w,2)
        lower_stacked=tf.stack([w_range,second_spacial_dim_start],axis=2)
        # shape (B,w)
        second_spacial_dim_end=tf.einsum('i,j->ij',stops[:,1],tf.ones((nums[0],)))
        # shape (B,w,2)
        upper_stacked=tf.stack([w_range,second_spacial_dim_end],axis=2)
        
        # liat of tensors with shape (B,w,2) of length h
        ranges=[(1.0-alpha)*lower_stacked+alpha*upper_stacked for alpha in [i/(nums[1]-1) for i in range(nums[1])]]
        # shape (B,w,h,2)
        stacked=tf.stack(ranges,axis=2)
        return stacked