Using Grafana for dashboarding and alerting with Spell model servers

Spell model servers are equipped with a Grafana dashboard frontend and a Prometheus database in the backend, a combination of tools that allows for sophisticated instrumentation and monitoring of your models' performance over time.

In my previous blog post, "Model observability in Spell model servers with Prometheus and Grafana", I introduced how these two tools work in general, and how to make them work with your Spell model servers in particular.

In this blog post we will build on that knowledge by writing some queries for a real model, dashboarding them using Grafana, and setting up Grafana alerts on this data for detecting model drift. This allows your model engineers to keep tabs on the models they have deployed without requiring buying or building additional third-party monitoring solutions!

Note that this article assumes familiarity with Spell model servers, Prometheus, and Grafana.

For readers unfamiliar with these tools, I strongly suggest reading my previous article, "Model observability in Spell model servers with Prometheus and Grafana", before this one.

Setup

Before we can build a model server dashboard we must first have a model server to dashboard. For the purposes of this article, I built a simple instrumented serving script for our spellml/cnn-cifar10 demo model. This is a PyTorch model that takes images as input and attempts to classify the image into one of the 10 classes included in the CIFAR10 dataset. Here's the code:

# serve_instrumented.py
import torch
from torch import nn
import torchvision

import numpy as np
import base64
from PIL import Image
import io
import os

from spell.serving import BasePredictor
import spell.serving.metrics as m

