Embedding weights tied to projection out logits


Is there a concise way to implement embedding weights tied to projection out logits in TensorFlow?

In Flax, we are able to do this:

class Parallel_Transformer(nn.Module): 
  dim: int 
  num_tokens: int
  depth: int 
  dim_head: int = 64 
  heads: int = 8 
  ff_mult: int = 4
  def __call__(self, x):
    embed = nn.Embed(self.num_tokens, self.dim, embedding_init = nn.initializers.normal(stddev=0.02))
    x = embed(x)
    x = Transformer(dim=self.dim, depth=self.depth, heads=self.heads, dim_head=self.dim_head, ff_mult=self.ff_mult)(x)
    x = nn.LayerNorm(epsilon = 1e-5, use_bias = False)(x)
    out = embed.attend(x)
    return out

In PyTorch, we are able to do the same with:

def Parallel_Transformer(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):

    net = nn.Sequential(
        nn.Embedding(num_tokens, dim),
            Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
            for _ in range(depth)
        nn.Linear(dim, num_tokens, bias=False)
    net[-1].weight = net[0].weight

    nn.init.normal_(net[0].weight, std=0.02)
    return net

I am currently working on a TensorFlow implementation here:

class Parallel_Transformer(tf.keras.Model):
    def __init__(self, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4): 
        super(PaLM, self).__init__()
        self.dim = dim
        self.num_tokens = num_tokens
        self.depth = depth
        self.dim_head = dim_head
        self.heads = heads
        self.ff_mult = ff_mult

        self.embedding = tf.keras.layers.Embedding(num_tokens, dim, embeddings_initializer='uniform')

        self.norm = tf.keras.layers.LayerNormalization()

        self.to_out = tf.keras.layers.Dense(num_tokens, use_bias = False)

    def call(self, x):
        embed = self.embedding(x)
        x = ParallelTransformer(self.dim, self.depth, self.heads, self.dim_head, self.ff_mult)(embed)
        x = self.norm(x)
        out = self.to_out(x)
        return out

I am unsure how to properly do the above in TensorFlow.

Any advice would be greatly appreciated.

Thank you,