top of page
Music Apps

Checkpoints in Deep Learning

Updated: Aug 5, 2021

Suppose you were to train a model having 500,000 or more dataset and for its accuracy to be great you chose 10,000 step training. Thinking about a dataset this big makes a thought come right way about the time that will be consumed to run this code. And also, what will happen if suddenly some failure occurs. The days or even weeks labor will be nothing more than just dust and you have to start training from the start and give another week or month for that.

Isn’t it a tedious task to do and way too irritating? Yes, surely it is. Trust me a very bad temper will rise when you lose your weeks work just due to some tiny fault like power outrage, network issue, hardware failure, OS fault, GPU fail or any other unexpected error. At times like this, one might just want to resume the training from a particular point or might train after a particular state.


For situations like this we have Checkpoints that will help us save our progress. Checkpointing is an important functionality to quickly recover from such failures for reducing the overall training time and ensure progress.

Checkpoints are snapshots of your working model during the training process and stores it in a non-volatile memory. In machine learning and deep learning experiments, they are essentially the things which one uses to save the current state of the model so that one can pick up from where they left.

Checkpoints does not contain information about the model or the training nor do they contain any of the computations defined by the model, just some basic details. Since checkpoint are needed to resume the session, they include information such as application details and what have been done so far, as well as what is left to be done to complete the session.


Given above is a block diagram explaining the generation of checkpoints during training. Firstly, inside the neural network model the layers as required by the model are applied and then the initial checkpoint file is saved in the model directory as model.ckpt file. After hat the main training begins. The training occurs on no of steps as provided by the coder and its roughly calculates loss after every 100 steps.

When the loss is minimum as the model feels as if overfitting will occur beyond this point, the training stops and evaluation is performed on the validation dataset. Checkpoints like this can also be manually specified by the user as to after how many epochs he/ she wants to create the checkpoint.

After the creation and validation, the checkpoint_path is saved inside model.ckpt at a destination specified by the coder. Training resumes again from the step it was left and continues until the next checkpoint occurs or until the training steps end.


Checkpoints can play a very important role in training long running machine learning model and when there is a strong possibility of it being interrupted. However, in some cases, checkpoints could add an unnecessary overhead to a session, such as costs related to memory usage and also much more time for training (as checkpoints saving also consumes a fair amount of time). So, It is advisable to create checkpoints only when the model best performs at a particular epoch.

For example, if we have the model training as

In this, if we are saving model at 49 epoch then its validation accuracy at that point is 0.6487 which is less than that at epoch 46 (where it is 0.6728). So instead of saving model at epoch 49 we must save our model at epoch 46 i.e., at an epoch where our model performs the best.

So, we can say that, Checkpoints need not be created after every epoch or occasionally. One needs the checkpoint creation when the weights are of the highest accuracy and lowest loss at that point of time it needs to be saved. If the weights of the model at a given epoch does not produce the best accuracy or loss, then the weight will not be saved i.e., checkpoints will not be created and the training will continue.

Using this method only essential weights are saved, saving both time and memory.


Checkpoint technique is used to record the execution state for the purpose of resuming the session from an intermediary point, preventing processor time being wasted.

When a session is established or when the training starts, the code will itself create a new checkpoint when the command is received from the main code. Occasionally, the checkpoints are created and all the checkpoints are stored locally into the folder as and when specified.

Checkpoint creation varies significantly for various models. We will now look at some of the most famous checkpoint methods used and how to create, save and resume training from checkpoints in them.

Keras & Tensor Flow

In keras, checkpoints are implemented using callbacks. Using callback we can customize the behaviour of a Keras model during its training.

In tensorflow, checkpoints can be implemented using keras only (tf.keras) So let’s study both of the libraries approach to create and use checkpoints.

Callbacks can be used to :

  • Get a view at the internal statics of the model during training.

  • Regularly save your model.

  • Stopping and evaluating model, as and when required.

  • And many more such things can be done.

