To Churn or not to Churn

🎵 That’s the question 🎵

Sparkify Project (Photo from Data Scientist Udacity Nanodegree Capstone)

Sparkify is an imaginary music streaming service, similar to Spotify or Pandora. With this project, I am going to analyze the users’ trends and predict if a user will churn or not from this digital service.

Problem Statement🎯

Churn Rates Prediction is a common and challenging problem for both digital and analogical companies. Understanding the reasons root of churn and taking measures before it occurs, is a good way to keep clients and a very important task for businesses. Important to say that all the findings can be translated to an analytical business, just selecting the equivalent features that are important in the analytical world.

Another aspect of this kind of predictions is measure the performance of the model, to check if it works well or not.

Additionally, the ability to manipulate large datasets in an effective way, with the SQL interface of PySpark, is a skill that is very high demanded today.

The goal of this project is to develop a model to predict whether a user will churn from the service and measure the performance of this model.

To get that goal, I will start with an exploration of the dataset and select the best features for our model. I will try different classification models, with hyperparameter tuning and metric evaluation.

If you think this is a naive problem, just imagine that you assume that your Sparkify service has 225 customers and each pays 10 €/month (27.000€/year income). If you consider every client will churn and you spend extra money to keep each customer with you (let’s say 2€/client = 450€), but there are 173 clients who don’t churn, so you’ll have lost 346€ (a 76% of the money you spent) added to 6.240€/year that you’ll lose forever, from the 52 users who finally will churn. In total you’ll lose 6.586€, almost a 25% of your income in a bad client retention strategy.

Figure 0: Lost Money for a bad client retention strategy

And then imagine that you change your strategy and consider every client will stay and you decide not spending money in retention strategies. If there are 52 clients that go away, you’ll be loosing 6.240€/year (more than a 23% of your income).

So, yes! Churn prediction is important for your 💸business💸.

Analysis 🔎

As a first step and as usual with these kind of tasks, I will start with an EDA (Exploratory Data Analysis) to explore the dataset to understand what are the relevant features and calculate statistics that describe our scenario. EDA is important to discover anomalies and characteristics about the data that we are going to process later.

The minidataset has 18 fields and 286.500 records and this is its schema:

|-- artist: string (nullable = true)
|-- auth: string (nullable = true) (values: Logged In/Cancelled)
|-- firstName: string (nullable = true)
|-- gender: string (nullable = true) (values: M/F)
|-- itemInSession: long (nullable = true)
|-- lastName: string (nullable = true)
|-- length: double (nullable = true)
|-- level: string (nullable = true) (values: free/paid)
|-- location: string (nullable = true)
|-- method: string (nullable = true) (values: PUT/GET)
|-- page: string (nullable = true)
|-- registration: long (nullable = true) (unix timestamp)
|-- sessionId: long (nullable = true)
|-- song: string (nullable = true)
|-- status: long (nullable = true) (values: 307/404/200)
|-- ts: long (nullable = true) (unix timestamp)
|-- userAgent: string (nullable = true)
|-- userId: string (nullable = true)

The categorical columns are: [‘artist’, ‘auth’, ‘firstName’, ‘gender’, ‘lastName’, ‘level’, ‘location’, ‘method’, ‘page’, ‘song’, ‘userAgent’, ‘userId’]

The numerical columns are: [‘itemInSession’, ‘length’, ‘registration’, ‘sessionId’, ‘status’, ‘ts’]

Methodology 📑

As a first step I have made a deep exploratory analysis of the dataset, I have produced several graphical visualizations to have a first sight overview.

I have then applied some cleansing and transformation to the initial dataset, in order to prepare it to the training phase.

After that I have made the definition of churn as one of the most crucial tasks is defining the Churn action, that is when a user cancels and leaves the Sparkify service. To identify when a user does churn, I have defined it is when a user selects “Cancellation Confirmation”.

Then I have done the “Feature Engineering” phase, where I have prepared the features I think they are important to predict if a user will leave our service.

This is a binary classifications problem (churn/not churn), so I have tried different classifications algorithms to discover which of them provide us with the best accuracy.

We have a high imbalanced dataset (52 users churned of a total of 225) so the different models I have tried they all suffered of overfitting.

I have done a training with different models and applied hyperparameter tuning and cross validation techniques.

