DataPrepare/draw_tools/draw_statics.py

175 lines
7.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from matplotlib.axes import Axes
from matplotlib.gridspec import GridSpec
from matplotlib.colors import PowerNorm
import seaborn as sns
import numpy as np
def draw_ax_confusion_matrix(ax:Axes, confusion_matrix, segment_count_matrix, confusion_matrix_percent,
valid_signal_length, total_duration, time_labels, amp_labels, signal_type=''):
# 创建用于热图注释的文本矩阵
text_matrix = np.empty((len(amp_labels), len(time_labels)), dtype=object)
percent_matrix = np.zeros((len(amp_labels), len(time_labels)))
# 填充文本矩阵和百分比矩阵
for i in range(len(amp_labels)):
for j in range(len(time_labels)):
val = confusion_matrix.iloc[i, j]
segment_count = segment_count_matrix[i, j]
percent = confusion_matrix_percent.iloc[i, j]
text_matrix[i, j] = f"[{int(segment_count)}]{int(val)}\n({percent:.2f}%)"
percent_matrix[i, j] = percent
# 绘制热图,调整颜色条位置
sns_heatmap = sns.heatmap(percent_matrix, annot=text_matrix, fmt='',
xticklabels=time_labels, yticklabels=amp_labels,
cmap='YlGnBu', ax=ax, vmin=0, vmax=100,
norm=PowerNorm(gamma=0.5, vmin=0, vmax=100),
cbar_kws={'label': '百分比 (%)', 'shrink': 0.6, 'pad': 0.15},
# annot_kws={'fontsize': 12}
)
# 添加行统计(右侧)
row_sums = confusion_matrix['总计']
row_percents = confusion_matrix_percent['总计']
ax.text(len(time_labels) + 1, -0.5, "各幅值时长\n(有效区间百分比%)", ha='center', va='center')
for i, (val, perc) in enumerate(zip(row_sums, row_percents)):
ax.text(len(time_labels) + 0.5, i + 0.5, f"{int(val)}\n({perc:.2f}%)",
ha='center', va='center')
# 添加列统计(底部)
col_sums = segment_count_matrix.sum(axis=0)
col_percents = confusion_matrix.sum(axis=0) / total_duration * 100
ax.text(-1, len(amp_labels) + 0.5, "[各时长片段数]\n(有效区间百分比%)", ha='center', va='center', rotation=0)
for j, (val, perc) in enumerate(zip(col_sums, col_percents)):
ax.text(j + 0.5, len(amp_labels) + 0.5, f"[{int(val)}]\n({perc:.2f}%)",
ha='center', va='center', rotation=0)
# 将x轴坐标移到顶部
ax.xaxis.set_label_position('top')
ax.xaxis.tick_top()
# 设置标题和标签
ax.set_title('幅值-时长统计矩阵', pad=40)
ax.set_xlabel('持续时间区间 (秒)', labelpad=10)
ax.set_ylabel('幅值区间')
# 设置坐标轴标签水平显示
ax.set_xticklabels(time_labels, rotation=0)
ax.set_yticklabels(amp_labels, rotation=0)
# 调整颜色条位置
cbar = sns_heatmap.collections[0].colorbar
cbar.ax.yaxis.set_label_position('right')
# 添加图例说明
ax.text(-2, -1, "热图内:\n[片段数]时长\n(有效区间百分比%)",
ha='left', va='top', bbox=dict(facecolor='none', edgecolor='black', alpha=0.5))
# 总计
# 总片段数
total_segments = segment_count_matrix.sum()
# 有效总市场占比
total_percent = valid_signal_length / total_duration * 100
ax.text(len(time_labels) + 0.5, len(amp_labels) + 0.5, f"[{int(total_segments)}]{valid_signal_length}\n({total_percent:.2f}%)",
ha='center', va='center')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
def draw_ax_amp(ax, signal_name, original_times, origin_signal, no_movement_signal, mav_values,
movement_position_list, low_amp_position_list, signal_second_length, aml_list=None):
# 绘制信号线
ax.plot(original_times, origin_signal, 'k-', linewidth=1, alpha=0.7)
# 添加红色和蓝色的axvspan区域
for start, end in movement_position_list:
if start < len(original_times) and end < len(original_times):
ax.axvspan(start, end, color='red', alpha=0.3)
for start, end in low_amp_position_list:
if start < len(original_times) and end < len(original_times):
ax.axvspan(start, end, color='blue', alpha=0.2)
# 如果存在AML列表绘制水平线
if aml_list is not None:
color_map = ['red', 'orange', 'green']
for i, aml in enumerate(aml_list):
ax.hlines(aml, 0, signal_second_length, color=color_map[min(i, 2)], linestyle='dashed', linewidth=2, alpha=0.5, label=f'{aml} aml')
ax.plot(np.linspace(0, len(mav_values), len(mav_values)), mav_values, color='blue', linewidth=2, alpha=0.4, label='2sMAV')
# 设置Y轴范围
ax.set_ylim((-2000, 2000))
# 创建表示不同颜色区域的图例
red_patch = mpatches.Patch(color='red', alpha=0.2, label='Movement Area')
blue_patch = mpatches.Patch(color='blue', alpha=0.2, label='Low Amplitude Area')
# 添加新的图例项,并保留原来的图例项
handles, labels = ax.get_legend_handles_labels() # 获取原有图例
ax.legend(handles=[red_patch, blue_patch] + handles,
labels=['Movement Area', 'Low Amplitude Area'] + labels,
loc='upper right',
bbox_to_anchor=(1, 1.4),
framealpha=0.5)
# 设置标题和标签
ax.set_title(f'{signal_name} Signal')
ax.set_ylabel('Amplitude')
ax.set_xlabel('Time (s)')
# 启用网格
ax.grid(True, linestyle='--', alpha=0.7)
def draw_signal_metrics(bcg_origin_signal, resp_origin_signal, bcg_no_movement_signal, resp_no_movement_signal,
bcg_sampling_rate, resp_sampling_rate, bcg_movement_position_list, bcg_low_amp_position_list,
resp_movement_position_list, resp_low_amp_position_list,
bcg_mav_values, resp_mav_values, bcg_statistic_info, resp_statistic_info,
signal_info, show=False, save_path=None):
# 创建图像
fig = plt.figure(figsize=(18, 10))
gs = GridSpec(2, 2, height_ratios=[2, 2], width_ratios=[4, 2], hspace=0.5)
signal_second_length = len(bcg_origin_signal) // bcg_sampling_rate
bcg_origin_times = np.linspace(0, signal_second_length, len(bcg_origin_signal))
resp_origin_times = np.linspace(0, signal_second_length, len(resp_origin_signal))
# 子图 1原始信号
ax1 = fig.add_subplot(gs[0])
draw_ax_amp(ax=ax1, signal_name='BCG', original_times=bcg_origin_times, origin_signal=bcg_origin_signal,
no_movement_signal=bcg_no_movement_signal, mav_values=bcg_mav_values,
movement_position_list=bcg_movement_position_list, low_amp_position_list=bcg_low_amp_position_list,
signal_second_length=signal_second_length, aml_list=[200, 500, 1000])
ax2 = fig.add_subplot(gs[1])
param_names = ['confusion_matrix', 'segment_count_matrix', 'confusion_matrix_percent',
'valid_signal_length', 'total_duration', 'time_labels', 'amp_labels']
params = dict(zip(param_names, bcg_statistic_info))
params['ax'] = ax2
draw_ax_confusion_matrix(**params)
ax3 = fig.add_subplot(gs[2], sharex=ax1)
draw_ax_amp(ax=ax3, signal_name='RSEP', original_times=resp_origin_times, origin_signal=resp_origin_signal,
no_movement_signal=resp_no_movement_signal, mav_values=resp_mav_values,
movement_position_list=resp_movement_position_list, low_amp_position_list=resp_low_amp_position_list,
signal_second_length=signal_second_length, aml_list=[100, 300, 500])
ax4 = fig.add_subplot(gs[3])
params = dict(zip(param_names, resp_statistic_info))
params['ax'] = ax4
draw_ax_confusion_matrix(**params)
# 全局标题
fig.suptitle(f'{signal_info} Signal Metrics', fontsize=16, x=0.35, y=0.95)
if save_path is not None:
# 保存图像
plt.savefig(save_path, dpi=300)
if show:
plt.show()
plt.close()