from matplotlib.axes import Axes from matplotlib.gridspec import GridSpec from matplotlib.colors import PowerNorm import seaborn as sns import numpy as np def draw_ax_confusion_matrix(ax:Axes, confusion_matrix, segment_count_matrix, confusion_matrix_percent, valid_signal_length, total_duration, time_labels, amp_labels, signal_type=''): # 创建用于热图注释的文本矩阵 text_matrix = np.empty((len(amp_labels), len(time_labels)), dtype=object) percent_matrix = np.zeros((len(amp_labels), len(time_labels))) # 填充文本矩阵和百分比矩阵 for i in range(len(amp_labels)): for j in range(len(time_labels)): val = confusion_matrix.iloc[i, j] segment_count = segment_count_matrix[i, j] percent = confusion_matrix_percent.iloc[i, j] text_matrix[i, j] = f"[{int(segment_count)}]{int(val)}\n({percent:.2f}%)" percent_matrix[i, j] = percent # 绘制热图,调整颜色条位置 sns_heatmap = sns.heatmap(percent_matrix, annot=text_matrix, fmt='', xticklabels=time_labels, yticklabels=amp_labels, cmap='YlGnBu', ax=ax, vmin=0, vmax=100, norm=PowerNorm(gamma=0.5, vmin=0, vmax=100), cbar_kws={'label': '百分比 (%)', 'shrink': 0.6, 'pad': 0.15}, # annot_kws={'fontsize': 12} ) # 添加行统计(右侧) row_sums = confusion_matrix['总计'] row_percents = confusion_matrix_percent['总计'] ax.text(len(time_labels) + 1, -0.5, "各幅值时长\n(有效区间百分比%)", ha='center', va='center') for i, (val, perc) in enumerate(zip(row_sums, row_percents)): ax.text(len(time_labels) + 0.5, i + 0.5, f"{int(val)}\n({perc:.2f}%)", ha='center', va='center') # 添加列统计(底部) col_sums = segment_count_matrix.sum(axis=0) col_percents = confusion_matrix.sum(axis=0) / total_duration * 100 ax.text(-1, len(amp_labels) + 0.5, "[各时长片段数]\n(有效区间百分比%)", ha='center', va='center', rotation=0) for j, (val, perc) in enumerate(zip(col_sums, col_percents)): ax.text(j + 0.5, len(amp_labels) + 0.5, f"[{int(val)}]\n({perc:.2f}%)", ha='center', va='center', rotation=0) # 将x轴坐标移到顶部 ax.xaxis.set_label_position('top') ax.xaxis.tick_top() # 设置标题和标签 ax.set_title('幅值-时长统计矩阵', pad=40) ax.set_xlabel('持续时间区间 (秒)', labelpad=10) ax.set_ylabel('幅值区间') # 设置坐标轴标签水平显示 ax.set_xticklabels(time_labels, rotation=0) ax.set_yticklabels(amp_labels, rotation=0) # 调整颜色条位置 cbar = sns_heatmap.collections[0].colorbar cbar.ax.yaxis.set_label_position('right') # 添加图例说明 ax.text(-2, -1, "热图内:\n[片段数]时长\n(有效区间百分比%)", ha='left', va='top', bbox=dict(facecolor='none', edgecolor='black', alpha=0.5)) # 总计 # 总片段数 total_segments = segment_count_matrix.sum() # 有效总市场占比 total_percent = valid_signal_length / total_duration * 100 ax.text(len(time_labels) + 0.5, len(amp_labels) + 0.5, f"[{int(total_segments)}]{valid_signal_length}\n({total_percent:.2f}%)", ha='center', va='center') import matplotlib.pyplot as plt import matplotlib.patches as mpatches def draw_ax_amp(ax, signal_name, original_times, origin_signal, no_movement_signal, mav_values, movement_position_list, low_amp_position_list, signal_second_length, aml_list=None): # 绘制信号线 ax.plot(original_times, origin_signal, 'k-', linewidth=1, alpha=0.7) # 添加红色和蓝色的axvspan区域 for start, end in movement_position_list: if start < len(original_times) and end < len(original_times): ax.axvspan(start, end, color='red', alpha=0.3) for start, end in low_amp_position_list: if start < len(original_times) and end < len(original_times): ax.axvspan(start, end, color='blue', alpha=0.2) # 如果存在AML列表,绘制水平线 if aml_list is not None: color_map = ['red', 'orange', 'green'] for i, aml in enumerate(aml_list): ax.hlines(aml, 0, signal_second_length, color=color_map[min(i, 2)], linestyle='dashed', linewidth=2, alpha=0.5, label=f'{aml} aml') ax.plot(np.linspace(0, len(mav_values), len(mav_values)), mav_values, color='blue', linewidth=2, alpha=0.4, label='2sMAV') # 设置Y轴范围 ax.set_ylim((-2000, 2000)) # 创建表示不同颜色区域的图例 red_patch = mpatches.Patch(color='red', alpha=0.2, label='Movement Area') blue_patch = mpatches.Patch(color='blue', alpha=0.2, label='Low Amplitude Area') # 添加新的图例项,并保留原来的图例项 handles, labels = ax.get_legend_handles_labels() # 获取原有图例 ax.legend(handles=[red_patch, blue_patch] + handles, labels=['Movement Area', 'Low Amplitude Area'] + labels, loc='upper right', bbox_to_anchor=(1, 1.4), framealpha=0.5) # 设置标题和标签 ax.set_title(f'{signal_name} Signal') ax.set_ylabel('Amplitude') ax.set_xlabel('Time (s)') # 启用网格 ax.grid(True, linestyle='--', alpha=0.7) def draw_signal_metrics(bcg_origin_signal, resp_origin_signal, bcg_no_movement_signal, resp_no_movement_signal, bcg_sampling_rate, resp_sampling_rate, bcg_movement_position_list, bcg_low_amp_position_list, resp_movement_position_list, resp_low_amp_position_list, bcg_mav_values, resp_mav_values, bcg_statistic_info, resp_statistic_info, signal_info, show=False, save_path=None): # 创建图像 fig = plt.figure(figsize=(18, 10)) gs = GridSpec(2, 2, height_ratios=[2, 2], width_ratios=[4, 2], hspace=0.5) signal_second_length = len(bcg_origin_signal) // bcg_sampling_rate bcg_origin_times = np.linspace(0, signal_second_length, len(bcg_origin_signal)) resp_origin_times = np.linspace(0, signal_second_length, len(resp_origin_signal)) # 子图 1:原始信号 ax1 = fig.add_subplot(gs[0]) draw_ax_amp(ax=ax1, signal_name='BCG', original_times=bcg_origin_times, origin_signal=bcg_origin_signal, no_movement_signal=bcg_no_movement_signal, mav_values=bcg_mav_values, movement_position_list=bcg_movement_position_list, low_amp_position_list=bcg_low_amp_position_list, signal_second_length=signal_second_length, aml_list=[200, 500, 1000]) ax2 = fig.add_subplot(gs[1]) param_names = ['confusion_matrix', 'segment_count_matrix', 'confusion_matrix_percent', 'valid_signal_length', 'total_duration', 'time_labels', 'amp_labels'] params = dict(zip(param_names, bcg_statistic_info)) params['ax'] = ax2 draw_ax_confusion_matrix(**params) ax3 = fig.add_subplot(gs[2], sharex=ax1) draw_ax_amp(ax=ax3, signal_name='RSEP', original_times=resp_origin_times, origin_signal=resp_origin_signal, no_movement_signal=resp_no_movement_signal, mav_values=resp_mav_values, movement_position_list=resp_movement_position_list, low_amp_position_list=resp_low_amp_position_list, signal_second_length=signal_second_length, aml_list=[100, 300, 500]) ax4 = fig.add_subplot(gs[3]) params = dict(zip(param_names, resp_statistic_info)) params['ax'] = ax4 draw_ax_confusion_matrix(**params) # 全局标题 fig.suptitle(f'{signal_info} Signal Metrics', fontsize=16, x=0.35, y=0.95) if save_path is not None: # 保存图像 plt.savefig(save_path, dpi=300) if show: plt.show() plt.close()