Churn Prediction with LightGBM and Artificial Neural Networks

Customer is the king. To sustain and excel at modern businesses, the importance of customer satisfaction cannot be emphasized enough. It should not come off as a surprise that customer satisfaction can do wonders that even an excellent product might not.

There are three elements of increasing revenue through customers: acquiring more customers, upselling the existing customers, and retaining customers. If we look at the underlying economics, the third element is very vital. Studies show that it costs 5 more times to acquire a customer, than to retain one. Which is why churn prediction has become an essential part of businesses these days.

Churn prediction deals with predicting which of our customers are more likely to churn out of our product in the (near) future. This is followed by developing strategies and offering benefits so that these customers do not unsubscribe or leave our services. 

In businesses, the number of customers and therefore the volume of data we have is huge. Even if they do work, the processing is slow, which turns out to be frustrating for developers.  Data scientists often have to try out and experiment things in order to derive insights, weed out irrelevant data, and establish various inter-variable relations. Imagine running a piece of code for half an hour only to find out that the results were unsatisfactory and another iteration would be required.

Which is why it makes sense to work with these humongous datasets on a GPU-enabled MLOps platform like that of Spell. This makes sure that the processing is fast no matter what, and allows the developers to solely focus on their work, rather than having to worry about processing power and integrability. 

In this article, we will use the Spell platform to look at how we can apply the principles of machine learning and neural networks on a customer dataset of about 10 million customers. The size of this data is coherent with the industry standards.

The data

WIthout further adieu, let us connect to a workspace in Spell, and analyze the data we have. 

  1. Login to Spell’s MLOps platform and go to the Workspaces tab on the left. 
  2. Click on 'Create Workspace;, specify the name of your workspace and specify the environment by selecting the required Machine Type. For this analysis, we have used V100 as our Machine Type. 

As with any other AI problem, the first task is to gather the data. We will use a dataset from Spell's public resource folder. There are 4 pieces of data, stored in the 4 different files:

  • The members file contains data about the customer demographic attributes, such as gender, age, city, and registration method. 
  • The transactions file contains details about customer transactions, such as transaction dates, payment method, payment amount and other membership details
  • The user_logs file that contains data on the customer behavior and sheds light into how long he/she listens to songs.

    This file has variables that tell us the number of songs which were played less than 25% of their length, played between 25% and 50% of their length, played between 50% and 75% of their length, and played between 75% and 100% of their length. Apart from that, it also has the number of unique songs played and the total number of seconds played.
  • Finally, the train file that tells whether the customer indeed churned or not. 

The code

In this tutorial, we will be using both LightGBM and Artificial Neural Networks for classification.

First, let's load the required libraries.

import numpy as np
import pandas as pd

import datetime
from datetime import datetime

import matplotlib.pyplot as plt
import seaborn as sb

from sklearn import model_selection
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.metrics import confusion_matrix

import lightgbm as lgb

from keras.models import Sequential
from keras.callbacks import ReduceLROnPlateau
from keras.callbacks import LambdaCallback
from keras.layers.normalization import BatchNormalization
from keras.optimizers import SGD
from keras.constraints import maxnorm
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense, Lambda
from keras.layers.core import Dropout
from keras import regularizers
from keras.models import Model, load_model
from keras.preprocessing import sequence
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.layers import Embedding
from keras.layers import Conv1D, GlobalMaxPooling1D
from keras.datasets import imdb
from keras.models import Sequential
from keras.layers import Dense, Activation, MaxPooling1D

Next, read in the files:

train = pd.read_csv('train_v2.csv')
user_logs = pd.read_csv('user_logs_v2.csv')
members = pd.read_csv('members_v3.csv')
transactions = pd.read_csv('transactions_v2.csv')

Let’s look at the dimensions of the data to get an idea of the volume that we are dealing with.

print(train.shape)  # (970960, 2)
print(user_logs.shape)  # (18396362, 9)
print(members.shape)  # (6769473, 6)
print(transactions.shape)  # (1431009, 9)

