47 lines
1.4 KiB
Python
47 lines
1.4 KiB
Python
#!/usr/bin/python
|
|
# -*- coding: UTF-8 -*-
|
|
"""
|
|
@author:Marques
|
|
@file:Draw_ConfusionMatrix.py
|
|
@email:admin@marques22.com
|
|
@email:2021022362@m.scnu.edu.cn
|
|
@time:2022/08/10
|
|
"""
|
|
import numpy as np
|
|
from matplotlib import pyplot as plt
|
|
|
|
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
|
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
|
|
|
|
|
def draw_confusionMatrix(cm, classes, title, save_path, cmap=plt.cm.Blues):
|
|
fig_cm, ax = plt.subplots(figsize=(8, 8), dpi=120)
|
|
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
|
|
ax.figure.colorbar(im, ax=ax)
|
|
ax.set(xticks=np.arange(cm.shape[1]),
|
|
yticks=np.arange(cm.shape[0]),
|
|
xticklabels=classes, yticklabels=classes,
|
|
title=title,
|
|
ylabel='True label',
|
|
xlabel='Predicted label')
|
|
ax.set_ylim(len(classes) - 0.5, -0.5)
|
|
|
|
# Rotate the tick labels and set their alignment.
|
|
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
|
normalize = False
|
|
fmt = '.2f' if normalize else 'd'
|
|
thresh = cm.max() * 0.8
|
|
for i in range(cm.shape[0]):
|
|
for j in range(cm.shape[1]):
|
|
ax.text(j, i, format(cm[i, j], fmt),
|
|
ha="center", va="center",
|
|
color="white" if cm[i, j] > thresh else "black")
|
|
fig_cm.tight_layout()
|
|
fig_cm.savefig(save_path)
|
|
plt.close()
|
|
#
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pass
|