class CIFAR10Model(nn.Module):
    def __init__(
        self,
        conv1_filters=32, conv1_dropout=0.25,
        conv2_filters=64, conv2_dropout=0.25,
        dense_layer=512, dense_dropout=0.5
    ):
        super().__init__()
        self.cnn_block_1 = nn.Sequential(*[
            nn.Conv2d(3, conv1_filters, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(conv1_filters, conv2_filters, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(conv1_dropout)
        ])
        self.cnn_block_2 = nn.Sequential(*[
            nn.Conv2d(conv2_filters, conv2_filters, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(conv2_filters, conv2_filters, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(conv2_dropout)
        ])
        self.flatten = lambda inp: torch.flatten(inp, 1)
        self.head = nn.Sequential(*[
            nn.Linear(conv2_filters * 8 * 8, dense_layer),
            nn.ReLU(),
            nn.Dropout(dense_dropout),
            nn.Linear(dense_layer, 10)
        ])

    def forward(self, X):
        X = self.cnn_block_1(X)
        X = self.cnn_block_2(X)
        X = self.flatten(X)
        X = self.head(X)
        return X

class Predictor(BasePredictor):
    def __init__(self):
        self.clf = CIFAR10Model()
        self.clf.load_state_dict(torch.load("/model/checkpoints/model_final.pth", map_location="cpu"))
        self.clf.eval()
        self.transform_test = torchvision.transforms.Compose([
            torchvision.transforms.Resize((32, 32)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        self.labels = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
        self.pid = str(os.getpid())
        self.inference_time_hist = m.prometheus.Histogram(
            'inference_time', 'Model inference time',
            labelnames=['pid'], labelvalues=[self.pid],
        )
        self.prediction_confidence_hists = []
        self.prediction_counters = []
        for label in self.labels:
            self.prediction_confidence_hists.append(
                m.prometheus.Histogram(
                    'prediction_confidence',
                    'Model (self-reported) prediction confidence',
                    buckets=[round(0.1 * x, 1) for x in range(0, 11)],
                    labelnames=['pid'], labelvalues=[self.pid],
                )
            )
            self.prediction_counters.append(
                m.prometheus.Counter(
                    'prediction_value',
                    'Predicted value',
                    labelnames=['pid', 'label'],
                    labelvalues=[self.pid, label]
                )
            )

    def predict(self, payload):
        img = base64.b64decode(payload['image'])
        img = Image.open(io.BytesIO(img), formats=[payload['format']])
        img_tensor = self.transform_test(img)
        img_tensor_batch = img_tensor[np.newaxis]

        # log model inference time
        with self.inference_time_hist.time():
            scores = self.clf(img_tensor_batch)

        class_match_idx = scores.argmax()
        prediction_confidence = scores.max()

        # log prediction counter and prediction confidence
        self.prediction_confidence_hists[class_match_idx].observe(prediction_confidence)
        self.prediction_counters[class_match_idx].inc()

        class_match = self.labels[class_match_idx]

        return {'class': class_match}

This sample code uses the Prometheus client library to log one counter (for the total number of predictions rendered) and one histogram (for the prediction confidence) per CIFAR10 class (so ten total).

I simulated traffic against this model server endpoint using the following test script (watch -n 1 python test_model_server.py):

# test_model_server.py
from PIL import Image
from io import BytesIO
import requests
import base64
import os
import random

# pick a random image from the local directory
images = [file for file in os.listdir() if file.endswith(".jpg")]
img_fp = random.choice(images)

# parse it into the b64 encoded format understood by the model server
img = Image.open(img_fp)
img.convert("RGB")
buf = BytesIO()
img.save(buf, format="JPEG")
img_str = base64.b64encode(buf.getvalue())

# make the prediction request
resp = requests.post(
    "https://dev-aws-east-1.spell-external.dev.spell.services/spell-external/cifar10/predict",
    headers={"Content-Type": "application/json"},
    json={
        "image": img_str.decode("utf8"),
        "format": "JPEG"
    }
)

# print the result
print(resp.json())

Dashboarding

Dashboards are Grafana's "killer app". You can access your existing dashboards, or create new ones, by clicking on the "Manage" item under the "Dashboards" button on the left navbar. Readers following on in the code will notice that Spell also populates a default dashboard for you—we'll get back to that in a bit. Clicking through that menu should take you to the dashboard view/create menu:

Click on the "New Dashboard" button on this page to drop into Grafana's dashboard builder interface. A Grafana dashboard consists of a number of individual elements called panels. Open the panel editor interface by clicking on the "Add new panel" button on the initial default-empty panel that Grafana creates for you:

There are a lot of bells and whistles on this page, but the editing experience is (for the most part) relatively straightforward. We'll start by plugging in a query counting the total number of predictions (using the bottom panel), and adjusting the time-scale to "last 30 minutes" instead of the default "last 6 hours" (using the top right menu). Then, set a "Panel title" and "Description" and hit "Apply". Here is the finished product:

And here's what it looks like in situ on the dashboard page:

We can immediately tell that the model is predicting one class (Ship) a lot more often then the others. Good to know!

You can continue to add panels using the "Add panel" button on the menu bar at the top right. The best way to learn how to build dashboards in Grafana is honestly just to play around with it for a bit.

One important feature to be aware of is the fact that Grafana supports different visualization types. These are under the "Visualization" option on the right sidebar in the panel editing interface:

As you can see the default visualization (and the one we created just now) is of type "Graph". But Grafana supports a broad variety of other types that may be more appropriate for the specific data in question.

Here's the complete dashboard:

With the following queries:

# Current Predictions Per Second Per Class
sum by (label) (rate(prediction_value_total{service="model-serving-10494"}[1m]))
# Predictions Per Server Process
sum by (pid) (prediction_value_total{service="model-serving-10494"})
# Predictions Made
sum by (pod) (inference_time_count{service="model-serving-10494"})
# Predictions By Label
sum by (label) (prediction_value_total)
# P90 Model Inference Time
histogram_quantile(0.9, sum by (job, le) (rate(inference_time_bucket{job="model-serving-10494"}[24h])))

Alerting

There are technically three alerting mechanisms that can be used with the Prometheus and Grafana stack. The first is Prometheus alerts, which flow through a separate Prometheus-project-administered service called Prometheus Alertmanager. The second is Grafana legacy alerts, a service that has been packaged into Grafana for many years. The last option is Grafana 8 alerts—a brand new, gut rewrite of Grafana's alerting service that landed in the June 2021 Grafana 8.0 release.

For users just starting out with model server alerting on Spell, we recommend starting out with using Grafana legacy alerts. The Prometheus Alertmanager is powerful and expressive but requires hand-editing config files, which is not a very ergonomic process on Spell (this may change in the future). As for new-style alerts, Spell currently ships Grafana 7.1.0 on our clusters; although this will change in the future, as of time of writing new-style Grafana alerts are still explicitly marked "alpha" in the Grafana documentation and release notes.

In order to send an alert, Grafana needs to have a place to send it to. This is configured using the "Notification channels" tab on the "Alerting" page:

For the purposes of demonstration let's set up a Slack alert. This alert type uses Slack's webhooks feature: see the guide Create Grafana Alert To Slack to see how this is set up on the Slack side.

Alerts are configured on the "Alert" panel. Note that only the "Graph" visualization type supports alerts—other visualization types will not have this tab available in their edit menu! Here is what that interface looks like:

Alerts are defined using a SQL-like syntax, by aggregating event values over a time interval within the panel's data stream. For demo purposes I've configured an alert that will fire whenever the maximum of the last five minutes of data in the P90 Model Inference Time panel is over 0.0042. This is a trivial alert, as the entire observed history of this particular model is above this value, but it serves to demonstrate how alerts work in Grafana.

After setting up the Slack integration, saving this alert, and waiting a few minutes for the Grafana alert to clear its pending period, checking the test Slack channel I created shows that the alert has landed!

Persistence

Note that, as of time of writing, dashboard definitions are not persisted across cluster updates (executions of the spell kube-cluster update command) or node failures (hard failures by the cluster cloud machine). We strongly recommend backing up your dashboard definitions using Grafana’s export/import tool, so that you can restore them if needed.

That concludes this demo. Happy training!

Ready to Get Started?

Create an account in minutes or connect with our team to learn how Spell can accelerate your business.