Get Signal Before Timeout on Slurm

Slurm requires jobs to specify a maximum run duration. Jobs that exceed this duration will get killed. This can be a problem in cases where you do not know the exact run duration of your jobs in advance. Fortunately, Slurm can be configured to send a signal to the job a bit before the timeout. This allows the job to save a checkpoint with the current results and use cluster_utils’ exit_for_resume() to terminate. cluster_utils will then automatically restart the job, allowing it to load the previously saved checkpoint and resume the computations.

Below is a small example on how to configure cluster_utils such that timeout-warning signals are send. Please note that it is still up to you to catch that signal in your code and react accordingly. An example on how this can be done is shown below as well.

grid_search.toml:

optimization_procedure_name = "timeout_signal_test"
results_dir = "/tmp/slurm_timeout_signal_example"
generate_report = "when_finished"
script_relative_path = "examples/slurm_timeout_signal/main.py"
remove_jobs_dir = false
restarts = 1

[git_params]
branch = "master"

[environment_setup]

[cluster_requirements]
partition = "cpu-galvani"
request_cpus = 1
memory_in_mb = 1000
request_time = "00:04:00"  # jobs get killed after 4 minutes
signal_seconds_to_timeout = 30  # sent signal at least 30 seconds before timeout

[fixed_params]

[[hyperparam_list]]
param = "x"
values = [0, 1]

[[hyperparam_list]]
param = "y"
values = [0, 1]

The relevant part here is the definition of signal_seconds_to_timeout in the [cluster_requirements] section. When defining it, Slurm will be configured to send a USR1 signal to warn about the approaching timeout. The value is the approximate time in seconds before the TIMEOUT at which the signal will be sent. Make sure to choose a value large enough to allow your job to actually save the intermediate data before the timeout is reached.

The main script of the job can then look something like this:

main.py:

"""Minimal example on how to use the timeout signal sent by Slurm."""

import json
import pathlib
import signal
import sys
import time

import cluster_utils

received_timeout_signal = False


def timeout_signal_handler(sig, frame):
    global received_timeout_signal

    print("Received timeout signal")
    # simply set a flag here, which is checked in the training loop
    received_timeout_signal = True


def main() -> int:
    """Main function."""
    params = cluster_utils.initialize_job()

    n_training_iterations = 60
    start_iteration = 0

    # register signal handler for the USR1 signal
    signal.signal(signal.SIGUSR1, timeout_signal_handler)

    checkpoint_file = pathlib.Path(params.working_dir) / "checkpoint.json"

    # load existing checkpoint
    if checkpoint_file.exists():
        print("Load checkpoint")
        with open(checkpoint_file) as f:
            chkpnt = json.load(f)
            start_iteration = chkpnt["iteration"]

    for i in range(start_iteration, n_training_iterations):
        print(f"Training iteration {i} with x = {params.x}, y = {params.y}")
        time.sleep(10)  # dummy sleep instead of an actual training

        if received_timeout_signal:
            print("Save checkpoint and exit for resume.")
            # save checkpoint
            with open(checkpoint_file, "w") as f:
                json.dump({"iteration": i + 1}, f)

            # exit and ask cluster_utils to restart this job
            cluster_utils.exit_for_resume()

    # just return some dummy metric value here
    metrics = {"result": params.x + params.y, "n_iterations": i}
    cluster_utils.finalize_job(metrics, params)

    return 0


if __name__ == "__main__":
    sys.exit(main())

A signal handler is registered with signal.signal(signal.SIGUSR1, timeout_signal_handler). This means the given function will be called when the process receives a USR1 signal. What this function does will then depend on the actual application. In the example, it simply sets a flag which will be checked in each iteration of the dummy training loop. If set True, a checkpoint will be saved and the script terminates with exit_for_resume().

Note

This example is included in cluster_utils/examples/slurm_timeout_signal and can be directly run from there.

Warning

In case you are using a wrapper script around your main.py (can for example be needed for some environment setup inside containers), the signal will only be sent to the wrapper script and not automatically be forwarded to the main.py process. So in this case, you need to catch the signal in the wrapper script as well and sent it to the child process from there.