As we can see, we are dealing with a lot of data here. The user_logs dataset has a whopping 18.3 million rows. This is where the GPU-enabled Jupyter notebooks come in handy. Spell’s V100 instance is able to read this humongous data and do operations on it within a couple of seconds, something which your local machine would never be able to do.

Feature Engineering

We first look at the user logs data. The data for each customer is spread over multiple customers and multiple dates.

We first group this data on customer ID (msno) and find the sum of the number of songs that were played less than 25% of their length (num_25), the sum of the number of songs that were played between 25% and 50% of their length (num_50), the sum of the number of songs that were played between 75% and 100% of their length (num_75), the sum of the number of songs that were played more than 75% of their length (num_100).

user_logs_sum_data = user_logs.groupby('msno').sum()

Next, we group on customers to find the count of rows pertaining to each customer. Since each row per customer corresponds to a particular unique date, this grouping gives us the number of days for which the customer used the app.

user_logs_count_data = pd.DataFrame(user_logs.groupby('msno').date.count().reset_index())
user_logs_count_data.columns = ['msno', 'used_days']

Finally, the two data frames are merged, and the date column is removed:

user_logs_new_data = pd.merge(user_logs_sum_data, user_logs_count_data, how = 'inner', left_on = 'msno', right_on = 'msno')

del user_logs_new_data['date']

The final user logs data is then merged with the train dataset. Please note that we plan to append the train dataset with the features we generate from each given dataset.

all_data = pd.merge(train, user_logs_new_data, how='left', left_on='msno', right_on='msno')

Next, we turn our attention to the transactions dataset. 

We first remove the transaction_date and membership_expire_date as they won’t be required in our analysis. We then group the data by the customer ID (msno) and find the mean values of payment_plan_days, plan_list_price, and actual_amount_paid. Finally, this is merge the train data.

del transactions['transaction_date']
del transactions['membership_expire_date']

transactions_mean_data = transactions.groupby('msno').mean()
transactions_mean_data = transactions_mean_data[['payment_plan_days', 'plan_list_price', 'actual_amount_paid']]

all_data = pd.merge(all_data, transactions_mean_data, how='left', left_on='msno', right_on='msno')

Next, we look at the members data. It has data about the customer demographic attributes, such as gender, age, city, and registration method.  In the variable gender, we replace male with 1 and female with 0. Furthermore, using the registration_init_time variable, we create a new variable, num_days, to which we assign the number of days from the initial app launch date of 31 March 2017 to the registration_init_time.

gender_new = [1 if x == 'male' else 0 for x in members.gender ]
members['gender_new'] = gender_new

current = datetime.strptime('20170331', "%Y%m%d").date()
members['num_days'] = members.registration_init_time.apply(lambda x: (current - datetime.strptime(str(int(x)), "%Y%m%d").date()).days if pd.notnull(x) else "NAN" )

Finally, we remove the variables which are not required, and merge the obtained dataset to the train dataset.

members_all_data = members

del members_all_data['city']
del members_all_data['bd']
del members_all_data['registered_via']
del members_all_data['registration_init_time']
del members_all_data['gender']

all_data = pd.merge(all_data, members_all_data, how='left', left_on='msno', right_on='msno')

Note that all the variable transformations, additions and removals that we are doing in this article are for simplicity and illustration purposes only. There are multiple ways in which the given variables can be treated. In that regard, please feel free to experiment your way through the given datasets and draw out more and different features if you can.

We fill the missing values in the all_data data with -1.

all_data = all_data.fillna(-1)

cols = [c for c in all_data.columns if c not in ['is_churn','msno']]
X = all_data[cols]
Y = all_data['is_churn']

Now, we split the data into training and testing data. Within the training data, we further split it into actual training and validation data. 

test_size = 0.3
seed = 7
X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X,Y, test_size=test_size, random_state = seed)

validation_size = 0.2
seed = 7
X_train, X_validation, Y_train, Y_validation = model_selection.train_test_split(X_train,Y_train, test_size=test_size, random_state = seed)

