sleep_apnea_hybrid/exam/042/utils/Draw_ConfusionMatrix.py

47 lines
1.4 KiB
Python

#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:Marques
@file:Draw_ConfusionMatrix.py
@email:admin@marques22.com
@email:2021022362@m.scnu.edu.cn
@time:2022/08/10
"""
import numpy as np
from matplotlib import pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
def draw_confusionMatrix(cm, classes, title, save_path, cmap=plt.cm.Blues):
fig_cm, ax = plt.subplots(figsize=(8, 8), dpi=120)
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
xticklabels=classes, yticklabels=classes,
title=title,
ylabel='True label',
xlabel='Predicted label')
ax.set_ylim(len(classes) - 0.5, -0.5)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
normalize = False
fmt = '.2f' if normalize else 'd'
thresh = cm.max() * 0.8
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig_cm.tight_layout()
fig_cm.savefig(save_path)
plt.close()
#
if __name__ == '__main__':
pass