Requeuing and checkpointing
On AI-LAB, jobs are limited to a maximum runtime of 12 hours. Once this time limit is reached, your job will be canceled, and any unsaved data will be lost. To prevent data loss, it’s important to implement checkpointing, which saves data at regular intervals. Additionally, you can set up automatic requeuing so that your job will restart automatically if it gets canceled, removing the need to manually resubmit it.
This guide explains how to set up checkpointing in various frameworks like Python, PyTorch, and TensorFlow, as well as how to configure automatic requeuing for your jobs.
Disclaimer
Using checkpointing and requeuing is done at your own risk. This guide is intended as a reference, but saving data via checkpointing is entirely the responsibility of the user. There may be errors or inaccuracies within this guide. If you encounter any issues or discover mistakes, we encourage you to provide feedback so we can improve. You can submit your feedback here.
Checkpointing methods
Checkpointing allows you to periodically save the state of your job so that it can be resumed later, even after an interruption. Different frameworks have different methods for implementing checkpointing. Below are examples for Python, PyTorch, and TensorFlow.
Python checkpointing
In Python, you can use the pickle module to periodically save and load the state of your job.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
|
Breakdown of the key components:
First, the script checks if a checkpoint file named checkpoint.pkl
exists using os.path.exists()
. If the file exists, it loads the checkpoint data using load_checkpoint
function and assigns it to data. If not, it initializes data with a dictionary containing a single key counter
initialized to 0.
Then, it enters an infinite loop (simulating a long-running process), where it increments the counter
key of the data dictionary, prints the current counter value, and simulates some work (in this case, a 1-second delay using time.sleep(1)
).
Every 5 iterations (if data['counter'] % 5 == 0)
, it saves the checkpoint by calling save_checkpoint
. If the process is interrupted by a keyboard interrupt (Ctrl+C), it saves the current checkpoint and prints a message before exiting.
PyTorch checkpointing
PyTorch provides native support for saving and loading model and optimizer states during training. More information about PyTorch checkpointing can be found here. Here's an example that shows how to checkpoint during training of a simple neural network:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
|
Breakdown of the key components:
Checkpoint Directory Setup (line 39-41): Creating a directory for storing checkpoints.
Checking for Existing Checkpoints (line 46-55): Checking for existing checkpoints and loading the latest one if available.
Saving Checkpoints (line 69-75): Saving the model's state, optimizer's state, and current loss at the end of each epoch to a uniquely named file based on the epoch number.
TensorFlow checkpointing
TensorFlow offers built-in functionality for saving model weights during training using the ModelCheckpoint
callback. More information about TensorFlow checkpointing can be found here. Here's an example:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
|
Breakdown of the key components:
checkpoint_path
: Specify the path where checkpoints will be saved. You can include dynamic elements such as epoch number in the file name to differentiate between checkpoints, like checkpoints/{epoch:d}.ckpt
cp_callback
: Create a ModelCheckpoint callback, which will save the model's weights at specified intervals during training. You can customize various parameters such as the file path, verbosity, and whether to save only the weights or the entire model.
model.load_weights(latest)
: Before starting training, check if there are existing checkpoints. If so, load the latest one to resume training from the last saved state. This ensures continuity in training even if interrupted.
Requeing jobs
After implementing checkpointing in your script, you have the option to set up automatic job requeuing in case your job gets cancelled. This is done by modifying a bash script that will automatically requeue the job if it's terminated due to exceeding the time limit. You can find an example script, requeue.sh
, on AI-LAB at /ceph/course/claaudia/docs/requeue.sh
.
requeue.sh
#!/bin/bash
#SBATCH --job-name=requeue_example
#SBATCH --time=00:01:00
#SBATCH --signal=B:SIGTERM@30
#SBATCH --gres=gpu:1
#SBATCH --cpus-per-task=15
#SBATCH --mem=24G
#####################################################################################
# tweak this to fit your needs
max_restarts=4
# Fetch the current restarts value from the job context
scontext=$(scontrol show job ${SLURM_JOB_ID})
restarts=$(echo ${scontext} | grep -o 'Restarts=[0-9]*' | cut -d= -f2)
# If no restarts found, it's the first run, so set restarts to 0
iteration=${restarts:-0}
# Dynamically set output and error filenames using job ID and iteration
outfile="${SLURM_JOB_ID}_${iteration}.out"
errfile="${SLURM_JOB_ID}_${iteration}.err"
# Print the filenames for debugging
echo "Output file: ${outfile}"
echo "Error file: ${errfile}"
## Define a term-handler function to be executed ##
## when the job gets the SIGTERM (before timeout) ##
term_handler()
{
echo "Executing term handler at $(date)"
if [[ $restarts -lt $max_restarts ]]; then
# Requeue the job, allowing it to restart with incremented iteration
scontrol requeue ${SLURM_JOB_ID}
exit 0
else
echo "Maximum restarts reached, exiting."
exit 1
fi
}
# Trap SIGTERM to execute the term_handler when the job gets terminated
trap 'term_handler' SIGTERM
#######################################################################################
# Use srun to dynamically specify the output and error files
srun --output="${outfile}" --error="${errfile}" singularity exec --nv /ceph/container/pytorch/pytorch_24.09.sif python torch_bm.py
In this script, we will run a PyTorch script (located at /ceph/course/claaudia/docs/torch_bm.py
) using the PyTorch container /ceph/container/pytorch/pytorch_24.09.sif
. You can modify this to run any script or container you need. Key parameters to pay attention to are:
#SBATCH --time=00:01:00
: This sets the time limit for your job. If your job exceeds this limit, it will be cancelled. If not specified, the default time limit is 12 hours.#SBATCH --signal=B:SIGTERM@30
: This tells Slurm to send a SIGTERM signal 30 seconds before the job reaches the time limit, giving the job time to handle termination gracefully.max_restarts=4
: This defines the maximum number of times your job will be automatically requeued if it gets cancelled. In this example, the job will be requeued up to four times before it is finally terminated.