Block masking in TensorFlow

I’m trying to implement the masking generation function for BEiT:

The part I am struggling with is the assignment of EagerTensors.

I have consulted references that show how to approach such assignments, but this one does not seem to fit them.

Any particular approaches I should try out or look into for this case?

Is every single masking patch random inside the single image there?

P.s. I was looking at:

1 Like

A single mask can be applied to a batch too.

Thanks for sharing this. Will take a look.

@Bhack there’s actually no masking involved in the link you sent.

So, the question is pretty much still open.

Yes as they are just reloading Microsoft weights. So no train protocol there.

What Is your specific issue? Isn’t just the standard
image tokenization in many visual transformer where some token are masked?

What Is your specific issue? Isn’t just the standard
image tokenization in many visual transformer where some token are masked?

My issue is in the block-wise masking strategy where apparently tensor assignment is needed (refer to my initial post). Had it been randomized, it would have been easier and we implemented that a while back (here).

To exactly mimic that impl are you looking for slice assigment?

Yes. Please take note of this part before sharing existing references:

I have consulted references that show how to approach such assignments, but this one does not seem to fit them.

If there’s no way other than doing something like this, then it’s a different choice.

Oh, in that case historically we are full of slice assignment tickets. Just to mention a few still open:

I’ve not checked the paper in details on what kind of index is going to be selected to execute the masking. Cannot be covered by tf.tensor_scatter_nd_update after populating these indexes?

The indexing conditions are in the source code I provided.

If you know a way around with scatter, do you mind providing a minimal working code.

E.g. I think that embedding in the Hugginface transformers library, also if it is using Pytorch ops, is not going to require/use the slice assignment:

I think you’re mistaken then.

bool_masked_pos in the forward() is nothing but the output the mask yielded by the class I showed in my initial post.

It is true bool_masked_pos is only the “application” of the masking but then ownership to prepare the mask it is still to the external the caller.

I don’t see all the details that are in reference implementation in the paper but with the concrete reference implementation you shared, with all these attemps, conditional loops etc, you could try to use a tf.variable to mimic that implementation but probably you will need to refactor it more in graph mode/tf.function:

Absolutely. And in case no reference implementations are available I guess the implementation done by the actual author comes to the rescue.

There isn’t much about it in the paper apart from the figure on block-wise masking which is why the original implementation is an important reference point.

Thanks for sharing your implementation. Will check it out.

Having a tf.function/graph version It is quite trivial with few changes/substitution with TF ops.

But a jit_compile=True version it will require a new design and probably some compromises.

Let me know if you have a jit_comile=True version.

What’s trivial for you may not be trivial to someone else :slight_smile:

Let me know when you have the same I’ve posted but with TF instead of numpy ops.
I will help you to make the required changes for tf.function.

This is already working with tf.function with minimal changes

    @tf.function
    def _mask(self, mask, max_mask_patches):
        delta = 0
        for attempt in tf.range(10):
            target_area = random.uniform(self.min_num_patches, max_mask_patches)
            aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))
            if w < self.width and h < self.height:
                top = random.randint(0, int(self.height - h))
                left = random.randint(0, int(self.width - w))
                num_masked = tf.math.count_nonzero(mask[top: top + h, left: left + w])
                # Overlap
                if 0 < h * w - num_masked and h * w - num_masked  <= max_mask_patches:
                    for i in range(top, top + h):
                        for j in range(left, left + w):
                            if mask[i, j] == 0:
                                mask[i, j].assign(1)
                                delta += 1
                if delta > 0:
                    break
        return delta