How to plot a confusion matrix using scikit-learn and matplotlib in Python.

share link

by vigneshchennai74 dot icon Updated: Feb 27, 2023

technology logo
technology logo

Solution Kit Solution Kit  

A confusion matrix is a table that outlines the performance of a classification algorithm by comparing the predicted class labels to the true class labels. It provides a quick and intuitive way to assess the accuracy of a classification model and identify any patterns in the errors made by the model. 


The scikit-learn library provides a function called confusion_matrix that can generate a confusion matrix from a set of true and predicted class labels. The plot_confusion_matrix function defined in the script takes this confusion matrix and plots it using the matplotlib library, making it easier to visualize and interpret. 


This script can be helpful in many machine learning tasks, particularly in cases where we need to evaluate the performance of a classification model. It can also be useful in explaining the model's performance to non-technical stakeholders, as the visual representation of the confusion matrix is easy to understand. By plotting the confusion matrix, we can easily identify which classes the model is misclassifying and adjust the model or data accordingly. 


Here is an example of how to plot a confusion matrix using scikit-learn and matplotlib in Python. 

Preview of the output that you will get on running this code from your IDE

Code

In this solution we have used Sklearn library.

y_test = np.array([0, 1, 0, 1])
y_train = np.array([0, 0, 1, 1])

y_test_pred = np.array([1, 1, 0, 1])  # from classifier_logistic.predict(x_test)
y_train_pred = np.array([0, 1, 0, 1]) # from classifier_logistic.predict(x_train)

y_true = np.concatenate((y_train, y_test))  # you already have this as `Y`
y_pred = np.concatenate((y_train_pred, y_test_pred))

import itertools
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix

# Source: http://scikit-learn.org/stable/auto_examples/model_selection/
#         plot_confusion_matrix.html#confusion-matrix


y_test = np.array([1, 1, 0, 1])
y_train = np.array([0, 0, 1, 1])

y_test_pred = np.array([1, 1, 0, 1])  # from classifier_logistic.predict(x_test)
y_train_pred = np.array([0, 1, 0, 1]) # from classifier_logistic.predict(x_train)

y_true = np.concatenate((y_train, y_test))
y_pred = np.concatenate((y_train_pred, y_test_pred))

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

cm = confusion_matrix(y_true, y_pred)
np.set_printoptions(precision=2)

plt.figure()
plot_confusion_matrix(cm, classes=[0, 1],
                      title='Confusion matrix')
  1. Copy the code using the "Copy" button above, and paste it in a Python file in your IDE.
  2. Run the file to get the output


I hope you found this useful. I have added the link to dependent libraries, version information in the following sections.


I found this code snippet by searching for "sklearn: Plot confusion matrix combined across training test sets" in kandi. You can try any such use case!


note


At the end of the code type plt.show(), it display your confusion matrix.

Environment Tested

I tested this solution in the following versions. Be mindful of changes when working with other versions.

  1. The solution is created and tested in Vscode 1.75.1 version
  2. The solution is created in Python 3.7.15 version
  3. The solution is tested on scikit-learn 1.0.2 version


Using this solution, we are able to combine the test and train accuracy in confusion matrix using Scikit learn library in Python with simple steps. This process also facilities an easy to use, hassle free method to create a hands-on working version of code which would help combine the accuracyin confusion matrix in Python.

Dependent Library

scikit-learnby scikit-learn

Python doticonstar image 54584 doticonVersion:1.2.2doticon
License: Permissive (BSD-3-Clause)

scikit-learn: machine learning in Python

Support
    Quality
      Security
        License
          Reuse

            scikit-learnby scikit-learn

            Python doticon star image 54584 doticonVersion:1.2.2doticon License: Permissive (BSD-3-Clause)

            scikit-learn: machine learning in Python
            Support
              Quality
                Security
                  License
                    Reuse

                      If you do not have Scikit-learn that is required to run this code, you can install it by clicking on the above link and copying the pip Install command from the Scikit-learn page in kandi.

                      You can search for any dependent library on kandi like Scikit-learn.

                      Support

                      1. For any support on kandi solution kits, please use the chat
                      2. For further learning resources, visit the Open Weaver Community learning page.


                      See similar Kits and Libraries