Conversion of pytorch code to tensorflow

I have attempted to convert the code below to tensorflow, but I am receiving shape errors. How can I convert this code to tensorflow?

class E_MHSA(nn.Module):
    def __init__(self, dim, out_dim=None, head_dim=32, qkv_bias=True, qk_scale=None,
                 attn_drop=0, proj_drop=0., sr_ratio=1):
        super().__init__()
        self.dim = dim
        self.out_dim = out_dim if out_dim is not None else dim
        self.num_heads = self.dim // head_dim
        self.scale = qk_scale or head_dim ** -0.5
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        self.proj = nn.Linear(self.dim, self.out_dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)
        self.sr_ratio = sr_ratio
        self.N_ratio = sr_ratio ** 2
        if sr_ratio > 1:
            self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio)
            self.norm = nn.BatchNorm1d(dim, eps=NORM_EPS)
            
    def forward(self, x):
        B, N, C = x.shape
        q = self.q(x)
        q = q.reshape(B, N, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        if self.sr_ratio > 1:
            x_ = x.transpose(1, 2)
            x_ = self.sr(x_)
            x_ = x_.transpose(1, 2)
            k = self.k(x_)
            k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
            v = self.v(x_)
            v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        else:
            k = self.k(x)
            k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
            v = self.v(x)
            v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        attn = (q @ k) * self.scale

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x```

@gitesh_chawda

Welcome to TF Forum !

To convert the Pytorch code to Tensorflow use ONNX : PyTorch → ONNX → TF

You can follow Below steps :

  1. Installations
!pip install -U onnx
!pip install -U onnx-tf
  1. Imports
import torch.onnx
  1. Export Pytorch to ONNX - The resulting alexnet.onnx file contains a binary [protocol buffer].
    (Protocol Buffers Documentation) which contains both the network structure and parameters of the model you exported (in this case, LSTM)
    Eg:
    torch.onnx.export(model, dummy_input, "LSTM.onnx", verbose=True, input_names=input_names, output_names=output_names)

4.Using the below code for converting the obtained ONNX model to Tensorflow:

import onnx
from onnx_tf.backend import prepare

onnx_model = onnx.load("./LSTM.onnx")
tf_rep = prepare(onnx_model)
tf_rep.export_graph('./LSTM.pb')

Let us know if the above solution works for you

1 Like

Also good to realize that onnx can run inference and the saved models are not re-trainable, so for the most part you are fine using Onnx or Pytorch, if you started with Pytorch.

I prefer Tensorflow but there is more resources in Pytorch imho.

batch_size = 2
seq_len = 5
feat_dim = 64

dummy_input = torch.randn(batch_size, seq_len, feat_dim)
model = E_MHSA(dim=64, out_dim=128, head_dim=32)
output = model(dummy_input)
print(output.shape)

torch.onnx.export(E_MHSA, dummy_input, "attention.onnx", verbose=True, input_names="input_names", output_names="output_names")

I am getting this error:

TypeError                                 Traceback (most recent call last)
<ipython-input-31-be0b10e5e25f> in <module>
----> 1 torch.onnx.export(E_MHSA, dummy_input, "attention.onnx", verbose=True, input_names="input_names", output_names="output_names")

5 frames
/usr/local/lib/python3.9/dist-packages/torch/onnx/utils.py in disable_apex_o2_state_dict_hook(model)
    135     if not isinstance(model, torch.jit.ScriptFunction):
    136         model_hooks = {}  # type: ignore[var-annotated]
--> 137         for module in model.modules():
    138             for key, hook in module._state_dict_hooks.items():
    139                 if type(hook).__name__ == "O2StateDictHook":

TypeError: modules() missing 1 required positional argument: 'self'