I have a custom model class which has a training method implemented using train_step
.
class MyFancyModel(tfk.models.Model):
...
def train_model(self, data, message=None):
@tf.function()
def train_step(model_and_inputs):
model, data = model_and_inputs
return model.train_step((data,))
history = {}
from tqdm import trange
for i in trange(steps):
_history = train_step((self, data))
for k,v in _history.items():
v = float(v)
history[k].append(v)
return history
Recently, I realized I should probably use self.train_on_batch
rather than using my own local function definition based on train_step
. So, I rewrote this method as
def train_model(self, data, message=None):
history = {}
from tqdm import trange
for i in trange(steps):
_history = train_on_batch(data, return_dict=True)
for k,v in _history.items():
v = float(v)
history[k].append(v)
return history
I thought that would be that, but I noticed that the new version’s output is slightly worse. I’m scratching my head trying to figure out what might be the salient difference between these two implementations. I’d appreciate any relevant insights into the keras.models.Model
innards.
Cheers!