Now that we have all the data ready, we are ready to proceed with modelling. We intend to employ a LightGBM model and set up its parameters as follows. 

lgb_params = {
    'learning_rate': 0.01,
    'application': 'binary',
    'max_depth': 40,
    'num_leaves': 3300,
    'verbosity': -1,
    'metric': 'binary_logloss'
d_train = lgb.Dataset(X_train, label=Y_train)
d_valid = lgb.Dataset(X_validation, label=Y_validation)
watchlist = [d_train, d_valid]

Finally, we fit the LightGBM model.

model = lgb.train(lgb_params, train_set=d_train, num_boost_round=1000, valid_sets=watchlist, early_stopping_rounds=50, verbose_eval=10)

Note that this step might take some time. However, this is much faster than what would have happened on a normal CPU. For data of this size, a CPU might sometimes take hours to train such a complex model that LightGBM is.

Once the model is trained, we observe how it performs on test data.

lgb_pred = model.predict(X_test)
lgb_pred_01 = [0 if x < 0.5 else 1 for x in lgb_pred]
print(confusion_matrix(Y_test, lgb_pred_01))

Confusion Matrix looks something like this:

[[262833   2186]
 [ 12190  14079]]

Let’s try to calculate some metrics based on the confusion matrix obtained.

accuracy = (262833 + 14079)/(262833 + 14079 + 2186 + 12190)
# 0.9506467825657081

recall = 14079/(14079 + 2186)
# 0.8656009837073471

fpr = 12190/(12190 + 262833)
# 0.04432356566541707

We see that the model has performed very well. We have more than 95% overall accuracy. If we look at the recall which basically tells us how many of the churned customers our model had correctly predicted, we observe that it is 86%. Finally, the false positive rate is only 4% which tells us that only 4% of the customers predicted as those who would churn were actually good customers, so we are not missing out on a lot of customers.

Artificial Neural Networks

Let's now train an Artificial Neural Network model on the same dataset, and see how it fares. We will use a simple feedforward neural network written in PyTorch.

lsize = 128

model = Sequential()
model.add(Dense(lsize, input_dim=int(X_train.shape[1]),activation='relu'))
model.add(Dense(int(lsize/2), activation='relu'))
model.add(Dense(int(lsize/4),kernel_regularizer=regularizers.l2(0.1), activation='relu'))
model.add(Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy', optimizer='adadelta', metrics=['accuracy'])

test_size = 0.3
seed = 123
X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X,Y, test_size=test_size, random_state = seed)

# Fit the model, Y_train, epochs=10, batch_size=1026, 
                    validation_split=0.2, verbose=1)

Again, while this step might take some time, it is much faster than how a CPU would perform. Once the model fitting and training is done, let’s how it performs on testing data:

predictions = model.predict(X_test)

prediction_list = []
for i in range(len(predictions)):
    if predictions[i][0] < 0.5:
confusion_matrix(Y_test, prediction_list) 

Confusion Matrix:

array([[264083,   1204],
       [ 14494,  11507]])

Let’s look at the metrics again.

accuracy_nn = (264083 + 11507)/(264083 + 11507 + 1204 + 14494)
accuracy_nn  # 0.9461083189146137

recall_nn = 11507/(11507 + 1204)
recall_nn  # 0.9052788922980096

fpr_nn = 11507/(11507 + 264083)
fpr_nn  # 0.0417540549366813

We see that the Artificial Neural Network has performed better as far as recall is concerned while maintaining similar levels of accuracy and false positive rates. The recall obtained is 90.5% compared to 86.5% as was observed from the LightGBM model.

This brings us to the end of this article. To try this tutorial out, sign on to Spell’s MLOps platform, using this link and get a $10 free GPU credit once you’ve made an account. Also, please feel free to post your questions and queries on this Slack link.

Ready to Get Started?

Create an account in minutes or connect with our team to learn how Spell can accelerate your business.