Data Science 101: Creating the Perfect Confusion Matrix

Sam Dedes
4 min readFeb 7, 2021

--

This article outlines the Python code to create a confusion matrix visualization function and explores the different colormaps available via Matplotlib.

Ha! You’ve done it! You’ve run your model countless times and through the blood, sweat and tears you’ve found a viable model. Now all that’s left is to present the results. What better way than to use your old friend, the confusion matrix? At this point we’re familiar with the meaning behind a confusion matrix, so let’s go about pairing the numbers with a decent visualization. After all, you simply can’t go to your client with the raw output of your IDE.

Raw Confusion Matrix: Good for Coders, Bad for Business

Begin by importing the necessary packages. In this case, pandas is being used only to format the data. Matplotlib provides a plethora of options when dealing with plots, while Seaborn has a particularly helpful “heatmap” that is perfect for visualizing confusion matrices.

# import packages
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

Convert the current confusion matrix into a pandas dataframe. Here the “index” and “columns” represent the different classifications your model is making. To keep this general, they’re labelled ‘0’ and ‘1’ for a simple binary classificaion.

# confusion matrix data from model
cm = [[454, 17], [8, 521]]
# convert to pandas dataframe
cm = pd.DataFrame(cm, index=['0', '1'], columns=['0', '1'])
####################################################################
Returns:
See picture below
Pandas dataframe of confusion Matrix

This is beginning to look more like a confusion matrix, and this is where Seaborn’s heatmap does the heavy lifting. Simply input the data into sns.heatmap() and turn annotations on with annot= True.

# plot confusion matrix
sns.heatmap(cm, annot=True)
####################################################################Returns:
<AxesSubplot:>
See picture below
Seaborn Heatmap

In this one line of code, the classifications have been labelled along the left and bottom, the values are displayed in each square and even a color bar as an additional label. However, notice the numbers are being represented in scientific format, and our matrix isn’t square. Luckily, by adding in a few keyword arguments, you can fix that and add some other aesthetic adjustments along the way.

# plot confusion matrix with various kwargs
sns.heatmap(cm, annot=True, linecolor='grey', linewidth=1, fmt='', xticklabels=labels, yticklabels=labels, square=True)
####################################################################Returns:
<AxesSubplot:>
See picture below
Seaborn Heatmap with keyword arguments

This is looking more professional, and that’s with only a few lines of code. While it might be tempting to repeat the few lines of code above for each confusion matrix you make, in practice it’s worthwhile to make a function that does the repetition for you.

The function should only require data from a confusion matrix, and we’ll leave ourselves some other options via keyword arguments.

# create confusion matrix function
def plot_confusion_matrix(cm, labels=[‘0’,’1'], cmap=None, title=None, show=True, path=None):

There are a few items here, begin with the first:

cm — confusion matrix data, the only positional argument

labels — labels for the confusion matrix classes
default: 0 and 1 as seen in the image above

cmap — colormap/ color scheme of the heatmap
default: None

title — title for the confusion matrix plot
default: None

show — option to show image, output can be suppressed by setting to 0, False, or None
default: True

path — path to save image, image will not save without a path
default: None

Now that the keyword arguments are defined, the function can be completed. The first few lines are similar to the above confusion matrix.

# create confusion matrix
cm = pd.DataFrame(cm, index=labels, columns=labels)

# plot confusion matrix
sns.heatmap(cm, cmap=cmap, linecolor='grey', linewidth=1, annot=True, fmt='', xticklabels=labels, yticklabels=labels, square=True)

if title:
plt.title(title)

if path:
plt.savefig(f'{path}', transparent=True)

if show:
plt.show()

Here’s what this looks like all together:

# create confusion matrix function
def plot_confusion_matrix(cm, labels=['0','1'], cmap=None, title=None, show=True, path=None):
"""
Takes raw confusion matrix data and returns visualization.
"""
# create confusion matrix
cm = pd.DataFrame(cm, index=labels, columns=labels)

# plot confusion matrix
sns.heatmap(cm, cmap=cmap, linecolor='grey', linewidth=1, annot=True, fmt='', xticklabels=labels, yticklabels=labels, square=True)

if title:
plt.title(title)

if path:
plt.savefig(f'{path}', transparent=True)

if show:
plt.show()

Here’s an example with using some of the keyword arguments:

plot_confusion_matrix([(1, 2), (3, 4)], labels=['Class 0', 'Class 1'], title='Example Title', path='confused_matrix')
####################################################################Returns:
See picture below
Custom function yields confusion matrix above

Now the power of the function comes to fruition, as there’s now a simple way to create a visualization with our function, and customize our labels, paths, and more.

For many applications, this is good enough, however, if you’ve discovered a passion for pushing the limits of what it means to visualize, or need further tweaks to create the perfect visual for a presentation, you can always go further. With the various colormaps within Matplotlib’s library, there are dozens of options for color schemes, and I encourage you to explore those; perhaps even with a custom function.

--

--

Sam Dedes
Sam Dedes

No responses yet