175 lines
7.8 KiB
Python
175 lines
7.8 KiB
Python
from matplotlib.axes import Axes
|
||
from matplotlib.gridspec import GridSpec
|
||
from matplotlib.colors import PowerNorm
|
||
import seaborn as sns
|
||
import numpy as np
|
||
|
||
|
||
def draw_ax_confusion_matrix(ax:Axes, confusion_matrix, segment_count_matrix, confusion_matrix_percent,
|
||
valid_signal_length, total_duration, time_labels, amp_labels, signal_type=''):
|
||
# 创建用于热图注释的文本矩阵
|
||
text_matrix = np.empty((len(amp_labels), len(time_labels)), dtype=object)
|
||
percent_matrix = np.zeros((len(amp_labels), len(time_labels)))
|
||
|
||
# 填充文本矩阵和百分比矩阵
|
||
for i in range(len(amp_labels)):
|
||
for j in range(len(time_labels)):
|
||
val = confusion_matrix.iloc[i, j]
|
||
segment_count = segment_count_matrix[i, j]
|
||
percent = confusion_matrix_percent.iloc[i, j]
|
||
text_matrix[i, j] = f"[{int(segment_count)}]{int(val)}\n({percent:.2f}%)"
|
||
percent_matrix[i, j] = percent
|
||
|
||
# 绘制热图,调整颜色条位置
|
||
sns_heatmap = sns.heatmap(percent_matrix, annot=text_matrix, fmt='',
|
||
xticklabels=time_labels, yticklabels=amp_labels,
|
||
cmap='YlGnBu', ax=ax, vmin=0, vmax=100,
|
||
norm=PowerNorm(gamma=0.5, vmin=0, vmax=100),
|
||
cbar_kws={'label': '百分比 (%)', 'shrink': 0.6, 'pad': 0.15},
|
||
# annot_kws={'fontsize': 12}
|
||
)
|
||
|
||
# 添加行统计(右侧)
|
||
row_sums = confusion_matrix['总计']
|
||
row_percents = confusion_matrix_percent['总计']
|
||
ax.text(len(time_labels) + 1, -0.5, "各幅值时长\n(有效区间百分比%)", ha='center', va='center')
|
||
for i, (val, perc) in enumerate(zip(row_sums, row_percents)):
|
||
ax.text(len(time_labels) + 0.5, i + 0.5, f"{int(val)}\n({perc:.2f}%)",
|
||
ha='center', va='center')
|
||
|
||
# 添加列统计(底部)
|
||
col_sums = segment_count_matrix.sum(axis=0)
|
||
col_percents = confusion_matrix.sum(axis=0) / total_duration * 100
|
||
ax.text(-1, len(amp_labels) + 0.5, "[各时长片段数]\n(有效区间百分比%)", ha='center', va='center', rotation=0)
|
||
for j, (val, perc) in enumerate(zip(col_sums, col_percents)):
|
||
ax.text(j + 0.5, len(amp_labels) + 0.5, f"[{int(val)}]\n({perc:.2f}%)",
|
||
ha='center', va='center', rotation=0)
|
||
|
||
# 将x轴坐标移到顶部
|
||
ax.xaxis.set_label_position('top')
|
||
ax.xaxis.tick_top()
|
||
|
||
# 设置标题和标签
|
||
ax.set_title('幅值-时长统计矩阵', pad=40)
|
||
ax.set_xlabel('持续时间区间 (秒)', labelpad=10)
|
||
ax.set_ylabel('幅值区间')
|
||
|
||
# 设置坐标轴标签水平显示
|
||
ax.set_xticklabels(time_labels, rotation=0)
|
||
ax.set_yticklabels(amp_labels, rotation=0)
|
||
|
||
# 调整颜色条位置
|
||
cbar = sns_heatmap.collections[0].colorbar
|
||
cbar.ax.yaxis.set_label_position('right')
|
||
|
||
# 添加图例说明
|
||
ax.text(-2, -1, "热图内:\n[片段数]时长\n(有效区间百分比%)",
|
||
ha='left', va='top', bbox=dict(facecolor='none', edgecolor='black', alpha=0.5))
|
||
|
||
# 总计
|
||
# 总片段数
|
||
total_segments = segment_count_matrix.sum()
|
||
# 有效总市场占比
|
||
total_percent = valid_signal_length / total_duration * 100
|
||
ax.text(len(time_labels) + 0.5, len(amp_labels) + 0.5, f"[{int(total_segments)}]{valid_signal_length}\n({total_percent:.2f}%)",
|
||
ha='center', va='center')
|
||
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.patches as mpatches
|
||
|
||
def draw_ax_amp(ax, signal_name, original_times, origin_signal, no_movement_signal, mav_values,
|
||
movement_position_list, low_amp_position_list, signal_second_length, aml_list=None):
|
||
# 绘制信号线
|
||
ax.plot(original_times, origin_signal, 'k-', linewidth=1, alpha=0.7)
|
||
|
||
# 添加红色和蓝色的axvspan区域
|
||
for start, end in movement_position_list:
|
||
if start < len(original_times) and end < len(original_times):
|
||
ax.axvspan(start, end, color='red', alpha=0.3)
|
||
|
||
for start, end in low_amp_position_list:
|
||
if start < len(original_times) and end < len(original_times):
|
||
ax.axvspan(start, end, color='blue', alpha=0.2)
|
||
|
||
# 如果存在AML列表,绘制水平线
|
||
if aml_list is not None:
|
||
color_map = ['red', 'orange', 'green']
|
||
for i, aml in enumerate(aml_list):
|
||
ax.hlines(aml, 0, signal_second_length, color=color_map[min(i, 2)], linestyle='dashed', linewidth=2, alpha=0.5, label=f'{aml} aml')
|
||
|
||
|
||
ax.plot(np.linspace(0, len(mav_values), len(mav_values)), mav_values, color='blue', linewidth=2, alpha=0.4, label='2sMAV')
|
||
|
||
# 设置Y轴范围
|
||
ax.set_ylim((-2000, 2000))
|
||
|
||
# 创建表示不同颜色区域的图例
|
||
red_patch = mpatches.Patch(color='red', alpha=0.2, label='Movement Area')
|
||
blue_patch = mpatches.Patch(color='blue', alpha=0.2, label='Low Amplitude Area')
|
||
|
||
# 添加新的图例项,并保留原来的图例项
|
||
handles, labels = ax.get_legend_handles_labels() # 获取原有图例
|
||
ax.legend(handles=[red_patch, blue_patch] + handles,
|
||
labels=['Movement Area', 'Low Amplitude Area'] + labels,
|
||
loc='upper right',
|
||
bbox_to_anchor=(1, 1.4),
|
||
framealpha=0.5)
|
||
|
||
# 设置标题和标签
|
||
ax.set_title(f'{signal_name} Signal')
|
||
ax.set_ylabel('Amplitude')
|
||
ax.set_xlabel('Time (s)')
|
||
|
||
# 启用网格
|
||
ax.grid(True, linestyle='--', alpha=0.7)
|
||
|
||
|
||
def draw_signal_metrics(bcg_origin_signal, resp_origin_signal, bcg_no_movement_signal, resp_no_movement_signal,
|
||
bcg_sampling_rate, resp_sampling_rate, bcg_movement_position_list, bcg_low_amp_position_list,
|
||
resp_movement_position_list, resp_low_amp_position_list,
|
||
bcg_mav_values, resp_mav_values, bcg_statistic_info, resp_statistic_info,
|
||
signal_info, show=False, save_path=None):
|
||
|
||
# 创建图像
|
||
fig = plt.figure(figsize=(18, 10))
|
||
|
||
gs = GridSpec(2, 2, height_ratios=[2, 2], width_ratios=[4, 2], hspace=0.5)
|
||
|
||
signal_second_length = len(bcg_origin_signal) // bcg_sampling_rate
|
||
bcg_origin_times = np.linspace(0, signal_second_length, len(bcg_origin_signal))
|
||
resp_origin_times = np.linspace(0, signal_second_length, len(resp_origin_signal))
|
||
# 子图 1:原始信号
|
||
ax1 = fig.add_subplot(gs[0])
|
||
draw_ax_amp(ax=ax1, signal_name='BCG', original_times=bcg_origin_times, origin_signal=bcg_origin_signal,
|
||
no_movement_signal=bcg_no_movement_signal, mav_values=bcg_mav_values,
|
||
movement_position_list=bcg_movement_position_list, low_amp_position_list=bcg_low_amp_position_list,
|
||
signal_second_length=signal_second_length, aml_list=[200, 500, 1000])
|
||
|
||
ax2 = fig.add_subplot(gs[1])
|
||
param_names = ['confusion_matrix', 'segment_count_matrix', 'confusion_matrix_percent',
|
||
'valid_signal_length', 'total_duration', 'time_labels', 'amp_labels']
|
||
params = dict(zip(param_names, bcg_statistic_info))
|
||
params['ax'] = ax2
|
||
draw_ax_confusion_matrix(**params)
|
||
|
||
ax3 = fig.add_subplot(gs[2], sharex=ax1)
|
||
draw_ax_amp(ax=ax3, signal_name='RSEP', original_times=resp_origin_times, origin_signal=resp_origin_signal,
|
||
no_movement_signal=resp_no_movement_signal, mav_values=resp_mav_values,
|
||
movement_position_list=resp_movement_position_list, low_amp_position_list=resp_low_amp_position_list,
|
||
signal_second_length=signal_second_length, aml_list=[100, 300, 500])
|
||
ax4 = fig.add_subplot(gs[3])
|
||
params = dict(zip(param_names, resp_statistic_info))
|
||
params['ax'] = ax4
|
||
draw_ax_confusion_matrix(**params)
|
||
|
||
# 全局标题
|
||
fig.suptitle(f'{signal_info} Signal Metrics', fontsize=16, x=0.35, y=0.95)
|
||
|
||
if save_path is not None:
|
||
# 保存图像
|
||
plt.savefig(save_path, dpi=300)
|
||
|
||
if show:
|
||
plt.show()
|
||
|
||
plt.close() |