Making model training scripts robust to spot interruptions

The secret to cost-effective machine learning model training on the cloud is spot instances. Spot instances are an alternative to the standard on-demand instances on services like AWS EC2 and GCP GCE that come at a steep discount. The catch is that spot instances may be interrupted at any time: taken away from you and given instead to a customer willing to pay a premium (non-interruptible) for an on-demand instance.

In exchange for this added risk, cloud providers provide steep discounts on instance costs — typically around 66% or so.

This makes spot instances ideal for expensive compute workloads that are robust to failure. Machine learning training jobs (especially deep learning ones) definitely fit the bill when it comes to “expensive”. And by incorporating model checkpointing into your training scripts, they can easily be made failure-tolerant as well: just restart your training run from a saved checkpoint file whenever an interrupt occurs.

At Spell, we recommend all our customers use spot instances and reentrant training scripts whenever possible. In a previous blog post, “Reduce cloud GPU model training costs by 66% using spot instances”, I go over the economic argument for and additional considerations around using spot instances for model training, including some benchmarks showing savings of up to around $200 on large training jobs on the cloud (24 hours on a V100x4). For anyone unfamiliar with spot instances, I recommend giving that article a read!

In this article, I will introduce a new feature of the Spell platform that lets you go even further with spot instances on Spell: auto-resumption.

To follow along with an interactive tutorial in code, check out the reentrancy demo in our spellml/examples repository.

Getting the most from spot compute — reentrancy

The first step to using spot instances for model training is adding what we call reentrancy to your training scripts. A reentrant training script is one which saves model checkpoints as it goes along, and can be parameterized to restart training from an existing checkpoint if needed.

This is already considered a best practice in the machine learning community, and making your training scripts reentrant is super-easy. Here’s a copy of the relevant code lines from our reentrancy demo:

NUM_EPOCHS = 50

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--from-checkpoint', type=str, dest='checkpoint', default='')
# ...
if args.checkpoint:
    first_remaining_epoch = int(args.checkpoint.split('_')[0]) + 1
    EPOCHS = range(first_remaining_epoch, NUM_EPOCHS)
    model.load_state_dict(torch.load(f'/spell/checkpoints/{args.checkpoint}'))
else:
    EPOCHS = range(NUM_EPOCHS)
for epoch in EPOCHS:
    train()
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'/spell/checkpoints/{epoch}_net.pth')

This training script saves a checkpoint file to disk every five epochs. Spell automatically backs these files up to our virtual filesystem, SpellFS, even if the run is interrupted, allowing you to easily recover your training progress at a later time:

$ spell run \
    --machine-type t4 \
    --github-url https://github.com/spellrun/spell-examples.git \
    "python spot/train_reentrant.py --from-checkpoint '20_net.pth'"

The spot instances demo Jupyter notebook in our examples repository has some visual examples showing this in action.

Getting the most from spot compute — resumability

We implemented this feature in April and it’s already proven to be quite popular with our users. In August we added a new feature that takes this one step further: the --auto-resume flag to spell run. Quoting from our docs:

If your your script is written such that it idempotently resumes wherever it left of given a prior run’s disk state, then you can go one step further using Spell’s “Auto Resume” feature. […] When a run is interrupted by the respective cloud service and auto resume is enabled, Spell will create a new run with identical parameters, restore the interrupted run’s saved disk state, and queue it up. When a machine becomes available again (usually when the demand lowers), the resumed run will execute, continuing the computation of the interrupted run.

This is a super powerful feature because, subject to a little work on your part, it guarantees your runs on spot instances actually finish training without requiring any further manual labor on your part. This means that you can use spot instances everywhere — even for infrastructure-critical runs — potentially making your model training jobs substantially cheaper across the board.

