# Embedding weights tied to projection out logits

Hi,

Is there a concise way to implement embedding weights tied to projection out logits in TensorFlow?

In Flax, we are able to do this:

``````class Parallel_Transformer(nn.Module):
dim: int
num_tokens: int
depth: int
ff_mult: int = 4

@nn.compact
def __call__(self, x):
embed = nn.Embed(self.num_tokens, self.dim, embedding_init = nn.initializers.normal(stddev=0.02))
x = embed(x)
x = nn.LayerNorm(epsilon = 1e-5, use_bias = False)(x)
out = embed.attend(x)
return out
``````

In PyTorch, we are able to do the same with:

``````def Parallel_Transformer(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):

net = nn.Sequential(
nn.Embedding(num_tokens, dim),
*[
for _ in range(depth)
],
LayerNorm(dim),
nn.Linear(dim, num_tokens, bias=False)
)

net[-1].weight = net[0].weight

nn.init.normal_(net[0].weight, std=0.02)
return net
``````

I am currently working on a TensorFlow implementation here:

``````class Parallel_Transformer(tf.keras.Model):
super(PaLM, self).__init__()

self.dim = dim
self.num_tokens = num_tokens
self.depth = depth
self.ff_mult = ff_mult

self.embedding = tf.keras.layers.Embedding(num_tokens, dim, embeddings_initializer='uniform')

self.norm = tf.keras.layers.LayerNormalization()

self.to_out = tf.keras.layers.Dense(num_tokens, use_bias = False)

def call(self, x):
embed = self.embedding(x)
x = self.norm(x)
out = self.to_out(x)
return out
``````

I am unsure how to properly do the above in TensorFlow.

Any advice would be greatly appreciated.

Thank you,

Enrico