300 lines
13 KiB
Python
300 lines
13 KiB
Python
from matplotlib.axes import Axes
|
||
from matplotlib.gridspec import GridSpec
|
||
from matplotlib.colors import PowerNorm
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.patches as mpatches
|
||
import seaborn as sns
|
||
import numpy as np
|
||
|
||
|
||
def draw_ax_confusion_matrix(ax:Axes, confusion_matrix, segment_count_matrix, confusion_matrix_percent,
|
||
valid_signal_length, total_duration, time_labels, amp_labels, signal_type=''):
|
||
# 创建用于热图注释的文本矩阵
|
||
text_matrix = np.empty((len(amp_labels), len(time_labels)), dtype=object)
|
||
percent_matrix = np.zeros((len(amp_labels), len(time_labels)))
|
||
|
||
# 填充文本矩阵和百分比矩阵
|
||
for i in range(len(amp_labels)):
|
||
for j in range(len(time_labels)):
|
||
val = confusion_matrix.iloc[i, j]
|
||
segment_count = segment_count_matrix[i, j]
|
||
percent = confusion_matrix_percent.iloc[i, j]
|
||
text_matrix[i, j] = f"[{int(segment_count)}]{int(val)}\n({percent:.2f}%)"
|
||
percent_matrix[i, j] = percent
|
||
|
||
# 绘制热图,调整颜色条位置
|
||
sns_heatmap = sns.heatmap(percent_matrix, annot=text_matrix, fmt='',
|
||
xticklabels=time_labels, yticklabels=amp_labels,
|
||
cmap='YlGnBu', ax=ax, vmin=0, vmax=100,
|
||
norm=PowerNorm(gamma=0.5, vmin=0, vmax=100),
|
||
cbar_kws={'label': '百分比 (%)', 'shrink': 0.6, 'pad': 0.15},
|
||
# annot_kws={'fontsize': 12}
|
||
)
|
||
|
||
# 添加行统计(右侧)
|
||
row_sums = confusion_matrix['总计']
|
||
row_percents = confusion_matrix_percent['总计']
|
||
ax.text(len(time_labels) + 1, -0.5, "各幅值时长\n(有效区间百分比%)", ha='center', va='center')
|
||
for i, (val, perc) in enumerate(zip(row_sums, row_percents)):
|
||
ax.text(len(time_labels) + 0.5, i + 0.5, f"{int(val)}\n({perc:.2f}%)",
|
||
ha='center', va='center')
|
||
|
||
# 添加列统计(底部)
|
||
col_sums = segment_count_matrix.sum(axis=0)
|
||
col_percents = confusion_matrix.sum(axis=0) / total_duration * 100
|
||
ax.text(-1, len(amp_labels) + 0.5, "[各时长片段数]\n(有效区间百分比%)", ha='center', va='center', rotation=0)
|
||
for j, (val, perc) in enumerate(zip(col_sums, col_percents)):
|
||
ax.text(j + 0.5, len(amp_labels) + 0.5, f"[{int(val)}]\n({perc:.2f}%)",
|
||
ha='center', va='center', rotation=0)
|
||
|
||
# 将x轴坐标移到顶部
|
||
ax.xaxis.set_label_position('top')
|
||
ax.xaxis.tick_top()
|
||
|
||
# 设置标题和标签
|
||
ax.set_title('幅值-时长统计矩阵', pad=40)
|
||
ax.set_xlabel('持续时间区间 (秒)', labelpad=10)
|
||
ax.set_ylabel('幅值区间')
|
||
|
||
# 设置坐标轴标签水平显示
|
||
ax.set_xticklabels(time_labels, rotation=0)
|
||
ax.set_yticklabels(amp_labels, rotation=0)
|
||
|
||
# 调整颜色条位置
|
||
cbar = sns_heatmap.collections[0].colorbar
|
||
cbar.ax.yaxis.set_label_position('right')
|
||
|
||
# 添加图例说明
|
||
ax.text(-2, -1, "热图内:\n[片段数]时长\n(有效区间百分比%)",
|
||
ha='left', va='top', bbox=dict(facecolor='none', edgecolor='black', alpha=0.5))
|
||
|
||
# 总计
|
||
# 总片段数
|
||
total_segments = segment_count_matrix.sum()
|
||
# 有效总市场占比
|
||
total_percent = valid_signal_length / total_duration * 100
|
||
ax.text(len(time_labels) + 0.5, len(amp_labels) + 0.5, f"[{int(total_segments)}]{valid_signal_length}\n({total_percent:.2f}%)",
|
||
ha='center', va='center')
|
||
|
||
|
||
|
||
def draw_ax_amp(ax, signal_name, original_times, origin_signal, no_movement_signal, mav_values,
|
||
movement_position_list, low_amp_position_list, signal_second_length, aml_list=None):
|
||
# 绘制信号线
|
||
ax.plot(original_times, origin_signal, 'k-', linewidth=1, alpha=0.7)
|
||
|
||
# 添加红色和蓝色的axvspan区域
|
||
for start, end in movement_position_list:
|
||
if start < len(original_times) and end < len(original_times):
|
||
ax.axvspan(start, end, color='red', alpha=0.3)
|
||
|
||
for start, end in low_amp_position_list:
|
||
if start < len(original_times) and end < len(original_times):
|
||
ax.axvspan(start, end, color='blue', alpha=0.2)
|
||
|
||
# 如果存在AML列表,绘制水平线
|
||
if aml_list is not None:
|
||
color_map = ['red', 'orange', 'green']
|
||
for i, aml in enumerate(aml_list):
|
||
ax.hlines(aml, 0, signal_second_length, color=color_map[min(i, 2)], linestyle='dashed', linewidth=2, alpha=0.5, label=f'{aml} aml')
|
||
|
||
|
||
ax.plot(np.linspace(0, len(mav_values), len(mav_values)), mav_values, color='blue', linewidth=2, alpha=0.4, label='2sMAV')
|
||
|
||
# 设置Y轴范围
|
||
ax.set_ylim((-2000, 2000))
|
||
|
||
# 创建表示不同颜色区域的图例
|
||
red_patch = mpatches.Patch(color='red', alpha=0.2, label='Movement Area')
|
||
blue_patch = mpatches.Patch(color='blue', alpha=0.2, label='Low Amplitude Area')
|
||
|
||
# 添加新的图例项,并保留原来的图例项
|
||
handles, labels = ax.get_legend_handles_labels() # 获取原有图例
|
||
ax.legend(handles=[red_patch, blue_patch] + handles,
|
||
labels=['Movement Area', 'Low Amplitude Area'] + labels,
|
||
loc='upper right',
|
||
bbox_to_anchor=(1, 1.4),
|
||
framealpha=0.5)
|
||
|
||
# 设置标题和标签
|
||
ax.set_title(f'{signal_name} Signal')
|
||
ax.set_ylabel('Amplitude')
|
||
ax.set_xlabel('Time (s)')
|
||
|
||
# 启用网格
|
||
ax.grid(True, linestyle='--', alpha=0.7)
|
||
|
||
|
||
def draw_signal_metrics(bcg_origin_signal, resp_origin_signal, bcg_no_movement_signal, resp_no_movement_signal,
|
||
bcg_sampling_rate, resp_sampling_rate, bcg_movement_position_list, bcg_low_amp_position_list,
|
||
resp_movement_position_list, resp_low_amp_position_list,
|
||
bcg_mav_values, resp_mav_values, bcg_statistic_info, resp_statistic_info,
|
||
signal_info, show=False, save_path=None):
|
||
|
||
# 创建图像
|
||
fig = plt.figure(figsize=(18, 10))
|
||
|
||
gs = GridSpec(2, 2, height_ratios=[2, 2], width_ratios=[4, 2], hspace=0.5)
|
||
|
||
signal_second_length = len(bcg_origin_signal) // bcg_sampling_rate
|
||
bcg_origin_times = np.linspace(0, signal_second_length, len(bcg_origin_signal))
|
||
resp_origin_times = np.linspace(0, signal_second_length, len(resp_origin_signal))
|
||
# 子图 1:原始信号
|
||
ax1 = fig.add_subplot(gs[0])
|
||
draw_ax_amp(ax=ax1, signal_name='BCG', original_times=bcg_origin_times, origin_signal=bcg_origin_signal,
|
||
no_movement_signal=bcg_no_movement_signal, mav_values=bcg_mav_values,
|
||
movement_position_list=bcg_movement_position_list, low_amp_position_list=bcg_low_amp_position_list,
|
||
signal_second_length=signal_second_length, aml_list=[200, 500, 1000])
|
||
|
||
ax2 = fig.add_subplot(gs[1])
|
||
param_names = ['confusion_matrix', 'segment_count_matrix', 'confusion_matrix_percent',
|
||
'valid_signal_length', 'total_duration', 'time_labels', 'amp_labels']
|
||
params = dict(zip(param_names, bcg_statistic_info))
|
||
params['ax'] = ax2
|
||
draw_ax_confusion_matrix(**params)
|
||
|
||
ax3 = fig.add_subplot(gs[2], sharex=ax1)
|
||
draw_ax_amp(ax=ax3, signal_name='RSEP', original_times=resp_origin_times, origin_signal=resp_origin_signal,
|
||
no_movement_signal=resp_no_movement_signal, mav_values=resp_mav_values,
|
||
movement_position_list=resp_movement_position_list, low_amp_position_list=resp_low_amp_position_list,
|
||
signal_second_length=signal_second_length, aml_list=[100, 300, 500])
|
||
ax4 = fig.add_subplot(gs[3])
|
||
params = dict(zip(param_names, resp_statistic_info))
|
||
params['ax'] = ax4
|
||
draw_ax_confusion_matrix(**params)
|
||
|
||
# 全局标题
|
||
fig.suptitle(f'{signal_info} Signal Metrics', fontsize=16, x=0.35, y=0.95)
|
||
|
||
if save_path is not None:
|
||
# 保存图像
|
||
plt.savefig(save_path, dpi=300)
|
||
|
||
if show:
|
||
plt.show()
|
||
|
||
plt.close()
|
||
|
||
|
||
def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, resp_fs, bcg_fs,
|
||
signal_disable_mask, resp_low_amp_mask, resp_movement_mask, resp_change_mask,
|
||
resp_sa_mask, bcg_low_amp_mask, bcg_movement_mask, bcg_change_mask
|
||
):
|
||
# 第一行绘制去工频噪声的原始信号,右侧为不可用区间标记,左侧为信号幅值纵坐标
|
||
# 第二行绘制呼吸分量,右侧低幅值、高幅值、幅值变换标记、SA标签,左侧为呼吸幅值纵坐标
|
||
# 第三行绘制心冲击分量,右侧为低幅值、高幅值、幅值变换标记、,左侧为心冲击幅值纵坐标
|
||
# mask为None,则生成全Nan掩码
|
||
def _none_to_nan_mask(mask, ref):
|
||
if mask is None:
|
||
return np.full_like(ref, np.nan)
|
||
else:
|
||
# 将mask中的0替换为nan,其他的保持
|
||
mask = np.where(mask == 0, np.nan, mask)
|
||
return mask
|
||
|
||
signal_disable_mask = _none_to_nan_mask(signal_disable_mask, signal_data)
|
||
resp_low_amp_mask = _none_to_nan_mask(resp_low_amp_mask, resp_data)
|
||
resp_movement_mask = _none_to_nan_mask(resp_movement_mask, resp_data)
|
||
resp_change_mask = _none_to_nan_mask(resp_change_mask, resp_data)
|
||
resp_sa_mask = _none_to_nan_mask(resp_sa_mask, resp_data)
|
||
bcg_low_amp_mask = _none_to_nan_mask(bcg_low_amp_mask, bcg_data)
|
||
bcg_movement_mask = _none_to_nan_mask(bcg_movement_mask, bcg_data)
|
||
bcg_change_mask = _none_to_nan_mask(bcg_change_mask, bcg_data)
|
||
|
||
|
||
fig = plt.figure(figsize=(18, 10))
|
||
ax0 = fig.add_subplot(3, 1, 1)
|
||
ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue')
|
||
ax0.set_title(f'Sample {samp_id} - Raw Signal Data')
|
||
ax0.set_ylabel('Amplitude')
|
||
# ax0.set_xticklabels([])
|
||
|
||
ax0_twin = ax0.twinx()
|
||
ax0_twin.plot(np.linspace(0, len(signal_disable_mask), len(signal_disable_mask)), signal_disable_mask,
|
||
color='red', alpha=0.5)
|
||
ax0_twin.autoscale(enable=False, axis='y', tight=True)
|
||
ax0_twin.set_ylim((-2, 2))
|
||
ax0_twin.set_ylabel('Disable Mask')
|
||
ax0_twin.set_yticks([0, 1])
|
||
ax0_twin.set_yticklabels(['Enabled', 'Disabled'])
|
||
ax0_twin.grid(False)
|
||
ax0_twin.legend(['Disable Mask'], loc='upper right')
|
||
|
||
|
||
ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
|
||
ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='orange')
|
||
ax1.set_ylabel('Amplitude')
|
||
# ax1.set_xticklabels([])
|
||
ax1_twin = ax1.twinx()
|
||
ax1_twin.plot(np.linspace(0, len(resp_low_amp_mask), len(resp_low_amp_mask)), resp_low_amp_mask*-1,
|
||
color='blue', alpha=0.5, label='Low Amplitude Mask')
|
||
ax1_twin.plot(np.linspace(0, len(resp_movement_mask), len(resp_movement_mask)), resp_movement_mask*-2,
|
||
color='red', alpha=0.5, label='Movement Mask')
|
||
ax1_twin.plot(np.linspace(0, len(resp_change_mask), len(resp_change_mask)), resp_change_mask*-3,
|
||
color='green', alpha=0.5, label='Amplitude Change Mask')
|
||
ax1_twin.plot(np.linspace(0, len(resp_sa_mask), len(resp_sa_mask)), resp_sa_mask,
|
||
color='purple', alpha=0.5, label='SA Mask')
|
||
ax1_twin.autoscale(enable=False, axis='y', tight=True)
|
||
ax1_twin.set_ylim((-4, 5))
|
||
# ax1_twin.set_ylabel('Resp Masks')
|
||
# ax1_twin.set_yticks([0, 1])
|
||
# ax1_twin.set_yticklabels(['No', 'Yes'])
|
||
ax1_twin.grid(False)
|
||
|
||
ax1_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask', 'SA Mask'], loc='upper right')
|
||
ax1.set_title(f'Sample {samp_id} - Respiration Component')
|
||
|
||
ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
|
||
ax2.plot(np.linspace(0, len(bcg_data) // bcg_fs, len(bcg_data)), bcg_data, color='green')
|
||
ax2.set_ylabel('Amplitude')
|
||
ax2.set_xlabel('Time (s)')
|
||
ax2_twin = ax2.twinx()
|
||
ax2_twin.plot(np.linspace(0, len(bcg_low_amp_mask), len(bcg_low_amp_mask)), bcg_low_amp_mask*-1,
|
||
color='blue', alpha=0.5, label='Low Amplitude Mask')
|
||
ax2_twin.plot(np.linspace(0, len(bcg_movement_mask), len(bcg_movement_mask)), bcg_movement_mask*-2,
|
||
color='red', alpha=0.5, label='Movement Mask')
|
||
ax2_twin.plot(np.linspace(0, len(bcg_change_mask), len(bcg_change_mask)), bcg_change_mask*-3,
|
||
color='green', alpha=0.5, label='Amplitude Change Mask')
|
||
ax2_twin.autoscale(enable=False, axis='y', tight=True)
|
||
ax2_twin.set_ylim((-4, 2))
|
||
ax2_twin.set_ylabel('BCG Masks')
|
||
# ax2_twin.set_yticks([0, 1])
|
||
# ax2_twin.set_yticklabels(['No', 'Yes'])
|
||
ax2_twin.grid(False)
|
||
ax2_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask'], loc='upper right')
|
||
# ax2.set_title(f'Sample {samp_id} - BCG Component')
|
||
|
||
ax0_twin._lim_lock = False
|
||
ax1_twin._lim_lock = False
|
||
ax2_twin._lim_lock = False
|
||
|
||
def on_lims_change(event_ax):
|
||
if getattr(event_ax, '_lim_lock', False):
|
||
return
|
||
try:
|
||
event_ax._lim_lock = True
|
||
|
||
if event_ax == ax0_twin:
|
||
# 重新锁定 ax1 的 Y 轴范围
|
||
ax0_twin.set_ylim(-2, 2)
|
||
elif event_ax == ax1_twin:
|
||
ax1_twin.set_ylim(-3, 5)
|
||
elif event_ax == ax2_twin:
|
||
ax2_twin.set_ylim(-4, 2)
|
||
|
||
finally:
|
||
event_ax._lim_lock = False
|
||
|
||
|
||
ax0_twin.callbacks.connect('ylim_changed', on_lims_change)
|
||
ax1_twin.callbacks.connect('ylim_changed', on_lims_change)
|
||
ax2_twin.callbacks.connect('ylim_changed', on_lims_change)
|
||
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
|
||
|
||
|