DataPrepare/draw_tools/draw_statics.py

300 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from matplotlib.axes import Axes
from matplotlib.gridspec import GridSpec
from matplotlib.colors import PowerNorm
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import numpy as np
def draw_ax_confusion_matrix(ax:Axes, confusion_matrix, segment_count_matrix, confusion_matrix_percent,
valid_signal_length, total_duration, time_labels, amp_labels, signal_type=''):
# 创建用于热图注释的文本矩阵
text_matrix = np.empty((len(amp_labels), len(time_labels)), dtype=object)
percent_matrix = np.zeros((len(amp_labels), len(time_labels)))
# 填充文本矩阵和百分比矩阵
for i in range(len(amp_labels)):
for j in range(len(time_labels)):
val = confusion_matrix.iloc[i, j]
segment_count = segment_count_matrix[i, j]
percent = confusion_matrix_percent.iloc[i, j]
text_matrix[i, j] = f"[{int(segment_count)}]{int(val)}\n({percent:.2f}%)"
percent_matrix[i, j] = percent
# 绘制热图,调整颜色条位置
sns_heatmap = sns.heatmap(percent_matrix, annot=text_matrix, fmt='',
xticklabels=time_labels, yticklabels=amp_labels,
cmap='YlGnBu', ax=ax, vmin=0, vmax=100,
norm=PowerNorm(gamma=0.5, vmin=0, vmax=100),
cbar_kws={'label': '百分比 (%)', 'shrink': 0.6, 'pad': 0.15},
# annot_kws={'fontsize': 12}
)
# 添加行统计(右侧)
row_sums = confusion_matrix['总计']
row_percents = confusion_matrix_percent['总计']
ax.text(len(time_labels) + 1, -0.5, "各幅值时长\n(有效区间百分比%)", ha='center', va='center')
for i, (val, perc) in enumerate(zip(row_sums, row_percents)):
ax.text(len(time_labels) + 0.5, i + 0.5, f"{int(val)}\n({perc:.2f}%)",
ha='center', va='center')
# 添加列统计(底部)
col_sums = segment_count_matrix.sum(axis=0)
col_percents = confusion_matrix.sum(axis=0) / total_duration * 100
ax.text(-1, len(amp_labels) + 0.5, "[各时长片段数]\n(有效区间百分比%)", ha='center', va='center', rotation=0)
for j, (val, perc) in enumerate(zip(col_sums, col_percents)):
ax.text(j + 0.5, len(amp_labels) + 0.5, f"[{int(val)}]\n({perc:.2f}%)",
ha='center', va='center', rotation=0)
# 将x轴坐标移到顶部
ax.xaxis.set_label_position('top')
ax.xaxis.tick_top()
# 设置标题和标签
ax.set_title('幅值-时长统计矩阵', pad=40)
ax.set_xlabel('持续时间区间 (秒)', labelpad=10)
ax.set_ylabel('幅值区间')
# 设置坐标轴标签水平显示
ax.set_xticklabels(time_labels, rotation=0)
ax.set_yticklabels(amp_labels, rotation=0)
# 调整颜色条位置
cbar = sns_heatmap.collections[0].colorbar
cbar.ax.yaxis.set_label_position('right')
# 添加图例说明
ax.text(-2, -1, "热图内:\n[片段数]时长\n(有效区间百分比%)",
ha='left', va='top', bbox=dict(facecolor='none', edgecolor='black', alpha=0.5))
# 总计
# 总片段数
total_segments = segment_count_matrix.sum()
# 有效总市场占比
total_percent = valid_signal_length / total_duration * 100
ax.text(len(time_labels) + 0.5, len(amp_labels) + 0.5, f"[{int(total_segments)}]{valid_signal_length}\n({total_percent:.2f}%)",
ha='center', va='center')
def draw_ax_amp(ax, signal_name, original_times, origin_signal, no_movement_signal, mav_values,
movement_position_list, low_amp_position_list, signal_second_length, aml_list=None):
# 绘制信号线
ax.plot(original_times, origin_signal, 'k-', linewidth=1, alpha=0.7)
# 添加红色和蓝色的axvspan区域
for start, end in movement_position_list:
if start < len(original_times) and end < len(original_times):
ax.axvspan(start, end, color='red', alpha=0.3)
for start, end in low_amp_position_list:
if start < len(original_times) and end < len(original_times):
ax.axvspan(start, end, color='blue', alpha=0.2)
# 如果存在AML列表绘制水平线
if aml_list is not None:
color_map = ['red', 'orange', 'green']
for i, aml in enumerate(aml_list):
ax.hlines(aml, 0, signal_second_length, color=color_map[min(i, 2)], linestyle='dashed', linewidth=2, alpha=0.5, label=f'{aml} aml')
ax.plot(np.linspace(0, len(mav_values), len(mav_values)), mav_values, color='blue', linewidth=2, alpha=0.4, label='2sMAV')
# 设置Y轴范围
ax.set_ylim((-2000, 2000))
# 创建表示不同颜色区域的图例
red_patch = mpatches.Patch(color='red', alpha=0.2, label='Movement Area')
blue_patch = mpatches.Patch(color='blue', alpha=0.2, label='Low Amplitude Area')
# 添加新的图例项,并保留原来的图例项
handles, labels = ax.get_legend_handles_labels() # 获取原有图例
ax.legend(handles=[red_patch, blue_patch] + handles,
labels=['Movement Area', 'Low Amplitude Area'] + labels,
loc='upper right',
bbox_to_anchor=(1, 1.4),
framealpha=0.5)
# 设置标题和标签
ax.set_title(f'{signal_name} Signal')
ax.set_ylabel('Amplitude')
ax.set_xlabel('Time (s)')
# 启用网格
ax.grid(True, linestyle='--', alpha=0.7)
def draw_signal_metrics(bcg_origin_signal, resp_origin_signal, bcg_no_movement_signal, resp_no_movement_signal,
bcg_sampling_rate, resp_sampling_rate, bcg_movement_position_list, bcg_low_amp_position_list,
resp_movement_position_list, resp_low_amp_position_list,
bcg_mav_values, resp_mav_values, bcg_statistic_info, resp_statistic_info,
signal_info, show=False, save_path=None):
# 创建图像
fig = plt.figure(figsize=(18, 10))
gs = GridSpec(2, 2, height_ratios=[2, 2], width_ratios=[4, 2], hspace=0.5)
signal_second_length = len(bcg_origin_signal) // bcg_sampling_rate
bcg_origin_times = np.linspace(0, signal_second_length, len(bcg_origin_signal))
resp_origin_times = np.linspace(0, signal_second_length, len(resp_origin_signal))
# 子图 1原始信号
ax1 = fig.add_subplot(gs[0])
draw_ax_amp(ax=ax1, signal_name='BCG', original_times=bcg_origin_times, origin_signal=bcg_origin_signal,
no_movement_signal=bcg_no_movement_signal, mav_values=bcg_mav_values,
movement_position_list=bcg_movement_position_list, low_amp_position_list=bcg_low_amp_position_list,
signal_second_length=signal_second_length, aml_list=[200, 500, 1000])
ax2 = fig.add_subplot(gs[1])
param_names = ['confusion_matrix', 'segment_count_matrix', 'confusion_matrix_percent',
'valid_signal_length', 'total_duration', 'time_labels', 'amp_labels']
params = dict(zip(param_names, bcg_statistic_info))
params['ax'] = ax2
draw_ax_confusion_matrix(**params)
ax3 = fig.add_subplot(gs[2], sharex=ax1)
draw_ax_amp(ax=ax3, signal_name='RSEP', original_times=resp_origin_times, origin_signal=resp_origin_signal,
no_movement_signal=resp_no_movement_signal, mav_values=resp_mav_values,
movement_position_list=resp_movement_position_list, low_amp_position_list=resp_low_amp_position_list,
signal_second_length=signal_second_length, aml_list=[100, 300, 500])
ax4 = fig.add_subplot(gs[3])
params = dict(zip(param_names, resp_statistic_info))
params['ax'] = ax4
draw_ax_confusion_matrix(**params)
# 全局标题
fig.suptitle(f'{signal_info} Signal Metrics', fontsize=16, x=0.35, y=0.95)
if save_path is not None:
# 保存图像
plt.savefig(save_path, dpi=300)
if show:
plt.show()
plt.close()
def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, resp_fs, bcg_fs,
signal_disable_mask, resp_low_amp_mask, resp_movement_mask, resp_change_mask,
resp_sa_mask, bcg_low_amp_mask, bcg_movement_mask, bcg_change_mask
):
# 第一行绘制去工频噪声的原始信号,右侧为不可用区间标记,左侧为信号幅值纵坐标
# 第二行绘制呼吸分量右侧低幅值、高幅值、幅值变换标记、SA标签左侧为呼吸幅值纵坐标
# 第三行绘制心冲击分量,右侧为低幅值、高幅值、幅值变换标记、,左侧为心冲击幅值纵坐标
# mask为None则生成全Nan掩码
def _none_to_nan_mask(mask, ref):
if mask is None:
return np.full_like(ref, np.nan)
else:
# 将mask中的0替换为nan其他的保持
mask = np.where(mask == 0, np.nan, mask)
return mask
signal_disable_mask = _none_to_nan_mask(signal_disable_mask, signal_data)
resp_low_amp_mask = _none_to_nan_mask(resp_low_amp_mask, resp_data)
resp_movement_mask = _none_to_nan_mask(resp_movement_mask, resp_data)
resp_change_mask = _none_to_nan_mask(resp_change_mask, resp_data)
resp_sa_mask = _none_to_nan_mask(resp_sa_mask, resp_data)
bcg_low_amp_mask = _none_to_nan_mask(bcg_low_amp_mask, bcg_data)
bcg_movement_mask = _none_to_nan_mask(bcg_movement_mask, bcg_data)
bcg_change_mask = _none_to_nan_mask(bcg_change_mask, bcg_data)
fig = plt.figure(figsize=(18, 10))
ax0 = fig.add_subplot(3, 1, 1)
ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue')
ax0.set_title(f'Sample {samp_id} - Raw Signal Data')
ax0.set_ylabel('Amplitude')
# ax0.set_xticklabels([])
ax0_twin = ax0.twinx()
ax0_twin.plot(np.linspace(0, len(signal_disable_mask), len(signal_disable_mask)), signal_disable_mask,
color='red', alpha=0.5)
ax0_twin.autoscale(enable=False, axis='y', tight=True)
ax0_twin.set_ylim((-2, 2))
ax0_twin.set_ylabel('Disable Mask')
ax0_twin.set_yticks([0, 1])
ax0_twin.set_yticklabels(['Enabled', 'Disabled'])
ax0_twin.grid(False)
ax0_twin.legend(['Disable Mask'], loc='upper right')
ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='orange')
ax1.set_ylabel('Amplitude')
# ax1.set_xticklabels([])
ax1_twin = ax1.twinx()
ax1_twin.plot(np.linspace(0, len(resp_low_amp_mask), len(resp_low_amp_mask)), resp_low_amp_mask*-1,
color='blue', alpha=0.5, label='Low Amplitude Mask')
ax1_twin.plot(np.linspace(0, len(resp_movement_mask), len(resp_movement_mask)), resp_movement_mask*-2,
color='red', alpha=0.5, label='Movement Mask')
ax1_twin.plot(np.linspace(0, len(resp_change_mask), len(resp_change_mask)), resp_change_mask*-3,
color='green', alpha=0.5, label='Amplitude Change Mask')
ax1_twin.plot(np.linspace(0, len(resp_sa_mask), len(resp_sa_mask)), resp_sa_mask,
color='purple', alpha=0.5, label='SA Mask')
ax1_twin.autoscale(enable=False, axis='y', tight=True)
ax1_twin.set_ylim((-4, 5))
# ax1_twin.set_ylabel('Resp Masks')
# ax1_twin.set_yticks([0, 1])
# ax1_twin.set_yticklabels(['No', 'Yes'])
ax1_twin.grid(False)
ax1_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask', 'SA Mask'], loc='upper right')
ax1.set_title(f'Sample {samp_id} - Respiration Component')
ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
ax2.plot(np.linspace(0, len(bcg_data) // bcg_fs, len(bcg_data)), bcg_data, color='green')
ax2.set_ylabel('Amplitude')
ax2.set_xlabel('Time (s)')
ax2_twin = ax2.twinx()
ax2_twin.plot(np.linspace(0, len(bcg_low_amp_mask), len(bcg_low_amp_mask)), bcg_low_amp_mask*-1,
color='blue', alpha=0.5, label='Low Amplitude Mask')
ax2_twin.plot(np.linspace(0, len(bcg_movement_mask), len(bcg_movement_mask)), bcg_movement_mask*-2,
color='red', alpha=0.5, label='Movement Mask')
ax2_twin.plot(np.linspace(0, len(bcg_change_mask), len(bcg_change_mask)), bcg_change_mask*-3,
color='green', alpha=0.5, label='Amplitude Change Mask')
ax2_twin.autoscale(enable=False, axis='y', tight=True)
ax2_twin.set_ylim((-4, 2))
ax2_twin.set_ylabel('BCG Masks')
# ax2_twin.set_yticks([0, 1])
# ax2_twin.set_yticklabels(['No', 'Yes'])
ax2_twin.grid(False)
ax2_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask'], loc='upper right')
# ax2.set_title(f'Sample {samp_id} - BCG Component')
ax0_twin._lim_lock = False
ax1_twin._lim_lock = False
ax2_twin._lim_lock = False
def on_lims_change(event_ax):
if getattr(event_ax, '_lim_lock', False):
return
try:
event_ax._lim_lock = True
if event_ax == ax0_twin:
# 重新锁定 ax1 的 Y 轴范围
ax0_twin.set_ylim(-2, 2)
elif event_ax == ax1_twin:
ax1_twin.set_ylim(-4, 5)
elif event_ax == ax2_twin:
ax2_twin.set_ylim(-4, 2)
finally:
event_ax._lim_lock = False
ax0_twin.callbacks.connect('ylim_changed', on_lims_change)
ax1_twin.callbacks.connect('ylim_changed', on_lims_change)
ax2_twin.callbacks.connect('ylim_changed', on_lims_change)
plt.tight_layout()
plt.show()