Thanks,
Currently immediatamente trying to implement this paper https://arxiv.org/abs/2103.04257, Pytorch implementation in pretty straightforward, but i have some issue with tensorflow. I defined the model in the following way, but I don’t think it is correct, the result it’s quite different from Pytorch, also in terms of trainable parameters (~11M Pytorch vs ~3M TF)
def Define_Model(img_shape, num_channel):
#----------------------------Istanza ResNet-18 ----------------------------
ResNet18, preprocess_input = Classifiers.get('resnet18')
#--------------------------------------------------------------------------
#----------------------- Definizione Tensore Input ------------------------
input_tensor = tf.keras.Input(shape = (img_shape, img_shape, num_channel))
#--------------------------------------------------------------------------
#----------------- -- Definizione ResNet Teacher e Student ----------------
t_net = ResNet18(weights = 'imagenet', include_top = False, input_tensor = input_tensor, input_shape = (img_shape, img_shape, num_channel))
s_net = ResNet18(weights = None, include_top = False, input_tensor = input_tensor, input_shape = (img_shape, img_shape, num_channel))
#--------------------------------------------------------------------------
#---------------------- Redifinzione Nomi Layer Reti ----------------------
for i, layer in enumerate(t_net.layers):
layer._name = 't_net_' + layer.name
for i, layer in enumerate(s_net.layers):
layer._name = 's_net_' + layer.name
#--------------------------------------------------------------------------
#------------------ Imposto la rete Teacher come non addestrabile ---------
for l in t_net.layers:
l.trainable = False
#--------------------------------------------------------------------------
#----------------- Estrazione Layer Intermedi Teacher ---------------------
intermediate_t_layer_1 = t_net.get_layer("t_net_stage1_unit2_conv2").output
intermediate_t_layer_2 = t_net.get_layer("t_net_stage2_unit2_conv2").output
intermediate_t_layer_3 = t_net.get_layer("t_net_stage3_unit2_conv2").output
#--------------------------------------------------------------------------
#----------------- Estrazione Layer Intermedi Student ---------------------
intermediate_s_layer_1 = s_net.get_layer("s_net_stage1_unit2_conv2").output
intermediate_s_layer_2 = s_net.get_layer("s_net_stage2_unit2_conv2").output
intermediate_s_layer_3 = s_net.get_layer("s_net_stage3_unit2_conv2").output
#---------------------------------------------------------------------------
#------------------------------ Output -----------------------------------
out_1 = [intermediate_t_layer_1] + [intermediate_t_layer_2] + [intermediate_t_layer_3]
out_2 = [intermediate_s_layer_1] + [intermediate_s_layer_2] + [intermediate_s_layer_3]
#--------------------------------------------------------------------------
#------------------------------ Modello -----------------------------------
model = tf.keras.Model(inputs = input_tensor, outputs = [out_1, out_2])
#--------------------------------------------------------------------------
#------------------------------ Compile -----------------------------------
model.add_loss(Feature_Loss(input_tensor, out_1, out_2))
model.compile(Adam(lr = 0.4), loss = None)
#--------------------------------------------------------------------------
return model, t_net, s_net
Daniele