#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:Marques
@file:Draw_ConfusionMatrix.py
@email:admin@marques22.com
@email:2021022362@m.scnu.edu.cn
@time:2022/08/10
"""
import numpy as np
from matplotlib import pyplot as plt

plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号


def draw_confusionMatrix(cm, classes, title, save_path, cmap=plt.cm.Blues):
    fig_cm, ax = plt.subplots(figsize=(8, 8), dpi=120)
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           xticklabels=classes, yticklabels=classes,
           title=title,
           ylabel='True label',
           xlabel='Predicted label')
    ax.set_ylim(len(classes) - 0.5, -0.5)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    normalize = False
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() * 0.8
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig_cm.tight_layout()
    fig_cm.savefig(save_path)
    plt.close()
    #


if __name__ == '__main__':
    pass