Optimizing PyTorch models for fast CPU inference using Apache TVM

Apache TVM is a relatively new Apache project that promises big performance improvements for deep learning model inference. It belongs to a new category of technologies called model compilers: it takes a model written in a high-level framework like PyTorch or TensorFlow as input and produces a binary bundle optimized for running on a specific hardware platform as output.

In this blog post, we'll take TVM through its paces. We'll discuss the basic concepts behind how it works, then try installing it and benchmarking it on a test model running on Spell.

To follow along in code, check out the GitHub repo.

Basic concepts

TVM is, at its core, a model compiler.

If you are familiar with compiled programming languages, you already know that compiled programming languages are almost strictly faster than interpreted ones (think C versus Python). This is because the addition of the compiler step allows for optimizations, both in high-level representation of the code (for example, loop unrolling) and in its low-level execution (for example, coercing operands to and from types natively understood by the hardware processors), that can make code execution faster by an order of magnitude.

The goal of model compilation is very similar: take a model written in an easy-to-write high-level framework, like PyTorch. Then, compile its computational graph into a binary object that is only optimized to run on one specific hardware platform. TVM supports a pretty broad range of different target platforms — some expected, some exotic. The Getting Started page in the TVM documentation used to feature the following diagram of supported backends:

The range of platforms that TVM supports is definitely a strength of the project. For example, the model quantization API in PyTorch only supports two target platforms: x86 and ARM. Using TVM, you can compile models that run on native macOS, NVIDIA CUDA—or even, via WASM, the web browser.

The process of producing the optimized model binary begins with translating the computational graph into TVM's internal high-level graph format, Relay. Relay is a usable high-level model API—you can even construct new models from scratch in it—but it mostly serves as a unified starting point for further model optimization.

TVM applies some high-level optimizations to the graph at the Relay level, then lowers it into a low-level IR called "Tensor Expressions" (TE) through a process it calls the "Relay Fusion Pass". At the TE level, the computational graph is split into a set of subgraphs that the TVM engine determines are good optimization targets.

The last and most important step in the TVM optimization process is tuning. During the tuning step, TVM makes guesses about the order of operations for computational tasks in the graph (the "schedule") that achieves the highest performance (fastest inference time) on the chosen hardware platform.

Interestingly enough, this is not a deterministic problem—there's simply too many valid possible ordering and too much non-determinism in how fast any particular operation will execute on any given hardware platform, given every other computational process it's running. Instead TVM constructs a search space over the computational space, then runs an XGBoost model with a custom loss function over that space to find the best schedule.

If this seems extremely complicated, that's because it is. Luckily you don't have to know any of the details of how TVM works to use it, as its high-level API takes care of most of the details for you.

Installing TVM

To get a sense of the performance benefits of TVM, I compiled a simple PyTorch Mobilenet model trained on CIFAR10 and tested its inference time before and after TVM compilation. The rest of this article will walk through this code. You can follow along with the code in the GitHub repository.

However, before you can use TVM you first have to install it. Unfortunately this is not at all an easy process. TVM does not currently distribute any wheels, and instead the documentation walks through installing TVM from source.

For the purposes of testing I am using a c5.4xlarge CPU instance on AWS via Spell. This is an x86 machine, so we'll need to install both TVM and a recent-enough version of the LLVM toolchain. Compiling TVM from source takes about 10 minutes, so this is a perfect use case for Spell's custom Docker image support—we can compile TVM and all of its dependencies into a Docker image once, then reuse that image for all our runs thereafter.

Here is the Dockerfile I used:

FROM ubuntu:18.04
WORKDIR /spell