And finally, I have measured how well or bad our models perform. As we have a small subset for those users who churned, I have used F1 Score as the metric to optimize our models. I’ll explain this decision with more details below.

  1. Null values

There are no null values neither in userId nor in sessionId, but there are some empty string in userId field that I have eliminated.

These are the columns with nulls/missing/empty values: {'artist': 58392, 'firstName': 8346, 'gender': 8346, 'lastName': 8346, 'length': 58392, 'location': 8346, 'registration': 8346, 'song': 58392, 'userAgent': 8346, 'userId': 8346}

2. Number of users

There are informations of 225 users, 52 of those churned. That means a 23.11% of the total, churned.

Figure 1: Churn users ratio

3. Gender rate

There are 104 female users and 121 male users on our dataset.

That means that a 46% of the users are females and a 54% are males.

Figure 2: Users ratio by Gender

4. Location

There are states and users grouped by those who churned or not.

Usually, the states with more churned users are those with more users who stay in the service.

Figure 3: Distribution of churn by Location

5. Operating System

Most users connect to Sparkify service with a Windows Operating Systems.

Both for free and paid subscription

Figure 4: Operating System by Subscription type

6. Pages visited by users

Most visited page is Next Song followed by Home page, for users who didn’t churn.

For those users who cancelled the subscription, the most visited page was “Cancellation Confirmation”

Figure 5: User Activities by churn

7. Time Span

The oldest record is from October 2018 (2018–10–01) and the most recent record is from December 2018 (2018–12–03) so we have a dataset of aproximately 2 months of information.

8. Subscriptions

What about the level of the users? Is there a difference between male and female subscriptors?

For both levels of subscription, the number of males is slightly higher, as we can see with this plot:

Figure 6: Subscription type by Gender

9. Number of songs by hour

And what is the distribution of songs a user listens to during the day?

It seems that the Sparkify users prefer the afternoon to use the service.

Figure 7: Number of played Songs by hour

Is there a day of the week, that users prefer?

With the graphic below, with days from 0 (Monday) to 6 (Sunday) we can see that Monday is the day when users don’t use Sparkify the most and Saturday is the favourite day to listen Sparkify music.

Figure 8: Number of played Songs by weekday

10. Membership lifetime

Is there a difference in membership time between the users who churned and who did not churn?

Figure 9: Membership lifetime by churn

With the previous boxplot graphic we can see that users who churn, stay less time in the Sparkify service than those who don’t churn. Users who churn stay 50 days in average. Users who don’t churn stay aprox. 75 days in average.


As I mentioned earlier, the dataset is imbalanced so it is prone to suffer from overfitting.

To avoid that, we can try several techniques:

  1. Change the proportion of train/test datasets

I’ve tried several combinations of training/test sets (60/40, 80/20), even I included training/validation/set. Finally I set the 70/30 proportion

2. Parameters, hyperparameter tuning and cross validation

This is the initial configuration for the different models I have tested with:

Figura 10: Initial parameters

At the begining I set higher values for the parameters, but long time executions made me change and lower the Depth and Iterations.

After training the models, I have performed a hyperparameter tuning technique with the , where I have set a range of parameters (numTress: 15,20,25, maxBins: 2,4 and maxDepth: 2,4) and make a search to find the best combination and also I have applied a 3-fold Cross Validation technique to improve the initial results. This technique takes 3 resamples of the training data and trains the model with this new group.

Figure 11: Hyperparameter Tuning details and performance

After this hyperparameter tuning I have not obtained a significant increasement on the performance of the models, but it could be a path to research more.

I have encountered some difficulties to find the right functions for every model in the official Spark documentation, so I’ve had to google.

Also the hyperparameter tuning is a high time consuming task and finally I decided to separate the notebook in two, in order to avoid executing all the cells since the very begining.

3. Change features

I realized that features that were time-related introduced multicollinearity to the model, so I didn’t include them as feature for the training.

Figure 12: Features heatmap

4. Augment dataset

In this case it is not a case of few records as we have near 300k in the minidataset, but we can “create” new features from the existing ones.

I have included several features related to the datatime (day, month, year) and also from the userAgent information, I have added the Operating System of the user’s device.

Figure 13: Augmenting data with Operating System

5. Weight balancing

Apply more weight to those features of the dataset that are in minority is another interesting technique to use when we have an imbalanced set. As I have obtained a nice performance, I have not tried this method.

6. Check sizes after the join

