Displaying a confusion matrix in TensorFlow

I recently started dabbling in TensorFlow by training an image classifier. My classifier has a few classes, and instead of blindly judging its accuracy, I decided to look at the confusion matrix to see which classes it recognized well and which it didn't.

Multiclass classifier

In TensorFlow, there's a function that computes it for you. All you need to do is to pass the predictions and the labels, and it'll compute the matrix for you. See https://www.tensorflow.org/api_docs/python/tf/math/confusion_matrix.

Computing predictions

Then I have to compute the predictions using my validator_generator. My validator_generator generates batches of images (x_batch) and labels (y_batch). The model generates an array of predictions of shape (batch_size, num_classes). However confusion_matrix expects predictions to be of shape (batch_size) with each element specifying the class as a digit from 0 to 1. So I use np.argmax to do that. axis = 1 means that for each row, it'll look at each column and pick the largest value. So if the row is [.1, .15, .15, .5, .05, .05], argmax will return 3.

Looping through the batches

Initially I used a pretty verbose loop. Here's the more concise way I ended up using: http://blog.wafrat.com/survey-of-keras-image-related-generators/#looping-through-the-data-once.

Collecting results

Finally to collect all batched predictions and labels, I had initially written something like:

# Initialize an empty Python list
all_labels = []
for ...:
    ...
    # Manipulate the batch's labels np array
    labels = np.argmax(y_batch, axis = 1)
    # Convert it to a Python list and concatenate with all_labels
    all_labels = all_labels + labels.tolist()

It's seems a bit inefficient to convert to lists every time, especially since confusion_matrix supports both types. To concatenate numpy arrays, I've found two functions: hstack (doc) and concatenate (doc). Here's the labels collection code that uses exclusively numpy arrays:

# Initialize an empty numpy array
all_labels = np.array([])
for ...:
    ...
    # Manipulate the batch's labels np array
    labels = np.argmax(y_batch, axis = 1)
    # Convert it to a Python list and concatenate with all_labels
    all_labels = np.concatenate([all_labels + labels])

Putting it all together

Putting it all together, I got:

def get_confusion_matrix(model, validation_generator):
    all_predictions = np.array([])
    all_labels = np.array([])
    for i in range(len(validation_generator)):
        x_batch, y_batch = validation_generator[i]
        predictions = model.predict(x_batch)
        predictions = np.argmax(predictions, axis = 1)
        all_predictions = np.concatenate([all_predictions, predictions])
        labels = np.argmax(y_batch, axis = 1)
        all_labels = np.concatenate([all_labels, labels])

    return tf.math.confusion_matrix(all_predictions, all_labels)

get_confusion_matrix(validation_generator, model, validation_size / batch_size)
<tf.Tensor: shape=(6, 6), dtype=int32, numpy=
array([[77,  0,  7,  6,  5,  5],
       [ 0,  8,  0,  0,  0,  1],
       [ 1,  0,  2,  0,  2,  0],
       [ 0,  0,  0,  0,  2,  0],
       [ 5,  0,  0,  0,  0,  0],
       [ 1,  0,  0,  0,  0,  2]], dtype=int32)>

validation_generator.class_indices
{'conversation': 0,
 'credits': 1,
 'gun_fight': 2,
 'sword_fight': 3,
 'race': 4,
 'space': 5}

This is a very basic confusion matrix but it does the job. Judging from this matrix, I can see that my data is highly imbalanced: most of my images consist in conversations. As a result, the model mostly predicts everything as a conversation. It is also great at recognizing credits scenes, but not much else.

The documentation also links to a few tutorials that use the function and build on top of that to display a much more elaborate matrix. Their matrix is an actual image, where each axis denotes whether it's the label or prediction, and each column and row shows what class it represents. Finally on top of numerical values they use a colormap to highlight large values. See https://www.tensorflow.org/tutorials/audio/simple_audio#display_a_confusion_matrix for example. To build it, they initialize a plot with the axis names, then use a library called seaborne. https://seaborn.pydata.org/generated/seaborn.heatmap.html. Neat. I might use that later.

Elsewhere, I've seen people use scikit-learn's ConfusionMatrixDisplay. See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.ConfusionMatrixDisplay.html#sklearn.metrics.ConfusionMatrixDisplay. It produces an output similar to seaborn's heatmap.

Binary classifier

More recently I had to display the confusion matrix for a binary classifier. I was able to reuse most of the code with a few tweaks. The main difference is that the prediction is not an array containing a number for each class. Instead the shape is (batch_size, 1), ie it returns only one number from 0 to 1. The batch labels is also an array of 0 and 1's. So instead of using argmax, I should convert the prediction to 0 or 1 depending on a chosen threshold. I arbitrarily picked .5 to start.

My initial attempt looked like this:

def to_one_or_zero(bool):
    return 1 if bool else 0

def get_confusion_matrix(model, validation_generator):
    ...
    predictions = list(map(to_one_or_zero, predictions > .5))
    ...

What the line does is:

  • convert the array of numbers from 0 to 1 into an array of booleans, depending on whether the number is higher than the threshold or not.
  • for each element, apply to_one_or_zero to output 0 or 1 whether the element is False or True.
  • convert to a list.

It turns out, multiplying a Boolean by 1 conveniently turns it into 0 or 1, and multiplying a numpy array by a scalar does an element-wise multipication. So np.array([True, False, False]) * 1 returns array([1, 0, 0]).

So the computation of predictions  can be simplified to (predictions > .5) * 1. The final version of the code now looks like this:

import numpy as np

def get_confusion_matrix(model, validation_generator):
    is_binary = len(validation_generator.class_indices) == 2
    all_predictions = np.array([])
    all_labels = np.array([])
    for i in range(len(validation_generator)):
        x_batch, y_batch = validation_generator[i]
        predictions = model.predict(x_batch)
        if (not is_binary):
            predictions = np.argmax(predictions, axis=1)
        else:
            predictions = (predictions > .5) * 1
        all_predictions = np.concatenate([all_predictions, predictions])

        if (not is_binary):
            labels = np.argmax(y_batch, axis = 1)
        else:
            labels = y_batch
        all_labels = np.concatenate([all_labels, labels])

    return tf.math.confusion_matrix(all_predictions, all_labels)

And prints a matrix that looks like this:

get_confusion_matrix(validation_generator, model)
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[252,  18],
       [ 58,  33]], dtype=int32)>