# Conda install part
RUN apt-get update && \\
    apt-get install -y wget git && rm -rf /var/lib/apt/lists/*
ENV CONDA_HOME=/root/anaconda/
RUN wget \\
    <https://repo.anaconda.com/miniconda/Miniconda3-py37_4.8.3-Linux-x86_64.sh> \\
    && mkdir /root/.conda \\
    && bash Miniconda3-py37_4.8.3-Linux-x86_64.sh -fbp $CONDA_HOME \\
    && rm -f Miniconda3-py37_4.8.3-Linux-x86_64.sh
ENV PATH=/root/anaconda/bin:$PATH
# NOTE: Spell runs will fail if pip3 is not avaiable at the command line.
# conda injects pip onto the path, but not pip3, so we create a symlink.
RUN ln /root/anaconda/bin/pip /root/anaconda/bin/pip3
# TVM install part
COPY environment.yml /tmp/environment.yml
RUN conda env create -n spell -f=/tmp/environment.yml
COPY scripts/install_tvm.sh /tmp/install_tvm.sh
RUN chmod +x /tmp/install_tvm.sh && /tmp/install_tvm.sh

This uses the following conda environment.yml:

name: spell
channels:
  - conda-forge
dependencies:
  - numpy
  - pandas
  - tornado
  - pip
  - pip:
    # NOTE(aleksey): because of AskUbuntu#1334667, we need an old version of
    # XGBoost, as recent versions are not compatible with our base image,
    # Ubuntu 18.04. XGBoost is required in this environment because TVM uses
    # it as its search space optimization algorithm in the tuning pass.
    - xgboost==1.1.0
    - torch==1.8.1
    - torchvision
    - cloudpickle
    - psutil
    - spell
    - kaggle
    - tokenizers
    - transformers
    # NOTE(aleksey): this dependency on pytest is probably accidental, as
    # it isn't documented. But without it, the TVM Python package will not
    # import.
    - pytest

Here is the contents of install_tvm.sh. Note here that TVM build time variables are set in a config.cmake file, which I've manipulated here to point to the specific version of LLVM we're installing with apt-get:

#!/bin/bash
set -ex
# <https://tvm.apache.org/docs/install/from_source.html#install-from-source>
if [[ ! -d "/tmp/tvm" ]]; then
    git clone --recursive <https://github.com/apache/tvm> /tmp/tvm
fi
apt-get update && \
    apt-get install -y gcc libtinfo-dev zlib1g-dev \
        build-essential cmake libedit-dev libxml2-dev \
        llvm-6.0 \
        libgomp1  \
        zip unzip
if [[ ! -d "/tmp/tvm/build" ]]; then
    mkdir /tmp/tvm/build
fi
cp /tmp/tvm/cmake/config.cmake /tmp/tvm/build
mv /tmp/tvm/build/config.cmake /tmp/tvm/build/~config.cmake && \
    cat /tmp/tvm/build/~config.cmake | \
        # sed -E "s|set\(USE_CUDA OFF\)|set\(USE_CUDA ON\)|" | \
        sed -E "s|set\(USE_GRAPH_RUNTIME OFF\)|set\(USE_GRAPH_RUNTIME ON\)|" | \
        sed -E "s|set\(USE_GRAPH_RUNTIME_DEBUG OFF\)|set\(USE_GRAPH_RUNTIME_DEBUG ON\)|" | \
        sed -E "s|set\(USE_LLVM OFF\)|set\(USE_LLVM /usr/bin/llvm-config-6.0\)|" > \\
        /tmp/tvm/build/config.cmake
cd /tmp/tvm/build && cmake .. && make -j4
cd /tmp/tvm/python && /root/anaconda/envs/spell/bin/python setup.py install --user && cd ..

You can docker build on your local machine, use this Gist to build this image on an EC2 machine, or just reuse the pubic image I built for this demo to skip this compilation process altogether.

Compiling a model with TVM

With TVM installed we can move on to compiling our test model with it.

Note that TVM has both Python and CLI clients; I used the Python client for this project.

First thing first, we need a trained model. In fact, not just any model will do—the relevant method, tvm.relay.frontend.from_pytorch (docs), only takes a quantized model as input.

Quantization is the process of lowering operations in a model graph into a lower-accuracy representation (e.g. from fp32 to int8). This is a form of model performance optimization: the fewer the bits an operand has, the faster it is to operate on. Quantization is a very involved technique, and is itself very new—at time of writing, its PyTorch implementation (the torch.jit module) is still in beta. We've covered quantization in depth on this blog before, in the post "A developer-friendly guide to model quantization with PyTorch", so we'll omit those details here.

From the code, here is the quantized model definition:

def conv_bn(inp, oup, stride):
    return nn.Sequential(OrderedDict([
        ('q', torch.quantization.QuantStub()),
        ('conv2d', nn.Conv2d(inp, oup, 3, stride, 1, bias=False)),
        ('batchnorm2d', nn.BatchNorm2d(oup)),
        ('relu6', nn.ReLU6(inplace=True)),
        ('dq', torch.quantization.DeQuantStub())
    ]))

def conv_1x1_bn(inp, oup):
    return nn.Sequential(OrderedDict([
        ('q', torch.quantization.QuantStub()),
        ('conv2d', nn.Conv2d(inp, oup, 1, 1, 0, bias=False)),
        ('batchnorm2d', nn.BatchNorm2d(oup)),
        ('relu6', nn.ReLU6(inplace=True)),
        ('dq', torch.quantization.DeQuantStub())
    ]))

def make_divisible(x, divisible_by=8):
    import numpy as np
    return int(np.ceil(x * 1. / divisible_by) * divisible_by)

class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(OrderedDict([
                ('q', torch.quantization.QuantStub()),
                # dw
                ('conv2d_1', nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False)),
                ('bnorm_2', nn.BatchNorm2d(hidden_dim)),
                ('relu6_3', nn.ReLU6(inplace=True)),
                # pw-linear
                ('conv2d_4', nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False)),
                ('bnorm_5', nn.BatchNorm2d(oup)),
                ('dq', torch.quantization.DeQuantStub())
            ]))
        else:
            self.conv = nn.Sequential(OrderedDict([
                ('q', torch.quantization.QuantStub()),
                # pw
                ('conv2d_1', nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False)),
                ('bnorm_2', nn.BatchNorm2d(hidden_dim)),
                ('relu6_3', nn.ReLU6(inplace=True)),
                # dw
                ('conv2d_4', nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False)),
                ('bnorm_5', nn.BatchNorm2d(hidden_dim)),
                ('relu6_6', nn.ReLU6(inplace=True)),
                # pw-linear
                ('conv2d_7', nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False)),
                ('bnorm_8', nn.BatchNorm2d(oup)),
                ('dq', torch.quantization.DeQuantStub())
            ]))

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

Here is the function which performs the quantization pass:

def prepare_model(model):
    model.train()
    model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    model = torch.quantization.fuse_modules(
        model,
        [
            # NOTE(aleksey): 'features' is the attr containing the non-head layers.
            ['features.in_conv.conv2d', 'features.in_conv.batchnorm2d'],
            ['features.inv_conv_1.conv.conv2d_1', 'features.inv_conv_1.conv.bnorm_2'],
            ['features.inv_conv_1.conv.conv2d_4', 'features.inv_conv_1.conv.bnorm_5'],
            *[
                *[[f'features.inv_conv_{i}.conv.conv2d_1',
                   f'features.inv_conv_{i}.conv.bnorm_2'] for i in range(2, 18)],
                *[[f'features.inv_conv_{i}.conv.conv2d_4',
                   f'features.inv_conv_{i}.conv.bnorm_5'] for i in range(2, 18)],
                *[[f'features.inv_conv_{i}.conv.conv2d_7',
                   f'features.inv_conv_{i}.conv.bnorm_8'] for i in range(2, 18)]
            ]
        ]
    )
    model = torch.quantization.prepare_qat(model)
    return model

Once we've defined and trained our quantized model to convergence, we're ready to pass it through the TVM optimization engine. The first step in this process is converting the computational graph out of traced PyTorch and into Relay:

import tvm
from tvm.contrib import graph_executor
import tvm.relay as relay

TARGET = "llvm -mcpu=skylake-avx512"

def get_tvm_model(traced_model, X_ex):
    mod, params = relay.frontend.from_pytorch(
        traced_model, input_infos=[('input0', X_ex.shape)]
    )

    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target=TARGET, params=params)

    dev = tvm.device(TARGET, 0)
    module = graph_executor.GraphModule(lib["default"](dev))

    module.set_input("input0", X_ex)
    module.run()  # smoke test

    # mod and params are IR structs used downstream.
    # module is a Relay Python callable.
    return mod, params, module

The method begins by calling into tvm.relay.frontend.from_pytorch. from_pytorch expects two things: the traced PyTorch module (traced_model here), and a struct explaining the model input shape. In this code, X_ex is an example batch I sampled from the training loop's dataloader, so the input shape is derived from that, X_ex.shape.

Notice that the input has a name, input0. This name parameter is required because Relay names its graph inputs. Thus TVM expects us to set a name even though PyTorch has no such concept, but its actual value doesn't matter.

The next call, relay.build, is what actually constructs the Relay computational graph. Its most important parameter is target; this is a string representation of the hardware platform you are running this code on (and targeting). It is important to set this string to match your target platform as specifically as possible, but unfortunately I don't see a list of string parameters anywhere in the documentation. I'm running this code on a c5.4xlarge instance on AWS using Spell, which is backed by a chip from the Intel Xeon Platinum 8000 series, hence the target parameters:

lib and mod are pointers to C (?) blobs that aren't usable directly. Wrapping lib in a GraphExecutor wraps in a Relay API, creating a module that we can call into directly from Python.

The last and most important step is tuning:

def tune(mod, params, X_ex):
    number = 10
    repeat = 1
    min_repeat_ms = 0
    timeout = 10

    # create a TVM runner
    runner = autotvm.LocalRunner(
        number=number,
        repeat=repeat,
        timeout=timeout,
        min_repeat_ms=min_repeat_ms,
    )

    tuning_option = {
        "tuner": "xgb",
        "trials": 10,
        "early_stopping": 100,
        "measure_option": autotvm.measure_option(
            builder=autotvm.LocalBuilder(build_func="default"), runner=runner
        ),
        "tuning_records": "resnet-50-v2-autotuning.json",
    }

    tasks = autotvm.task.extract_from_program(
        mod["main"], target=TARGET, params=params
    )

    for i, task in enumerate(tasks):
        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
        tuner_obj = XGBTuner(task, loss_type="rank")
        tuner_obj.tune(
            n_trial=min(tuning_option["trials"], len(task.config_space)),
            early_stopping=tuning_option["early_stopping"],
            measure_option=tuning_option["measure_option"],
            callbacks=[
                autotvm.callback.progress_bar(
                    tuning_option["trials"], prefix=prefix
                ),
                autotvm.callback.log_to_file(tuning_option["tuning_records"]),
            ],
        )

    with autotvm.apply_history_best(tuning_option["tuning_records"]):
        with tvm.transform.PassContext(opt_level=3, config={}):
            lib = relay.build(mod, target=TARGET, params=params)

    dev = tvm.device(str(TARGET), 0)
    optimized_module = graph_executor.GraphModule(lib["default"](dev))

    optimized_module.set_input("input0", X_ex)
    optimized_module.run()  # dry run test

    return optimized_module

This code uses the XGBoost library to perform an optimization run over the Relay model, finding a schedule for this computational graph that's as close to optimal as feasible given the selected time constraints. There's a lot of boilerplate involved in this code, but you don't have to understand what every individual line does to get the gist of it.

Notice that, for the sake of time, we are performing a trial with just 10 test runs in it. For production use cases TVM's Getting Started with Python guide recommends 1500 test runs for CPUs and 3000 or so for GPUs.

Benchmarking the resulting model

I timed running a batch of data through two different versions of this model on CPU and timed the average time to inference over several runs. The first was a base PyTorch model, one without quantization nor compilation. The second was the fully optimized model: a MobileNet that has been quantized, compiled, and tuned using the code from the previous section. You can see the benchmark code here. Here were the results:

Inference times for the compiled version of the model were over 30 times faster than for the base model!

In fact, remarkably, the model compiled on CPU ran about as fast as the base model on GPU (a g4dn.xlarge, e.g. an NVIDIA T4 instance) did. The CPU in question is a c5.4xlarge, which costs $0.68/hour at time of writing. The GPU in question is a g4dn.xlarge, which costs $0.526/hour. So the performance gains that quantization and model compilation produces makes CPU and GPU serving almost equivalently effective, which is remarkable when you consider just how much slower the model was prior to optimization.

Note that not all of this performance boost can be attributed to TVM—some of it comes from the quantization that was applied to the model in PyTorch before the compilation steps. Again, I recommend reading my previous article on model quantization to learn more about this API.

Happy training! ✌️

Ready to Get Started?

Create an account in minutes or connect with our team to learn how Spell can accelerate your business.