DataPrepare/draw_tools/draw_label.py

338 lines
15 KiB
Python

from matplotlib.axes import Axes
from matplotlib.gridspec import GridSpec
from matplotlib.colors import PowerNorm
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import numpy as np
from tqdm.rich import tqdm
import utils
import gc
# 添加with_prediction参数
psg_bcg_chn_name2ax = {
"SpO2": 0,
"Flow T": 1,
"Flow P": 2,
"Effort Tho": 3,
"Effort Abd": 4,
"HR": 5,
"resp": 6,
"bcg": 7,
"Stage": 8,
"resp_twinx": 9,
"bcg_twinx": 10,
}
psg_chn_name2ax = {
"SpO2": 0,
"Flow T": 1,
"Flow P": 2,
"Effort Tho": 3,
"Effort Abd": 4,
"Effort": 5,
"HR": 6,
"RRI": 7,
"Stage": 8,
}
resp_chn_name2ax = {
"resp": 0,
"bcg": 1,
}
def create_psg_bcg_figure():
fig = plt.figure(figsize=(12, 8), dpi=200)
gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1])
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
axes = []
for i in range(9):
ax = fig.add_subplot(gs[i])
axes.append(ax)
axes[psg_bcg_chn_name2ax["SpO2"]].grid(True)
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["SpO2"]].set_ylim((85, 100))
axes[psg_bcg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["Flow T"]].grid(True)
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["Flow P"]].grid(True)
# axes[2].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["Effort Tho"]].grid(True)
# axes[3].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["Effort Abd"]].grid(True)
# axes[4].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["HR"]].grid(True)
axes[psg_bcg_chn_name2ax["HR"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["resp"]].grid(True)
axes[psg_bcg_chn_name2ax["resp"]].tick_params(axis='x', colors="white")
axes.append(axes[psg_bcg_chn_name2ax["resp"]].twinx())
axes[psg_bcg_chn_name2ax["bcg"]].grid(True)
# axes[5].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white")
axes.append(axes[psg_bcg_chn_name2ax["bcg"]].twinx())
axes[psg_bcg_chn_name2ax["Stage"]].grid(True)
# axes[7].xaxis.set_major_formatter(Params.FORMATTER)
return fig, axes
def create_psg_figure():
fig = plt.figure(figsize=(12, 8), dpi=200)
gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 1, 1, 1])
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
axes = []
for i in range(9):
ax = fig.add_subplot(gs[i])
axes.append(ax)
axes[psg_chn_name2ax["SpO2"]].grid(True)
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_chn_name2ax["SpO2"]].set_ylim((85, 100))
axes[psg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Flow T"]].grid(True)
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Flow P"]].grid(True)
# axes[2].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Effort Tho"]].grid(True)
# axes[3].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Effort Abd"]].grid(True)
# axes[4].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Effort"]].grid(True)
axes[psg_chn_name2ax["Effort"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["HR"]].grid(True)
axes[psg_chn_name2ax["HR"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["RRI"]].grid(True)
axes[psg_chn_name2ax["RRI"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Stage"]].grid(True)
return fig, axes
def create_resp_figure():
fig = plt.figure(figsize=(12, 6), dpi=100)
gs = GridSpec(2, 1, height_ratios=[3, 2])
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
axes = []
for i in range(2):
ax = fig.add_subplot(gs[i])
axes.append(ax)
axes[0].grid(True)
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
axes[0].tick_params(axis='x', colors="white")
axes[1].grid(True)
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
axes[1].tick_params(axis='x', colors="white")
return fig, axes
def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None,
event_codes: list[int] = None, multi_labels=None, ax2: Axes = None):
signal_fs = signal_data["fs"]
chn_signal = signal_data["data"]
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs)
ax.plot(time_axis, chn_signal[segment_start * signal_fs:segment_end * signal_fs], color='black',
label=signal_data["name"])
if event_mask is not None:
if multi_labels is None and event_codes is not None:
for event_code in event_codes:
mask = event_mask[segment_start:segment_end].repeat(signal_fs) == event_code
y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * mask).astype(float)
np.place(y, y == 0, np.nan)
ax.plot(time_axis, y, color=utils.ColorCycle[event_code])
elif multi_labels == "resp" and event_codes is not None:
ax.set_ylim(-6, 6)
# 建立第二个y轴坐标
ax2.cla()
ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
color='blue', alpha=0.8, label='Low Amplitude Mask')
ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
color='orange', alpha=0.8, label='Movement Mask')
ax2.plot(time_axis, event_mask["Resp_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3,
color='green', alpha=0.8, label='Amplitude Change Mask')
for event_code in event_codes:
sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code
score_mask = event_mask["SA_Score"][segment_start:segment_end].repeat(signal_fs)
# y = (sa_mask * score_mask).astype(float)
y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * sa_mask).astype(float)
np.place(y, y == 0, np.nan)
ax.plot(time_axis, y, color=utils.ColorCycle[event_code])
ax2.plot(time_axis, score_mask, color="orange")
ax2.set_ylim(-4, 5)
elif multi_labels == "bcg" and event_codes is not None:
# 建立第二个y轴坐标
ax2.cla()
ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
color='blue', alpha=0.8, label='Low Amplitude Mask')
ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
color='orange', alpha=0.8, label='Movement Mask')
ax2.plot(time_axis, event_mask["BCG_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3,
color='green', alpha=0.8, label='Amplitude Change Mask')
ax2.set_ylim(-4, 4)
ax.set_ylabel("Amplitude")
ax.legend(loc=1)
def plt_stage_on_ax(ax, stage_data, segment_start, segment_end):
stage_signal = stage_data["data"]
stage_fs = stage_data["fs"]
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * stage_fs)
ax.plot(time_axis, stage_signal[segment_start * stage_fs:segment_end * stage_fs], color='black', label=stage_data["name"])
ax.set_ylim(0, 6)
ax.set_yticks([1, 2, 3, 4, 5])
ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"])
ax.set_ylabel("Stage")
ax.legend(loc=1)
def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end):
spo2_signal = spo2_data["data"]
spo2_fs = spo2_data["fs"]
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * spo2_fs)
ax.plot(time_axis, spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs], color='black', label=spo2_data["name"])
if spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() < 85:
ax.set_ylim((spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() - 5, 100))
else:
ax.set_ylim((85, 100))
ax.set_ylabel("SpO2 (%)")
ax.legend(loc=1)
def score_mask2alpha(score_mask):
alpha_mask = np.zeros_like(score_mask, dtype=float)
alpha_mask[score_mask <= 0] = 0
alpha_mask[score_mask == 1] = 0.9
alpha_mask[score_mask == 2] = 0.6
alpha_mask[score_mask == 3] = 0.1
return alpha_mask
def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None, verbose=True,
multi_p=None, multi_task_id=None):
if save_path is not None:
save_path.mkdir(parents=True, exist_ok=True)
for mask in event_mask.keys():
if mask.startswith("Resp_") or mask.startswith("BCG_"):
event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0)
event_mask["SA_Score"] = utils.none_to_nan_mask(event_mask["SA_Score"], 0)
# event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"])
# event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0)
fig, axes = create_psg_bcg_figure()
for i, (segment_start, segment_end) in enumerate(segment_list):
for ax in axes:
ax.cla()
plt_spo2_on_ax(axes[psg_bcg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end)
plt_stage_on_ax(axes[psg_bcg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4],
ax2=axes[psg_bcg_chn_name2ax["resp_twinx"]])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4],
ax2=axes[psg_bcg_chn_name2ax["bcg_twinx"]])
if save_path is not None:
fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png")
# print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
if multi_p is not None:
multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"}
plt.close(fig)
plt.close('all')
gc.collect()
def draw_psg_label(psg_data, psg_label, segment_list, save_path=None, verbose=True,
multi_p=None, multi_task_id=None):
if save_path is not None:
save_path.mkdir(parents=True, exist_ok=True)
if multi_p is None:
# 遍历psg_data中所有数据的长度
for i in range(len(psg_data.keys())):
chn_name = list(psg_data.keys())[i]
print(f"{chn_name} data length: {len(psg_data[chn_name]['data'])}, fs: {psg_data[chn_name]['fs']}")
# psg_label的长度
print(f"psg_label length: {len(psg_label)}")
fig, axes = create_psg_figure()
for i, (segment_start, segment_end) in enumerate(segment_list):
for ax in axes:
ax.cla()
plt_spo2_on_ax(axes[psg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end)
plt_stage_on_ax(axes[psg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort"]], psg_data["Effort"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_chn_name2ax["RRI"]], psg_data["RRI"], segment_start, segment_end)
if save_path is not None:
fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png")
# print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
if multi_p is not None:
multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"}
plt.close(fig)
plt.close('all')
gc.collect()