The first technique for creating checkpoints is using ModelCheckpoint callback. A callback calls ModelCheckpoint() method which is passed during training to the .fit() function. Also it can be passed to .evaluate() and .predict() functions of the model. It saves the models weights after some frequency which is specified by the coder.

Checkpoint_create = tf.keras.callbacks.ModelCheckpoint( 
 filepath = '/tmp/weights.hdf5, 
 monitor = 'val_loss', 
 verbose = 1, 
 save_best_only = True, 
 save_weights_only = False, 
 mode = auto',
 save_freq = ‘epoch
 period = 1) x_traindata, y_traindata, batch_size = 64, epochs = 50, verbose = auto’, validation_data = (X_testdata, Y_testdata), callbacks = Checkpoint_create )
model.evaluate( x_testdata, y_testdata, batch_size = 64, verbose = auto’, callbacks = Checkpoint_create )
model.predict(x_testdata, batch_size = 64, verbose = 0, callbacks = Checkpoint_create )

In the ModelCheckpoint function the list of arguments are as follows:

  • filepath: string argument, gives the file path of the model in which the weights will be saved with its extension

  • monitor: This contains the list of metrics that needs to be evaluated by the model during training and testing. We can give either one string to it or else can give a list or a dictionary of metrics that we need. Some of the metrics mostly used are loss, val_loss, accuracy, val_accuracy.

  • verbose: Verbosity mode. 0 = silent. 1= progress bar, 2 = one line per epoch.

  • save_best_only: If true, then only it saves when the model is considered as the “best”.

  • save_weights_only: If true, then only it saves then weights of the model, else the full model is saved.

  • mode: There are 3 types of modes. ‘Auto’, ’Min’, ’Max’. ‘Auto’ automatically set to ‘Max’ if accuracy vise quantities are to be monitored, otherwise set to ‘Min’. For val_acc ‘Max’ is needed and for loss ‘Min’ is needed. This mode then takes the decision to overwrite the saved file according to quantity to be monitored.

  • save_freq: Determines the frequency at which the model needs to be saved. Default setting is after every 5 epochs.

  • **kwargs: Other additional arguments. May include period.

In the above code fragment, the weights are saved in one location i.e. in weights.hdf5 file. It means that every time a checkpoint is created the file is updated and so the last chekpoint will only be saved in the file.

To create separate files for saving we can write the code as

filepath = weights-{epoch}-{val_loss}-{accuracy}.hdf5 "

i.e., in {} we provide the quantity to be monitored so that the saved file will come out as

Next, we will see the method which is used to reduce overfitting of your machine learning program. The method is called as Early Stopping. It stops the training when a monitored metric has stopped improving, in this way we get the overfitting of model on our dataset is prevented.

Earlystop_cp = tf.keras.callbacks.EarlyStopping( 
 monitor = 'loss', 
 min_delta = 0, 
 patience = 3, 
 verbose = 0, 
 mode = auto',
 baseline = None
 restore_best_weights = False ) x_traindata, y_traindata, batch_size = 64, epochs = 50, verbose = auto’, validation_data = (X_testdata, Y_testdata), callbacks = Checkpoint_create )
model.evaluate( x_testdata, y_testdata, batch_size = 64, verbose = auto’, callbacks = Checkpoint_create )
model.predict(x_testdata, batch_size = 64, verbose = 0, callbacks = Checkpoint_create )

In the EarlyStopping function the list of arguments are as follows:

  • monitor: This contains the list of metrics that needs to be evaluated by the model.

  • min_delta: Specifies the minimum change to be monitored below which the quantity that is monitored will be qualified as no improvement.

  • patience: Defines the number of epochs with no improvement after which the training will be stopped

  • verbose: Verbosity mode. 0 = silent. 1= progress bar, 2 = one line per epoch.

  • mode: There are 3 types of modes. ‘Auto’, ’Min’, ’Max’. ‘Auto’ automatically set to ‘Max’ if accuracy vise quantities are to be monitored, otherwise set to ‘Min’. For val_acc ‘Max’ is needed and for loss ‘Min’ is needed. This mode then takes the decision to overwrite the saved file according to quantity to be monitored.

  • baseline: Defines the baseline value for the monitored quantity. Training stops if the model doesn’t show improvement over the baseline.

  • Restore_best_weights: If true, then only the best value of the monitored quantity is restored on each epoch otherwise the model weights in the last step of the training are used.