After joining all the features, during the Features Engineering phase, it’s ease to forget adding (or event missing) a dropDuplicates to the select command, so you can have a bad set for training. I made this mistake and I realized things didn’t go well when I got long training times and perfect performance of the models (overfitting).

That is why I added a check to get the size after every query (maximum of distinct users is 225, so this is the maximum number of records we should obtained with a query) and I also checked after every query if the churn/not churn proportion maintained during the process. This double check keeps things under control 😉.

Figure 14: Double checking sizes and occurrences

Feature Engineering🤹🏼‍♀️

These are the features I have included as indicators of user churning as I think these features represent the customer experience.

feature_engineering, chosen features:
|-- level: Level of the subscription
|-- n_artists: Number of artists a user has listened
|-- n_songs: Number of songs a user has listened
|-- n_songs_play: Number of songs a user has added to a playlist
|-- n_thumbs_up: Number of thumbs-up a user has clicked
|-- n_thumbs_down: Number of thumbs-down a user has clicked
|-- n_Errors: Number of errors a user experiments
|-- n_friends: Number of friends, a user has added
|-- n_Rolls: Number of advertisements watched
|-- n_Help: Number of visits to the Help page
|-- total_sdays: Subscription days the user has been with Sparkify
|-- n_sessions: Number of sessions a user connects
|-- IPAD: bigint the user connects with an IPAD to our Sparkify
|-- IPHONE: Device the user connect to our service is an IPHONE
|-- LINUX: Device the user connect to our service is a Linux
|-- MAC: Device the user connect to our service is a MAC

I have not included gender as a feature in my model, as I think this characteristic does not determine the probability to leave a music streaming service and including it, could influence in not desired biassed trends because the service has today more male users than women.

After selecting the final features for the model, I have joined them through an outer join. This action brings null values, as there are properties that are presents in some users records and not in others. For instance if a user has never added a friend to the playlist, this field will contain a null. For these cases I have impute these null values with a zero.

Then we have a combination of distinct features, everyone with a different value and range. To avoid that a feature with higher values dominates the others, I have scaled with a standard normal scaler, that is, with the standard variance of every feature.

The next step is to convert the features into a vector, as this is a requirement of Spark ML library.

Model Validation and Evaluation🔮

I have tried several classification models:

  1. Logistic Regression

With the parameters:

maxIter=10 and regParam =0.0

This model took about 4 minutes to train, make the predictions with the test set and print the Confussion matrix and metrics.

2. Decision Trees

With a seed= 5, it took about 5 minutes to train, test and show performance metrics.

3. Gradient Boosted Trees

This model is the slowest, it took 7 minutes, and I used maxDepth = 5, maxIter=10 and seed=42

4. Random Forest

This model took 5 minutes and I used a seed=5

5. Support Vector Machine

The last model took 5 minutes and I set maxIter=10 and regParam=0.01

The best performance (best F1-Score) has been obtained with Random Forest, with a value of 0.74 for F1-Score. In the next section I will explained why I have chosen F1-Score as the more convenient metric for this problem.

All of them have obtained good results and similar running time, although GBT is the one that has a higher running time. Logistic Regression and Decision Trees have had lower running time.

To demonstrate that my model is robust, I have performed a 3-fold cross validation. The validation performance is stable and does not fluctuate much, so I can argue that the model is robust against small perturbations in the training data. But it would be better to test it with the big dataset.

There are several studies where both Random Forest and Gradient Boosted Trees have been pioneered and used in different environments to predict and also to extract the essential features. One of those studies where Random Forest was applied was this one to analyse divorce trends in Germany.

Other investigation applies a variant of Gradient Boosted Trees (XGBoost: Extreme Gradient Boosting) and shows a high prediction accuracy, good stability and high running speed to build diabetes prediction.


Choosing metrics is a crucial task in Machine Learning environment, as it is a way to know if our model works right or not.

The metric depends on the problem we need to resolve and in this case, it is very important to detect exactly the users that are prone to churn, as the Sparkify company will invest money to retain them. If we select a user as potential leaver and this user is not going to churn, we are losing money. Also if we don’t detect a user as potential churner and that user churns, we are also loosing a client and therefore loosing money. So in this environment, it is really important to have both a good precision and also a good recall, as those characteristics are relevant in this scenario. The score that balances equally both recall and precision is F1 Score so I have used it to measure the models.