We call training scripts that support this feature resumable. Here’s a PyTorch code snippet from our resumability demo script showing how you might implement this:

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--resume', action='store_true')
args = parser.parse_args()
# ...
if args.resume:
    if not os.path.exists("/spell/checkpoints/") or len(os.listdir("/spell/checkpoints/")) == 0:
        EPOCHS = range(NUM_EPOCHS)
    else:
        checkpoint_epoch = max(
            [int(re.findall("[0-9]{1,2}", fp)[0]) for fp in os.listdir("/spell/checkpoints/")]
        )
        model.load_state_dict(torch.load(f'/spell/checkpoints/{checkpoint_epoch}_net.pth'))
        first_remaining_epoch = checkpoint_epoch + 1
        EPOCHS = range(first_remaining_epoch, NUM_EPOCHS)
else:
    EPOCHS = range(NUM_EPOCHS)
for epoch in EPOCHS:
    train()
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'{epoch}_net.pth')

This training script writes model checkpoint files to the /spell/checkpoints/ path incrementally. If the --resume flag is specified, and a /spell/checkpoints/ directory already exists, it uses a regex to find the checkpoint file with the largest epoch number and initializes the model from that one.

Spell guarantees that this script will complete, even if run on a spot instance. Here’s how it works. First, we submit this run to Spell with the --auto-resume flag set:

$ spell run --machine-type v100-spot \
    --github-url 'https://github.com/spellrun/spell-examples.git' \
    --mount uploads/segmented-bob-ross-images:/mnt/segmented-bob-ross-images \
    --auto-resume \
    "python spot/train_resumable.py --resume"

After letting the job run for a few minutes I manually terminated the backing EC2 instance in the AWS web console. Spell detects that the machine went down and kicks into backup save mode: spinning up a new CPU instance with the terminated machine’s disk image attached that backs these files up to our virtual filesystem, SpellFS. The run logs show this process in action:

Aug 12, 2020, 12:12:23: running: Unexpected error or interruption during run, attempting save 
✨ Run is saving 
Aug 12, 2020, 12:12:33: saving: Unable to successfully save on machine, attempting backup save on new machine instead
✨ Run is backup_saving
Aug 12, 2020, 12:12:33: backup_saving: Starting CPU to perform backup save on
Aug 12, 2020, 12:18:23: backup_saving: Mounting existing volume on backup save machine
Aug 12, 2020, 12:19:07: backup_saving: Scanning for modified or new files from the run
Aug 12, 2020, 12:19:07: backup_saving: Saving '/spell/0_net.pth
Aug 12, 2020, 12:19:07: backup_saving: Saving '/spell/5_net.pth

Backup saving occurs for all runs on spot instances, not just ones with auto-resumption enabled. However, the next step is new! Immediately after backup saving is completed, Spell checks for the presence of the --auto-resume flag, and, if it finds one, automatically schedules a new run with the same command-line arguments:

Aug 12, 2020, 12:19:12: backup_saving: Auto-resume enabled for run: creating new run to resume interrupted run 
Aug 12, 2020, 12:19:12: backup_saving: Successfully created new run 151 to resume interrupted run
Aug 12, 2020, 12:19:12: backup_saving: Use 'spell logs -f 151' to view and follow the logs of this new run

This new run receives a copy of its predecessor’s disk state as input. Our resumable training script checks the /spell/checkpoints/ directory, finds some checkpoints, boots up the highest-numbered one, and resumes training from there:

Aug 12, 2020, 12:27:18: running: Finished epoch 6.
Aug 12, 2020, 12:27:27: running: Finished epoch 7.
Aug 12, 2020, 12:27:35: running: Finished epoch 8.
Aug 12, 2020, 12:27:43: running: Finished epoch 9.
Aug 12, 2020, 12:27:52: running: Finished epoch 10.
Aug 12, 2020, 12:27:52: running: Saved model to 10_net.pth.

Allowing us to complete our training process with no further effort on our end. 🔥

To try this code out yourself, check out the spot instance resources in our example repo on GitHub.

Ready to Get Started?

Create an account in minutes or connect with our team to learn how Spell can accelerate your business.