Now let’s see callbacks that can be used to stream events to a server. This callback RemoteMonitor is used when the model can be partially trained on one system and then can be passed to another system for further training. The calls in this are HTTP post. It’s a bit of an extra step to do and comes in handy in rare cases. So there’s no need to surely include this in your training. Only use this when training on more than one system is needed.

remotemonitor_cp = tf.keras.callbacks.RemoteMonitor( 
    root = ‘https: //localhost:9000’, 
    path = '/publish/epoch/end/’,
    field = ‘data’, 
    headers = None, 
send_as_json = False)

In the RemoteMonitoring function the list of arguments are as follows:

  • root: string argument, Gives the root URL to the target server

  • path: string argument, Gives path with root to which the event will be sent

  • field: string argument, JSON field under which the data will be stored

  • headers: dictionary argument, These are optional custom HTTP headers

  • send_as_json: If true, then the content type of the request will be “application/json” otherwise the serialized JSON will be sent.

We have many other callbacks such as:

  • LearningRateScheduler : Which will update the learning rate at the beginning of every epoch and provides this update to the optimizer.

  • ReduceLROnPlateau : which reduces the learning rate when the metric has stopped performing i.e., when the value of metrics is plateau

  • LambdaCallback : these are used for custom creation of a callback .In this callback we can add checkpoints at times like when the epoch will begin or end, when the batch will begin or end, when the training will begin or end and more like this.

  • TerminateOnNan : this will terminate the training whenever NaN loss value is encountered.

  • CSVLogger : saves the epoch results to a CSV file

  • ProgbaLogger : prints the metrics to stdout (standard output)

Next is the TensorBoard callback. TensorBoard is basically a part of TensorFlow which acts as a visualization tool for it.

It can be used for:

  • Getting metrics summary plots

  • Training graph visualization

  • Activate histograms

  • And for sampled profiling.

Checkpoint_create = tf.keras.callbacks.TensorBoard( 
    log_dir = 'logs’, 
    histogram_freq = 1, 
    write_graph = False, 
    write_images = False, 
    write_steps_per_second = True,
    update_freq = batch’,
    profile_batch = 2
    embeddings_freq = 4

In the TensorBoard function the list of arguments are as follows:

  • log_dir: stores the path of the directory where the log files need to be saved

  • histogram_freq: Gives the frequency with which the activation and weight histograms will be computed for the layers of the model.

  • write_graph: If true, then the log file saves the graph visualized in TensorBoard

  • write_images: Argument to tell whether to write model weights to visualize images in TensorBoard or not.

  • write_steps_per_second: If true, then log file saves the training steps per seconf in both the epoch and batch frequency logging.

  • update_freq: Can be ‘batch’ or ‘epoch’ or an integer value. ‘Batch’ and ‘epoch’ writes the losses and metrics after each batch or epoch to the TensorBoard. If integer is specified, then after every ‘N’ batches the callbacks will be written

  • profile_batch: Profile the batch to sample compute characteristics.

  • embeddings_freq: Gives the frequency at which embedding layers will be visualized.

  • embeddings_metadata: Dictionary argument which maps the embedding layers to the filename of a file in which the metadata is to be saved

  • **kwargs: Other additional arguments. May include period.

Some of the custom callbacks that can be created are:

  • on_epoch_begin

  • on_epoch_end

  • on_predict_begin

  • on_predict_end

  • on_predict_batch_begin

  • on_predict_batch_end

  • on_train_begin

  • on_train_end

  • on_train_batch_begin

  • on_train_batch_end

  • on_test_begin

  • on_test_end

  • on_test_batch_begin

  • on_test_batch_end

Keras makes it simple and fast to define checkpoints in your model. It can help you save checkpoints in many different ways using both custom and predefined functions. It requires just one line of code to create checkpoint and another 2 lines to get it working. What could be more easy than easy? Right.

In tensorflow we can use the official Tensorflow Model Implementation GitHub Repository wherein the model has in-built checkpoints. The model during training will automatically create the checkpoints and the model.ckpt file


Now the main step after checkpointing is to use those checkpoints into model. Considering the checkpoints been created using the ModelCheckpoint method we can make predictions on our model as follows:

model.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics = ['accuracy'])
#load the dataset after this using numpy.loadtxt which will load the .csv file
#after that create the input and output variables that will split the dataset
Score = model.evaluate( Inp, Out, verbose = 0 )
print( “%s: %.2f%%” % (model.metrics[1] , Score[1]*100 ))

