Get Signal Before Timeout on Slurm¶
Slurm requires jobs to specify a maximum run duration. Jobs that exceed this
duration will get killed. This can be a problem in cases where you do not know
the exact run duration of your jobs in advance. Fortunately, Slurm can be
configured to send a signal to the job a bit before the timeout. This allows
the job to save a checkpoint with the current results and use cluster_utils’
exit_for_resume()
to terminate. cluster_utils will then
automatically restart the job, allowing it to load the previously saved checkpoint and
resume the computations.
Below is a small example on how to configure cluster_utils such that timeout-warning signals are send. Please note that it is still up to you to catch that signal in your code and react accordingly. An example on how this can be done is shown below as well.
grid_search.toml:
optimization_procedure_name = "timeout_signal_test"
results_dir = "/tmp/slurm_timeout_signal_example"
generate_report = "when_finished"
script_relative_path = "examples/slurm_timeout_signal/main.py"
remove_jobs_dir = false
restarts = 1
[git_params]
branch = "master"
[environment_setup]
[cluster_requirements]
partition = "cpu-galvani"
request_cpus = 1
memory_in_mb = 1000
request_time = "00:04:00" # jobs get killed after 4 minutes
signal_seconds_to_timeout = 30 # sent signal at least 30 seconds before timeout
[fixed_params]
[[hyperparam_list]]
param = "x"
values = [0, 1]
[[hyperparam_list]]
param = "y"
values = [0, 1]
The relevant part here is the definition of signal_seconds_to_timeout
in the
[cluster_requirements]
section. When defining it, Slurm will be configured to send a USR1
signal to warn about the approaching timeout. The value is the approximate time in seconds before the TIMEOUT at which the signal will be sent. Make sure to choose a value large enough to allow your job to actually save the intermediate data before the timeout is reached.
The main script of the job can then look something like this:
main.py:
"""Minimal example on how to use the timeout signal sent by Slurm."""
import json
import pathlib
import signal
import sys
import time
import cluster_utils
received_timeout_signal = False
def timeout_signal_handler(sig, frame):
global received_timeout_signal
print("Received timeout signal")
# simply set a flag here, which is checked in the training loop
received_timeout_signal = True
def main() -> int:
"""Main function."""
params = cluster_utils.initialize_job()
n_training_iterations = 60
start_iteration = 0
# register signal handler for the USR1 signal
signal.signal(signal.SIGUSR1, timeout_signal_handler)
checkpoint_file = pathlib.Path(params.working_dir) / "checkpoint.json"
# load existing checkpoint
if checkpoint_file.exists():
print("Load checkpoint")
with open(checkpoint_file) as f:
chkpnt = json.load(f)
start_iteration = chkpnt["iteration"]
for i in range(start_iteration, n_training_iterations):
print(f"Training iteration {i} with x = {params.x}, y = {params.y}")
time.sleep(10) # dummy sleep instead of an actual training
if received_timeout_signal:
print("Save checkpoint and exit for resume.")
# save checkpoint
with open(checkpoint_file, "w") as f:
json.dump({"iteration": i + 1}, f)
# exit and ask cluster_utils to restart this job
cluster_utils.exit_for_resume()
# just return some dummy metric value here
metrics = {"result": params.x + params.y, "n_iterations": i}
cluster_utils.finalize_job(metrics, params)
return 0
if __name__ == "__main__":
sys.exit(main())
A signal handler is registered with signal.signal(signal.SIGUSR1,
timeout_signal_handler)
. This means the given function will be called when
the process receives a USR1
signal. What this function does will then
depend on the actual application. In the example, it simply sets a flag which
will be checked in each iteration of the dummy training loop. If set True, a
checkpoint will be saved and the script terminates with
exit_for_resume()
.
Note
This example is included in cluster_utils/examples/slurm_timeout_signal
and can be directly run from there.
Warning
In case you are using a wrapper script around your main.py (can for example be needed for some environment setup inside containers), the signal will only be sent to the wrapper script and not automatically be forwarded to the main.py process. So in this case, you need to catch the signal in the wrapper script as well and sent it to the child process from there.