Useful tensorflow/keras callbacks for model training
catasaurus
Posted on March 13, 2022
Here are some callbacks that I have found to be very useful when training machine learning models using python and tensorflow:
Number one: Early stopping
Keras early stopping (https://keras.io/api/callbacks/early_stopping/) has to be my favorite callback. With it you can define when the model should stop training if it is not improving. An example for usage is:
earlystopping = tf.keras.callbacks.EarlyStopping(
monitor="val_loss",
min_delta=0.001,
patience=5,
verbose=1,
restore_best_weights=True,
)
This will stop the model's training once it does not improve at least 0.001 in loss for 5 epochs. It will then restore the model's weights to the weights on the best epoch. Just like any callback make sure to include it during training like
model.fit(some_data_X, some_data_y, epochs=some_number, callbacks=[earlystopping, some_other_callback])
Number two: Learning rate scheduler
Keras learning rate scheduler (https://keras.io/api/callbacks/learning_rate_scheduler/) can be very useful if you are having problems with your learning rate. With it you can reduce or increase learning rate during training based on a number of conditions. An example:
def scheduler(epoch, lr):
return lr * tf.math.exp(-0.5)
learningratecallback = tf.keras.callbacks.LearningRateScheduler(scheduler)
The scheduler function is where you can define your logic for how the learning rate should decrease or increase. learningratecallback
just wraps your function in a tf.keras.callbacks.LearningRateScheduler()
. Don't forget to include it in model.fit()
!
Last but not least, number three: Custom callbacks
Custom callbacks (https://keras.io/guides/writing_your_own_callbacks/) are great if you need to do something during training that is not built in to keras + tensorflow. I won't go in depth as there is a lot you can do. Basically you have to define a class that inherits from keras.callbacks.Callback
. There are many different functions that you can define that will be called at different times during the training (or testing and prediction) cycle. A simple example would be:
class Catsarecoolcallback(keras.callbacks.Callback):
def on_epoch_end(self, logs=None):
print('cats are cool!`)
callback = Catsarecoolcallback()
This (as you can probably tell) prints out cats are cool!
every time an epoch ends.
Hope you learned something while reading this!
Posted on March 13, 2022
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.