tf.train.CheckpointManager can be helpful for managing multiple checkpoints. It saves only the necessary checkpoints and delets the unuseful ones.

Manage = tf.train.CheckpointManager (
   checkpoint  =,
   directory  = ‘/tmp/model’,
   max_to_keep = 10,
   keep_checkpoint)_every_n_hours = None,
   checkpoint_name = checkpnt’,
   step_counter  = None,
   checkpoint_interval = None,
init_fn= None )


Stop, pause and resume training

As previously mentioned, the biggest use of checkpoints is to resume the training from a particular step to recover from any failure and thus to ensure rapid training process. In this too there are numerous small applications hidden like

Using them you can easily

  1. Take a snapshot of the model during the training

  2. Load the model into the memory

  3. Resume training from where you left.

Saving the best model (improving accuracy)

In addition to this they are used for online training to improve inference prediction accuracy with continuously learning. This is done typically by lowering the learning rate by an order of magnitude which in turn will increase the model accuracy. In this, the training process continuous even after the inference graph is obtained. For instance, an interim model can be used for prediction serving, while the model is still being trained over more recent dataset for maintaining the inference model freshness.

Model can be trained on different systems, if needed

Checkpoints are needed for moving training processes across different nodes or clusters. This shift may be required in cases such as server maintenance, hardware failure, network issues or maybe something else

Performs Transfer learning

Checkpoints are also used for performing transfer learning, where an intermediate model state is used as a seed, which is then trained for a different goal.


By reading this blog now you might have got intensive understanding of Model checkpointing which can simple do wonders. Checkpoint model is used to register, save, pause and resume session that was been done so that the work in progress and the energy are both saved. Using checkpoints makes our work easier and improves your deep learning model by many folds. So having a thorough knowledge about it is essential.

I hope you enjoyed reading this blog. To learn how to implement this checkpointing technique, stay tuned! I will come right back with a blog on practical implementation of checkpoints in deep learning soon. Till then, if you found this blog useful, please share it among your friends and on social media. For any queries, suggestions, or any other issue, please ping me in the comments.


[1] Alshareef, H. N., & Grigoras, D. (2016). Robust cloud management of MANET checkpoint sessions. Concurrency and Computation: Practice and Experience, 29(2), e3816.

[2]Luo, Z. (2000). Checkpointing for workflow recovery. Proceedings of the 38th Annual on Southeast Regional Conference - ACM-SE 38

[3]Check-N-Run: a Checkpointing System for Training Deep Learning Recommendation Models Assaf Eisenman1 , Kiran Kumar Matam1 , Steven Ingram1 , Dheevatsa Mudigere1 , Raghuraman Krishnamoorthi1 , Krishnakumar Nair1 , Misha Smelyanskiy1 , and Murali Annavaram1,2 ( arXiv:2010.08679 )

4,888 views0 comments


bottom of page