#!/usr/bin/python # -*- coding: UTF-8 -*- """ @author:Marques @file:Draw_ConfusionMatrix.py @email:admin@marques22.com @email:2021022362@m.scnu.edu.cn @time:2022/08/10 """ import numpy as np from matplotlib import pyplot as plt plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 def draw_confusionMatrix(cm, classes, title, save_path, cmap=plt.cm.Blues): fig_cm, ax = plt.subplots(figsize=(8, 8), dpi=120) im = ax.imshow(cm, interpolation='nearest', cmap=cmap) ax.figure.colorbar(im, ax=ax) ax.set(xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]), xticklabels=classes, yticklabels=classes, title=title, ylabel='True label', xlabel='Predicted label') ax.set_ylim(len(classes) - 0.5, -0.5) # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") normalize = False fmt = '.2f' if normalize else 'd' thresh = cm.max() * 0.8 for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text(j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black") fig_cm.tight_layout() fig_cm.savefig(save_path) plt.close() # if __name__ == '__main__': pass