from matplotlib.axes import Axes from matplotlib.gridspec import GridSpec from matplotlib.colors import PowerNorm import matplotlib.pyplot as plt import matplotlib.patches as mpatches import seaborn as sns import numpy as np def draw_ax_confusion_matrix(ax:Axes, confusion_matrix, segment_count_matrix, confusion_matrix_percent, valid_signal_length, total_duration, time_labels, amp_labels, signal_type=''): # 创建用于热图注释的文本矩阵 text_matrix = np.empty((len(amp_labels), len(time_labels)), dtype=object) percent_matrix = np.zeros((len(amp_labels), len(time_labels))) # 填充文本矩阵和百分比矩阵 for i in range(len(amp_labels)): for j in range(len(time_labels)): val = confusion_matrix.iloc[i, j] segment_count = segment_count_matrix[i, j] percent = confusion_matrix_percent.iloc[i, j] text_matrix[i, j] = f"[{int(segment_count)}]{int(val)}\n({percent:.2f}%)" percent_matrix[i, j] = percent # 绘制热图,调整颜色条位置 sns_heatmap = sns.heatmap(percent_matrix, annot=text_matrix, fmt='', xticklabels=time_labels, yticklabels=amp_labels, cmap='YlGnBu', ax=ax, vmin=0, vmax=100, norm=PowerNorm(gamma=0.5, vmin=0, vmax=100), cbar_kws={'label': '百分比 (%)', 'shrink': 0.6, 'pad': 0.15}, # annot_kws={'fontsize': 12} ) # 添加行统计(右侧) row_sums = confusion_matrix['总计'] row_percents = confusion_matrix_percent['总计'] ax.text(len(time_labels) + 1, -0.5, "各幅值时长\n(有效区间百分比%)", ha='center', va='center') for i, (val, perc) in enumerate(zip(row_sums, row_percents)): ax.text(len(time_labels) + 0.5, i + 0.5, f"{int(val)}\n({perc:.2f}%)", ha='center', va='center') # 添加列统计(底部) col_sums = segment_count_matrix.sum(axis=0) col_percents = confusion_matrix.sum(axis=0) / total_duration * 100 ax.text(-1, len(amp_labels) + 0.5, "[各时长片段数]\n(有效区间百分比%)", ha='center', va='center', rotation=0) for j, (val, perc) in enumerate(zip(col_sums, col_percents)): ax.text(j + 0.5, len(amp_labels) + 0.5, f"[{int(val)}]\n({perc:.2f}%)", ha='center', va='center', rotation=0) # 将x轴坐标移到顶部 ax.xaxis.set_label_position('top') ax.xaxis.tick_top() # 设置标题和标签 ax.set_title('幅值-时长统计矩阵', pad=40) ax.set_xlabel('持续时间区间 (秒)', labelpad=10) ax.set_ylabel('幅值区间') # 设置坐标轴标签水平显示 ax.set_xticklabels(time_labels, rotation=0) ax.set_yticklabels(amp_labels, rotation=0) # 调整颜色条位置 cbar = sns_heatmap.collections[0].colorbar cbar.ax.yaxis.set_label_position('right') # 添加图例说明 ax.text(-2, -1, "热图内:\n[片段数]时长\n(有效区间百分比%)", ha='left', va='top', bbox=dict(facecolor='none', edgecolor='black', alpha=0.5)) # 总计 # 总片段数 total_segments = segment_count_matrix.sum() # 有效总市场占比 total_percent = valid_signal_length / total_duration * 100 ax.text(len(time_labels) + 0.5, len(amp_labels) + 0.5, f"[{int(total_segments)}]{valid_signal_length}\n({total_percent:.2f}%)", ha='center', va='center') def draw_ax_amp(ax, signal_name, original_times, origin_signal, no_movement_signal, mav_values, movement_position_list, low_amp_position_list, signal_second_length, aml_list=None): # 绘制信号线 ax.plot(original_times, origin_signal, 'k-', linewidth=1, alpha=0.7) # 添加红色和蓝色的axvspan区域 for start, end in movement_position_list: if start < len(original_times) and end < len(original_times): ax.axvspan(start, end, color='red', alpha=0.3) for start, end in low_amp_position_list: if start < len(original_times) and end < len(original_times): ax.axvspan(start, end, color='blue', alpha=0.2) # 如果存在AML列表,绘制水平线 if aml_list is not None: color_map = ['red', 'orange', 'green'] for i, aml in enumerate(aml_list): ax.hlines(aml, 0, signal_second_length, color=color_map[min(i, 2)], linestyle='dashed', linewidth=2, alpha=0.5, label=f'{aml} aml') ax.plot(np.linspace(0, len(mav_values), len(mav_values)), mav_values, color='blue', linewidth=2, alpha=0.4, label='2sMAV') # 设置Y轴范围 ax.set_ylim((-2000, 2000)) # 创建表示不同颜色区域的图例 red_patch = mpatches.Patch(color='red', alpha=0.2, label='Movement Area') blue_patch = mpatches.Patch(color='blue', alpha=0.2, label='Low Amplitude Area') # 添加新的图例项,并保留原来的图例项 handles, labels = ax.get_legend_handles_labels() # 获取原有图例 ax.legend(handles=[red_patch, blue_patch] + handles, labels=['Movement Area', 'Low Amplitude Area'] + labels, loc='upper right', bbox_to_anchor=(1, 1.4), framealpha=0.5) # 设置标题和标签 ax.set_title(f'{signal_name} Signal') ax.set_ylabel('Amplitude') ax.set_xlabel('Time (s)') # 启用网格 ax.grid(True, linestyle='--', alpha=0.7) def draw_signal_metrics(bcg_origin_signal, resp_origin_signal, bcg_no_movement_signal, resp_no_movement_signal, bcg_sampling_rate, resp_sampling_rate, bcg_movement_position_list, bcg_low_amp_position_list, resp_movement_position_list, resp_low_amp_position_list, bcg_mav_values, resp_mav_values, bcg_statistic_info, resp_statistic_info, signal_info, show=False, save_path=None): # 创建图像 fig = plt.figure(figsize=(18, 10)) gs = GridSpec(2, 2, height_ratios=[2, 2], width_ratios=[4, 2], hspace=0.5) signal_second_length = len(bcg_origin_signal) // bcg_sampling_rate bcg_origin_times = np.linspace(0, signal_second_length, len(bcg_origin_signal)) resp_origin_times = np.linspace(0, signal_second_length, len(resp_origin_signal)) # 子图 1:原始信号 ax1 = fig.add_subplot(gs[0]) draw_ax_amp(ax=ax1, signal_name='BCG', original_times=bcg_origin_times, origin_signal=bcg_origin_signal, no_movement_signal=bcg_no_movement_signal, mav_values=bcg_mav_values, movement_position_list=bcg_movement_position_list, low_amp_position_list=bcg_low_amp_position_list, signal_second_length=signal_second_length, aml_list=[200, 500, 1000]) ax2 = fig.add_subplot(gs[1]) param_names = ['confusion_matrix', 'segment_count_matrix', 'confusion_matrix_percent', 'valid_signal_length', 'total_duration', 'time_labels', 'amp_labels'] params = dict(zip(param_names, bcg_statistic_info)) params['ax'] = ax2 draw_ax_confusion_matrix(**params) ax3 = fig.add_subplot(gs[2], sharex=ax1) draw_ax_amp(ax=ax3, signal_name='RSEP', original_times=resp_origin_times, origin_signal=resp_origin_signal, no_movement_signal=resp_no_movement_signal, mav_values=resp_mav_values, movement_position_list=resp_movement_position_list, low_amp_position_list=resp_low_amp_position_list, signal_second_length=signal_second_length, aml_list=[100, 300, 500]) ax4 = fig.add_subplot(gs[3]) params = dict(zip(param_names, resp_statistic_info)) params['ax'] = ax4 draw_ax_confusion_matrix(**params) # 全局标题 fig.suptitle(f'{signal_info} Signal Metrics', fontsize=16, x=0.35, y=0.95) if save_path is not None: # 保存图像 plt.savefig(save_path, dpi=300) if show: plt.show() plt.close() def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, resp_fs, bcg_fs, signal_disable_mask, resp_low_amp_mask, resp_movement_mask, resp_change_mask, resp_sa_mask, bcg_low_amp_mask, bcg_movement_mask, bcg_change_mask ): # 第一行绘制去工频噪声的原始信号,右侧为不可用区间标记,左侧为信号幅值纵坐标 # 第二行绘制呼吸分量,右侧低幅值、高幅值、幅值变换标记、SA标签,左侧为呼吸幅值纵坐标 # 第三行绘制心冲击分量,右侧为低幅值、高幅值、幅值变换标记、,左侧为心冲击幅值纵坐标 # mask为None,则生成全Nan掩码 def _none_to_nan_mask(mask, ref): if mask is None: return np.full_like(ref, np.nan) else: # 将mask中的0替换为nan,其他的保持 mask = np.where(mask == 0, np.nan, mask) return mask signal_disable_mask = _none_to_nan_mask(signal_disable_mask, signal_data) resp_low_amp_mask = _none_to_nan_mask(resp_low_amp_mask, resp_data) resp_movement_mask = _none_to_nan_mask(resp_movement_mask, resp_data) resp_change_mask = _none_to_nan_mask(resp_change_mask, resp_data) resp_sa_mask = _none_to_nan_mask(resp_sa_mask, resp_data) bcg_low_amp_mask = _none_to_nan_mask(bcg_low_amp_mask, bcg_data) bcg_movement_mask = _none_to_nan_mask(bcg_movement_mask, bcg_data) bcg_change_mask = _none_to_nan_mask(bcg_change_mask, bcg_data) fig = plt.figure(figsize=(18, 10)) ax0 = fig.add_subplot(3, 1, 1) ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue') ax0.set_title(f'Sample {samp_id} - Raw Signal Data') ax0.set_ylabel('Amplitude') # ax0.set_xticklabels([]) ax0_twin = ax0.twinx() ax0_twin.plot(np.linspace(0, len(signal_disable_mask), len(signal_disable_mask)), signal_disable_mask, color='red', alpha=0.5) ax0_twin.autoscale(enable=False, axis='y', tight=True) ax0_twin.set_ylim((-2, 2)) ax0_twin.set_ylabel('Disable Mask') ax0_twin.set_yticks([0, 1]) ax0_twin.set_yticklabels(['Enabled', 'Disabled']) ax0_twin.grid(False) ax0_twin.legend(['Disable Mask'], loc='upper right') ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='orange') ax1.set_ylabel('Amplitude') # ax1.set_xticklabels([]) ax1_twin = ax1.twinx() ax1_twin.plot(np.linspace(0, len(resp_low_amp_mask), len(resp_low_amp_mask)), resp_low_amp_mask*-1, color='blue', alpha=0.5, label='Low Amplitude Mask') ax1_twin.plot(np.linspace(0, len(resp_movement_mask), len(resp_movement_mask)), resp_movement_mask*-2, color='red', alpha=0.5, label='Movement Mask') ax1_twin.plot(np.linspace(0, len(resp_change_mask), len(resp_change_mask)), resp_change_mask*-3, color='green', alpha=0.5, label='Amplitude Change Mask') ax1_twin.plot(np.linspace(0, len(resp_sa_mask), len(resp_sa_mask)), resp_sa_mask, color='purple', alpha=0.5, label='SA Mask') ax1_twin.autoscale(enable=False, axis='y', tight=True) ax1_twin.set_ylim((-4, 5)) # ax1_twin.set_ylabel('Resp Masks') # ax1_twin.set_yticks([0, 1]) # ax1_twin.set_yticklabels(['No', 'Yes']) ax1_twin.grid(False) ax1_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask', 'SA Mask'], loc='upper right') ax1.set_title(f'Sample {samp_id} - Respiration Component') ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) ax2.plot(np.linspace(0, len(bcg_data) // bcg_fs, len(bcg_data)), bcg_data, color='green') ax2.set_ylabel('Amplitude') ax2.set_xlabel('Time (s)') ax2_twin = ax2.twinx() ax2_twin.plot(np.linspace(0, len(bcg_low_amp_mask), len(bcg_low_amp_mask)), bcg_low_amp_mask*-1, color='blue', alpha=0.5, label='Low Amplitude Mask') ax2_twin.plot(np.linspace(0, len(bcg_movement_mask), len(bcg_movement_mask)), bcg_movement_mask*-2, color='red', alpha=0.5, label='Movement Mask') ax2_twin.plot(np.linspace(0, len(bcg_change_mask), len(bcg_change_mask)), bcg_change_mask*-3, color='green', alpha=0.5, label='Amplitude Change Mask') ax2_twin.autoscale(enable=False, axis='y', tight=True) ax2_twin.set_ylim((-4, 2)) ax2_twin.set_ylabel('BCG Masks') # ax2_twin.set_yticks([0, 1]) # ax2_twin.set_yticklabels(['No', 'Yes']) ax2_twin.grid(False) ax2_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask'], loc='upper right') # ax2.set_title(f'Sample {samp_id} - BCG Component') ax0_twin._lim_lock = False ax1_twin._lim_lock = False ax2_twin._lim_lock = False def on_lims_change(event_ax): if getattr(event_ax, '_lim_lock', False): return try: event_ax._lim_lock = True if event_ax == ax0_twin: # 重新锁定 ax1 的 Y 轴范围 ax0_twin.set_ylim(-2, 2) elif event_ax == ax1_twin: ax1_twin.set_ylim(-4, 5) elif event_ax == ax2_twin: ax2_twin.set_ylim(-4, 2) finally: event_ax._lim_lock = False ax0_twin.callbacks.connect('ylim_changed', on_lims_change) ax1_twin.callbacks.connect('ylim_changed', on_lims_change) ax2_twin.callbacks.connect('ylim_changed', on_lims_change) plt.tight_layout() plt.show()