Let’s review some decisions I took during the implementation:

  • We have an imbalanced dataset (minidataset), with a total of 52 churned users and 173 not churned users. When we split the dataset into training/test, the proportion keeps 37/126 for the train set and 15/47 for the test set, and this leads to unstable performance.
  • There are some features (membership lifetime, number of played songs, time listening music, etc..) that presents multicollinearity and this can introduce overfitting to our model. I have decided not to include several of these features on the final model.
  • Gender issue, although it could be a behavioural feature, I have decided not to add it as a feature so that the final model will be “gender agnostic”. I believe we all can help building an ethical machine learning environment without loosing performance.
  • Random Forest and Gradient Boosted Trees have obtained the best performance as I will explain with more details in the next sections, but they should be applied to the large dataset in order to evaluate the accuracy with a balanced set.
  • To improve the understanding of the data and its relations, it is a good practice to get the Feature Importance over a trained model. This brings us a vector of the features and its index that is the weigth (variance) every feature has on the prediction. I have used the featureImportances function that comes with the Random Forest and also with the Gradient Boosted Trees model of ML Spark.
  • At the begining of this project I made myself several questions and I got the answers from the data. You can explore them on my code but I include here some examples:
Figure 15: Question about membership time and churn relation
Figure 16: Question about average song length
Figure 17: Question about user activities related to churn

Results 🥇

The problem has been partially solved with this mini dataset as the models have a nice accuracy on detecting churn users, but it should be applied to the large dataset to do some adjusts and refinements and get better results.

The models that got better performance were Random Forest and Gradient Boosted Trees.

These are some metrics of the trained models.

Figure 18: Random Forest with the best performance and best F1-Score
Figure 19: Gradient Boosted Tree performance

These are the important features for the Random Forest model:

Figure 20: Feature Importance for Random Forest model

And these are the important features for the GBT model:

Important Features for GDB model
Important Features for GDB model
Figure 21: Feature Importance for Gradient Boosted Trees model

As we can observe in both models, the number of subscription days (membership life) is a crucial feature to determine if a user will churn or not from our Sparkify music streaming service. Also the number of friends the user has and the number of thumbs-down the user clicked. This information is really valuable as we could prepare some actions to influence on users to avoid churn. For instance campaigns with friend recommendations or just setting alerts when the number of subscription days is coming next to the dangerous zone, so we can give those users some privileges, to keep them with Sparkify.


There is room for improvements in several ways:

  • Apply weigths to the training and test of the minidataset, so that we can potentiate the minority set (churn users)
  • Expand the code to the large dataset, this set will have a less imbalanced proportion so sure the performance will be better.
  • Explore different features: we can search new ones and add them to the model to see if we get better results.
  • Enlarge the search of hyperparameter tunning with the large dataset. We can obtain better performance for sure.
  • Apply clustering to the listened songs, to detect different styles of music and add this information to the features. If Sparkify service has few songs of a specific style, maybe there are some users who leave our service for that reason.
  • Analyse the ocurrences of advertising strategy to detect if users leave the service for this reason. We can also investigate errors’ occurences as this could be another root cause for leaving the service.


To predict one of the most common problem of business today, I have made a deep exploratory analysis of the dataset and I have produced several graphical visualizations to have a complete overview.

I have then applied some cleansing and transformation to the initial dataset, in order to prepare it to the training phase.

This is a binary classification problem (churn/not churn), so I have tried different classifications algorithms to discover which of them provide us with the best accuracy.

After that, I have applied some tuning to the models (hyperparameters tuning) although the original model without tuning, obtained better results.

I detected there were several features that were highly correlated after visualizing the features with a heatmap, so it is good to apply some techniques like Principal Component Analysis (PCA), or just drop those features with higher value than a threshold, but I couldn’t invest more time in these tasks.

There have been several challenges in this project and the most important have been:

  1. Working with a big amount of data with Spark means invest much time in executing the commands and also learning a new language (for me).
  2. Working with an imbalanced dataset has derived in some overfitting problems that I had to overcome.

Show me the code! 👩‍💻

If you want to try yourself, just take a look at the code here

I will refactor the code and apply the ML model to the large dataset to check if I get better results, so stay alert 👀 to the second part of this post….

AI, Machine Learning, Deep Learning, Data Science, VR, CAE, HPC, Management.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store