Callback for getting data size on_train_begin

I need to get the shape of the input X data when training begins. But, I do not see the data passed into:

def on_train_begin(self, logs=None):

This is a bummer because it’s really easy in pytorch. Is there a workaround to do this?

Hi @BeardedDork, you can get the shape of the training data by using the below callback.

def on_train_begin(self, logs=None):
        shapeofxtrain = self.x_train.shape
        print(shapeofxtrain)

please refer to this gist for working code example. Thank You.

2 Likes