231 lines
9.7 KiB
Python
231 lines
9.7 KiB
Python
from matplotlib.axes import Axes
|
|
from matplotlib.gridspec import GridSpec
|
|
from matplotlib.colors import PowerNorm
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.patches as mpatches
|
|
import seaborn as sns
|
|
import numpy as np
|
|
|
|
import utils
|
|
|
|
# 添加with_prediction参数
|
|
|
|
psg_chn_name2ax = {
|
|
"SpO2": 0,
|
|
"Flow T": 1,
|
|
"Flow P": 2,
|
|
"Effort Tho": 3,
|
|
"Effort Abd": 4,
|
|
"HR": 5,
|
|
"resp": 6,
|
|
"bcg": 7,
|
|
"Stage": 8
|
|
}
|
|
|
|
resp_chn_name2ax = {
|
|
"resp": 0,
|
|
"bcg": 1,
|
|
}
|
|
|
|
|
|
def create_psg_bcg_figure():
|
|
fig = plt.figure(figsize=(12, 8), dpi=100)
|
|
gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1])
|
|
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
|
|
axes = []
|
|
for i in range(9):
|
|
ax = fig.add_subplot(gs[i])
|
|
axes.append(ax)
|
|
|
|
axes[0].grid(True)
|
|
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
|
|
axes[0].set_ylim((85, 100))
|
|
axes[0].tick_params(axis='x', colors="white")
|
|
|
|
axes[1].grid(True)
|
|
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
|
|
axes[1].tick_params(axis='x', colors="white")
|
|
|
|
axes[2].grid(True)
|
|
# axes[2].xaxis.set_major_formatter(Params.FORMATTER)
|
|
axes[2].tick_params(axis='x', colors="white")
|
|
|
|
axes[3].grid(True)
|
|
# axes[3].xaxis.set_major_formatter(Params.FORMATTER)
|
|
axes[3].tick_params(axis='x', colors="white")
|
|
|
|
axes[4].grid(True)
|
|
# axes[4].xaxis.set_major_formatter(Params.FORMATTER)
|
|
axes[4].tick_params(axis='x', colors="white")
|
|
|
|
axes[5].grid(True)
|
|
axes[5].tick_params(axis='x', colors="white")
|
|
|
|
axes[6].grid(True)
|
|
# axes[5].xaxis.set_major_formatter(Params.FORMATTER)
|
|
axes[6].tick_params(axis='x', colors="white")
|
|
|
|
axes[7].grid(True)
|
|
# axes[6].xaxis.set_major_formatter(Params.FORMATTER)
|
|
axes[7].tick_params(axis='x', colors="white")
|
|
|
|
axes[8].grid(True)
|
|
# axes[7].xaxis.set_major_formatter(Params.FORMATTER)
|
|
|
|
return fig, axes
|
|
|
|
|
|
def create_resp_figure():
|
|
fig = plt.figure(figsize=(12, 6), dpi=100)
|
|
gs = GridSpec(2, 1, height_ratios=[3, 2])
|
|
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
|
|
axes = []
|
|
for i in range(2):
|
|
ax = fig.add_subplot(gs[i])
|
|
axes.append(ax)
|
|
|
|
axes[0].grid(True)
|
|
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
|
|
axes[0].tick_params(axis='x', colors="white")
|
|
|
|
axes[1].grid(True)
|
|
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
|
|
axes[1].tick_params(axis='x', colors="white")
|
|
|
|
return fig, axes
|
|
|
|
|
|
def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None,
|
|
event_codes: list[int] = None, multi_labels=None):
|
|
signal_fs = signal_data["fs"]
|
|
chn_signal = signal_data["data"]
|
|
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs)
|
|
ax.plot(time_axis, chn_signal[segment_start * signal_fs:segment_end * signal_fs], color='black',
|
|
label=signal_data["name"])
|
|
if event_mask is not None:
|
|
if multi_labels is None and event_codes is not None:
|
|
for event_code in event_codes:
|
|
mask = event_mask[segment_start:segment_end].repeat(signal_fs) == event_code
|
|
y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * mask).astype(float)
|
|
np.place(y, y == 0, np.nan)
|
|
ax.plot(time_axis, y, color=utils.ColorCycle[event_code])
|
|
elif multi_labels == "resp" and event_codes is not None:
|
|
ax.set_ylim(-6, 6)
|
|
# 建立第二个y轴坐标
|
|
ax2 = ax.twinx()
|
|
ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
|
|
color='blue', alpha=0.8, label='Low Amplitude Mask')
|
|
ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
|
|
color='orange', alpha=0.8, label='Movement Mask')
|
|
ax2.plot(time_axis, event_mask["Resp_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3,
|
|
color='green', alpha=0.8, label='Amplitude Change Mask')
|
|
for event_code in event_codes:
|
|
sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code
|
|
score_mask = event_mask["SA_Score_Alpha"][segment_start:segment_end].repeat(signal_fs)
|
|
y = (sa_mask * score_mask).astype(float)
|
|
np.place(y, y == 0, np.nan)
|
|
ax2.plot(time_axis, y, color=utils.ColorCycle[event_code])
|
|
ax2.set_ylim(-4, 5)
|
|
elif multi_labels == "bcg" and event_codes is not None:
|
|
# 建立第二个y轴坐标
|
|
ax2 = ax.twinx()
|
|
ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
|
|
color='blue', alpha=0.8, label='Low Amplitude Mask')
|
|
ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
|
|
color='orange', alpha=0.8, label='Movement Mask')
|
|
ax2.plot(time_axis, event_mask["BCG_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3,
|
|
color='green', alpha=0.8, label='Amplitude Change Mask')
|
|
|
|
ax2.set_ylim(-4, 4)
|
|
|
|
ax.set_ylabel("Amplitude")
|
|
ax.legend(loc=1)
|
|
|
|
|
|
def plt_stage_on_ax(ax, stage_data, segment_start, segment_end):
|
|
stage_signal = stage_data["data"]
|
|
stage_fs = stage_data["fs"]
|
|
time_axis = np.linspace(segment_start / stage_fs, segment_end / stage_fs, segment_end - segment_start)
|
|
ax.plot(time_axis, stage_signal[segment_start:segment_end], color='black', label=stage_data["name"])
|
|
ax.set_ylim(0, 6)
|
|
ax.set_yticks([1, 2, 3, 4, 5])
|
|
ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"])
|
|
ax.set_ylabel("Stage")
|
|
ax.legend(loc=1)
|
|
|
|
|
|
def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end):
|
|
spo2_signal = spo2_data["data"]
|
|
spo2_fs = spo2_data["fs"]
|
|
time_axis = np.linspace(segment_start / spo2_fs, segment_end / spo2_fs, segment_end - segment_start)
|
|
ax.plot(time_axis, spo2_signal[segment_start:segment_end], color='black', label=spo2_data["name"])
|
|
|
|
if spo2_signal[segment_start:segment_end].min() < 85:
|
|
ax.set_ylim((spo2_signal[segment_start:segment_end].min() - 5, 100))
|
|
else:
|
|
ax.set_ylim((85, 100))
|
|
ax.set_ylabel("SpO2 (%)")
|
|
ax.legend(loc=1)
|
|
|
|
|
|
def score_mask2alpha(score_mask):
|
|
alpha_mask = np.zeros_like(score_mask, dtype=float)
|
|
alpha_mask[score_mask <= 0] = 0
|
|
alpha_mask[score_mask == 1] = 0.9
|
|
alpha_mask[score_mask == 2] = 0.6
|
|
alpha_mask[score_mask == 3] = 0.1
|
|
return alpha_mask
|
|
|
|
|
|
def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list):
|
|
for mask in event_mask.keys():
|
|
if mask.startswith("Resp_") or mask.endswith("BCG_"):
|
|
event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0)
|
|
|
|
event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"])
|
|
event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0)
|
|
|
|
fig, axes = create_psg_bcg_figure()
|
|
for segment_start, segment_end in segment_list:
|
|
print(f"Drawing segment: {segment_start} to {segment_end} seconds")
|
|
for ax in axes:
|
|
ax.cla()
|
|
|
|
plt_spo2_on_ax(axes[psg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end)
|
|
plt_stage_on_ax(axes[psg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end)
|
|
plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end,
|
|
psg_label, event_codes=[1, 2, 3, 4])
|
|
plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end,
|
|
psg_label, event_codes=[1, 2, 3, 4])
|
|
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end,
|
|
psg_label, event_codes=[1, 2, 3, 4])
|
|
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end,
|
|
psg_label, event_codes=[1, 2, 3, 4])
|
|
plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
|
|
plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
|
|
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4])
|
|
plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
|
|
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4])
|
|
plt.show()
|
|
print(f"Finished drawing segment: {segment_start} to {segment_end} seconds")
|
|
|
|
|
|
def draw_resp_label(resp_data, resp_label, segment_list):
|
|
for mask in resp_label.keys():
|
|
if mask.startswith("Resp_"):
|
|
resp_label[mask] = utils.none_to_nan_mask(resp_label[mask], 0)
|
|
|
|
resp_label["Resp_Score_Alpha"] = score_mask2alpha(resp_label["Resp_Score"])
|
|
resp_label["Resp_Label_Alpha"] = utils.none_to_nan_mask(resp_label["Resp_Label_Alpha"], 0)
|
|
|
|
fig, axes = create_resp_figure()
|
|
for segment_start, segment_end in segment_list:
|
|
for ax in axes:
|
|
ax.cla()
|
|
|
|
plt_signal_label_on_ax(axes[resp_chn_name2ax["resp"]], resp_data["resp_signal"], segment_start, segment_end,
|
|
resp_label, multi_labels="resp", event_codes=[1, 2, 3, 4])
|
|
plt_signal_label_on_ax(axes[resp_chn_name2ax["bcg"]], resp_data["bcg_signal"], segment_start, segment_end,
|
|
resp_label, multi_labels="bcg", event_codes=[1, 2, 3, 4])
|
|
plt.show()
|