DataPrepare/draw_tools/draw_label.py

231 lines
9.7 KiB
Python

from matplotlib.axes import Axes
from matplotlib.gridspec import GridSpec
from matplotlib.colors import PowerNorm
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import numpy as np
import utils
# 添加with_prediction参数
psg_chn_name2ax = {
"SpO2": 0,
"Flow T": 1,
"Flow P": 2,
"Effort Tho": 3,
"Effort Abd": 4,
"HR": 5,
"resp": 6,
"bcg": 7,
"Stage": 8
}
resp_chn_name2ax = {
"resp": 0,
"bcg": 1,
}
def create_psg_bcg_figure():
fig = plt.figure(figsize=(12, 8), dpi=100)
gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1])
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
axes = []
for i in range(9):
ax = fig.add_subplot(gs[i])
axes.append(ax)
axes[0].grid(True)
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
axes[0].set_ylim((85, 100))
axes[0].tick_params(axis='x', colors="white")
axes[1].grid(True)
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
axes[1].tick_params(axis='x', colors="white")
axes[2].grid(True)
# axes[2].xaxis.set_major_formatter(Params.FORMATTER)
axes[2].tick_params(axis='x', colors="white")
axes[3].grid(True)
# axes[3].xaxis.set_major_formatter(Params.FORMATTER)
axes[3].tick_params(axis='x', colors="white")
axes[4].grid(True)
# axes[4].xaxis.set_major_formatter(Params.FORMATTER)
axes[4].tick_params(axis='x', colors="white")
axes[5].grid(True)
axes[5].tick_params(axis='x', colors="white")
axes[6].grid(True)
# axes[5].xaxis.set_major_formatter(Params.FORMATTER)
axes[6].tick_params(axis='x', colors="white")
axes[7].grid(True)
# axes[6].xaxis.set_major_formatter(Params.FORMATTER)
axes[7].tick_params(axis='x', colors="white")
axes[8].grid(True)
# axes[7].xaxis.set_major_formatter(Params.FORMATTER)
return fig, axes
def create_resp_figure():
fig = plt.figure(figsize=(12, 6), dpi=100)
gs = GridSpec(2, 1, height_ratios=[3, 2])
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
axes = []
for i in range(2):
ax = fig.add_subplot(gs[i])
axes.append(ax)
axes[0].grid(True)
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
axes[0].tick_params(axis='x', colors="white")
axes[1].grid(True)
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
axes[1].tick_params(axis='x', colors="white")
return fig, axes
def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None,
event_codes: list[int] = None, multi_labels=None):
signal_fs = signal_data["fs"]
chn_signal = signal_data["data"]
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs)
ax.plot(time_axis, chn_signal[segment_start * signal_fs:segment_end * signal_fs], color='black',
label=signal_data["name"])
if event_mask is not None:
if multi_labels is None and event_codes is not None:
for event_code in event_codes:
mask = event_mask[segment_start:segment_end].repeat(signal_fs) == event_code
y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * mask).astype(float)
np.place(y, y == 0, np.nan)
ax.plot(time_axis, y, color=utils.ColorCycle[event_code])
elif multi_labels == "resp" and event_codes is not None:
ax.set_ylim(-6, 6)
# 建立第二个y轴坐标
ax2 = ax.twinx()
ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
color='blue', alpha=0.8, label='Low Amplitude Mask')
ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
color='orange', alpha=0.8, label='Movement Mask')
ax2.plot(time_axis, event_mask["Resp_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3,
color='green', alpha=0.8, label='Amplitude Change Mask')
for event_code in event_codes:
sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code
score_mask = event_mask["SA_Score_Alpha"][segment_start:segment_end].repeat(signal_fs)
y = (sa_mask * score_mask).astype(float)
np.place(y, y == 0, np.nan)
ax2.plot(time_axis, y, color=utils.ColorCycle[event_code])
ax2.set_ylim(-4, 5)
elif multi_labels == "bcg" and event_codes is not None:
# 建立第二个y轴坐标
ax2 = ax.twinx()
ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
color='blue', alpha=0.8, label='Low Amplitude Mask')
ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
color='orange', alpha=0.8, label='Movement Mask')
ax2.plot(time_axis, event_mask["BCG_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3,
color='green', alpha=0.8, label='Amplitude Change Mask')
ax2.set_ylim(-4, 4)
ax.set_ylabel("Amplitude")
ax.legend(loc=1)
def plt_stage_on_ax(ax, stage_data, segment_start, segment_end):
stage_signal = stage_data["data"]
stage_fs = stage_data["fs"]
time_axis = np.linspace(segment_start / stage_fs, segment_end / stage_fs, segment_end - segment_start)
ax.plot(time_axis, stage_signal[segment_start:segment_end], color='black', label=stage_data["name"])
ax.set_ylim(0, 6)
ax.set_yticks([1, 2, 3, 4, 5])
ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"])
ax.set_ylabel("Stage")
ax.legend(loc=1)
def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end):
spo2_signal = spo2_data["data"]
spo2_fs = spo2_data["fs"]
time_axis = np.linspace(segment_start / spo2_fs, segment_end / spo2_fs, segment_end - segment_start)
ax.plot(time_axis, spo2_signal[segment_start:segment_end], color='black', label=spo2_data["name"])
if spo2_signal[segment_start:segment_end].min() < 85:
ax.set_ylim((spo2_signal[segment_start:segment_end].min() - 5, 100))
else:
ax.set_ylim((85, 100))
ax.set_ylabel("SpO2 (%)")
ax.legend(loc=1)
def score_mask2alpha(score_mask):
alpha_mask = np.zeros_like(score_mask, dtype=float)
alpha_mask[score_mask <= 0] = 0
alpha_mask[score_mask == 1] = 0.9
alpha_mask[score_mask == 2] = 0.6
alpha_mask[score_mask == 3] = 0.1
return alpha_mask
def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list):
for mask in event_mask.keys():
if mask.startswith("Resp_") or mask.endswith("BCG_"):
event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0)
event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"])
event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0)
fig, axes = create_psg_bcg_figure()
for segment_start, segment_end in segment_list:
print(f"Drawing segment: {segment_start} to {segment_end} seconds")
for ax in axes:
ax.cla()
plt_spo2_on_ax(axes[psg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end)
plt_stage_on_ax(axes[psg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4])
plt.show()
print(f"Finished drawing segment: {segment_start} to {segment_end} seconds")
def draw_resp_label(resp_data, resp_label, segment_list):
for mask in resp_label.keys():
if mask.startswith("Resp_"):
resp_label[mask] = utils.none_to_nan_mask(resp_label[mask], 0)
resp_label["Resp_Score_Alpha"] = score_mask2alpha(resp_label["Resp_Score"])
resp_label["Resp_Label_Alpha"] = utils.none_to_nan_mask(resp_label["Resp_Label_Alpha"], 0)
fig, axes = create_resp_figure()
for segment_start, segment_end in segment_list:
for ax in axes:
ax.cla()
plt_signal_label_on_ax(axes[resp_chn_name2ax["resp"]], resp_data["resp_signal"], segment_start, segment_end,
resp_label, multi_labels="resp", event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[resp_chn_name2ax["bcg"]], resp_data["bcg_signal"], segment_start, segment_end,
resp_label, multi_labels="bcg", event_codes=[1, 2, 3, 4])
plt.show()