优化绘图功能,增加双y轴支持并调整图像保存路径
This commit is contained in:
parent
ed4205f5b8
commit
19d476d489
@ -130,7 +130,8 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
|
|||||||
psg_label=psg_event_mask,
|
psg_label=psg_event_mask,
|
||||||
bcg_data=bcg_data,
|
bcg_data=bcg_data,
|
||||||
event_mask=event_mask,
|
event_mask=event_mask,
|
||||||
segment_list=segment_list)
|
segment_list=segment_list,
|
||||||
|
save_path=visual_path / f"{samp_id}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -141,8 +142,11 @@ if __name__ == '__main__':
|
|||||||
root_path = Path(conf["root_path"])
|
root_path = Path(conf["root_path"])
|
||||||
mask_path = Path(conf["mask_save_path"])
|
mask_path = Path(conf["mask_save_path"])
|
||||||
save_path = Path(conf["dataset_config"]["dataset_save_path"])
|
save_path = Path(conf["dataset_config"]["dataset_save_path"])
|
||||||
|
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
|
||||||
dataset_config = conf["dataset_config"]
|
dataset_config = conf["dataset_config"]
|
||||||
|
|
||||||
|
visual_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
save_processed_signal_path = save_path / "Signals"
|
save_processed_signal_path = save_path / "Signals"
|
||||||
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
|
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@ -155,12 +159,13 @@ if __name__ == '__main__':
|
|||||||
print(f"select_ids: {select_ids}")
|
print(f"select_ids: {select_ids}")
|
||||||
print(f"root_path: {root_path}")
|
print(f"root_path: {root_path}")
|
||||||
print(f"save_path: {save_path}")
|
print(f"save_path: {save_path}")
|
||||||
|
print(f"visual_path: {visual_path}")
|
||||||
|
|
||||||
org_signal_root_path = root_path / "OrgBCG_Aligned"
|
org_signal_root_path = root_path / "OrgBCG_Aligned"
|
||||||
psg_signal_root_path = root_path / "PSG_Aligned"
|
psg_signal_root_path = root_path / "PSG_Aligned"
|
||||||
|
|
||||||
build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True)
|
# build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True)
|
||||||
#
|
|
||||||
# for samp_id in select_ids:
|
for samp_id in select_ids:
|
||||||
# print(f"Processing sample ID: {samp_id}")
|
print(f"Processing sample ID: {samp_id}")
|
||||||
# build_HYS_dataset_segment(samp_id, show=False)
|
build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
|
||||||
@ -83,3 +83,4 @@ dataset_config:
|
|||||||
window_sec: 180
|
window_sec: 180
|
||||||
stride_sec: 60
|
stride_sec: 60
|
||||||
dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset
|
dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset
|
||||||
|
dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset/visualization
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import matplotlib.pyplot as plt
|
|||||||
import matplotlib.patches as mpatches
|
import matplotlib.patches as mpatches
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tqdm.rich import tqdm
|
||||||
import utils
|
import utils
|
||||||
|
|
||||||
# 添加with_prediction参数
|
# 添加with_prediction参数
|
||||||
@ -19,7 +19,9 @@ psg_chn_name2ax = {
|
|||||||
"HR": 5,
|
"HR": 5,
|
||||||
"resp": 6,
|
"resp": 6,
|
||||||
"bcg": 7,
|
"bcg": 7,
|
||||||
"Stage": 8
|
"Stage": 8,
|
||||||
|
"resp_twinx": 9,
|
||||||
|
"bcg_twinx": 10,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp_chn_name2ax = {
|
resp_chn_name2ax = {
|
||||||
@ -29,7 +31,7 @@ resp_chn_name2ax = {
|
|||||||
|
|
||||||
|
|
||||||
def create_psg_bcg_figure():
|
def create_psg_bcg_figure():
|
||||||
fig = plt.figure(figsize=(12, 8), dpi=100)
|
fig = plt.figure(figsize=(12, 8), dpi=200)
|
||||||
gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1])
|
gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1])
|
||||||
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
|
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
|
||||||
axes = []
|
axes = []
|
||||||
@ -37,39 +39,41 @@ def create_psg_bcg_figure():
|
|||||||
ax = fig.add_subplot(gs[i])
|
ax = fig.add_subplot(gs[i])
|
||||||
axes.append(ax)
|
axes.append(ax)
|
||||||
|
|
||||||
axes[0].grid(True)
|
axes[psg_chn_name2ax["SpO2"]].grid(True)
|
||||||
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
|
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
|
||||||
axes[0].set_ylim((85, 100))
|
axes[psg_chn_name2ax["SpO2"]].set_ylim((85, 100))
|
||||||
axes[0].tick_params(axis='x', colors="white")
|
axes[psg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white")
|
||||||
|
|
||||||
axes[1].grid(True)
|
axes[psg_chn_name2ax["Flow T"]].grid(True)
|
||||||
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
|
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
|
||||||
axes[1].tick_params(axis='x', colors="white")
|
axes[psg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white")
|
||||||
|
|
||||||
axes[2].grid(True)
|
axes[psg_chn_name2ax["Flow P"]].grid(True)
|
||||||
# axes[2].xaxis.set_major_formatter(Params.FORMATTER)
|
# axes[2].xaxis.set_major_formatter(Params.FORMATTER)
|
||||||
axes[2].tick_params(axis='x', colors="white")
|
axes[psg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white")
|
||||||
|
|
||||||
axes[3].grid(True)
|
axes[psg_chn_name2ax["Effort Tho"]].grid(True)
|
||||||
# axes[3].xaxis.set_major_formatter(Params.FORMATTER)
|
# axes[3].xaxis.set_major_formatter(Params.FORMATTER)
|
||||||
axes[3].tick_params(axis='x', colors="white")
|
axes[psg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white")
|
||||||
|
|
||||||
axes[4].grid(True)
|
axes[psg_chn_name2ax["Effort Abd"]].grid(True)
|
||||||
# axes[4].xaxis.set_major_formatter(Params.FORMATTER)
|
# axes[4].xaxis.set_major_formatter(Params.FORMATTER)
|
||||||
axes[4].tick_params(axis='x', colors="white")
|
axes[psg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white")
|
||||||
|
|
||||||
axes[5].grid(True)
|
axes[psg_chn_name2ax["HR"]].grid(True)
|
||||||
axes[5].tick_params(axis='x', colors="white")
|
axes[psg_chn_name2ax["HR"]].tick_params(axis='x', colors="white")
|
||||||
|
|
||||||
axes[6].grid(True)
|
axes[psg_chn_name2ax["resp"]].grid(True)
|
||||||
|
axes[psg_chn_name2ax["resp"]].tick_params(axis='x', colors="white")
|
||||||
|
axes.append(axes[psg_chn_name2ax["resp"]].twinx())
|
||||||
|
|
||||||
|
axes[psg_chn_name2ax["bcg"]].grid(True)
|
||||||
# axes[5].xaxis.set_major_formatter(Params.FORMATTER)
|
# axes[5].xaxis.set_major_formatter(Params.FORMATTER)
|
||||||
axes[6].tick_params(axis='x', colors="white")
|
axes[psg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white")
|
||||||
|
axes.append(axes[psg_chn_name2ax["bcg"]].twinx())
|
||||||
|
|
||||||
axes[7].grid(True)
|
|
||||||
# axes[6].xaxis.set_major_formatter(Params.FORMATTER)
|
|
||||||
axes[7].tick_params(axis='x', colors="white")
|
|
||||||
|
|
||||||
axes[8].grid(True)
|
axes[psg_chn_name2ax["Stage"]].grid(True)
|
||||||
# axes[7].xaxis.set_major_formatter(Params.FORMATTER)
|
# axes[7].xaxis.set_major_formatter(Params.FORMATTER)
|
||||||
|
|
||||||
return fig, axes
|
return fig, axes
|
||||||
@ -96,7 +100,7 @@ def create_resp_figure():
|
|||||||
|
|
||||||
|
|
||||||
def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None,
|
def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None,
|
||||||
event_codes: list[int] = None, multi_labels=None):
|
event_codes: list[int] = None, multi_labels=None, ax2: Axes = None):
|
||||||
signal_fs = signal_data["fs"]
|
signal_fs = signal_data["fs"]
|
||||||
chn_signal = signal_data["data"]
|
chn_signal = signal_data["data"]
|
||||||
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs)
|
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs)
|
||||||
@ -112,7 +116,7 @@ def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, ev
|
|||||||
elif multi_labels == "resp" and event_codes is not None:
|
elif multi_labels == "resp" and event_codes is not None:
|
||||||
ax.set_ylim(-6, 6)
|
ax.set_ylim(-6, 6)
|
||||||
# 建立第二个y轴坐标
|
# 建立第二个y轴坐标
|
||||||
ax2 = ax.twinx()
|
ax2.cla()
|
||||||
ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
|
ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
|
||||||
color='blue', alpha=0.8, label='Low Amplitude Mask')
|
color='blue', alpha=0.8, label='Low Amplitude Mask')
|
||||||
ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
|
ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
|
||||||
@ -122,13 +126,15 @@ def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, ev
|
|||||||
for event_code in event_codes:
|
for event_code in event_codes:
|
||||||
sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code
|
sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code
|
||||||
score_mask = event_mask["SA_Score_Alpha"][segment_start:segment_end].repeat(signal_fs)
|
score_mask = event_mask["SA_Score_Alpha"][segment_start:segment_end].repeat(signal_fs)
|
||||||
y = (sa_mask * score_mask).astype(float)
|
# y = (sa_mask * score_mask).astype(float)
|
||||||
|
y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * sa_mask).astype(float)
|
||||||
np.place(y, y == 0, np.nan)
|
np.place(y, y == 0, np.nan)
|
||||||
ax2.plot(time_axis, y, color=utils.ColorCycle[event_code])
|
ax.plot(time_axis, y, color=utils.ColorCycle[event_code])
|
||||||
|
ax2.plot(time_axis, score_mask, color="orange")
|
||||||
ax2.set_ylim(-4, 5)
|
ax2.set_ylim(-4, 5)
|
||||||
elif multi_labels == "bcg" and event_codes is not None:
|
elif multi_labels == "bcg" and event_codes is not None:
|
||||||
# 建立第二个y轴坐标
|
# 建立第二个y轴坐标
|
||||||
ax2 = ax.twinx()
|
ax2.cla()
|
||||||
ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
|
ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
|
||||||
color='blue', alpha=0.8, label='Low Amplitude Mask')
|
color='blue', alpha=0.8, label='Low Amplitude Mask')
|
||||||
ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
|
ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
|
||||||
@ -177,17 +183,19 @@ def score_mask2alpha(score_mask):
|
|||||||
return alpha_mask
|
return alpha_mask
|
||||||
|
|
||||||
|
|
||||||
def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list):
|
def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None):
|
||||||
|
if save_path is not None:
|
||||||
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
for mask in event_mask.keys():
|
for mask in event_mask.keys():
|
||||||
if mask.startswith("Resp_") or mask.endswith("BCG_"):
|
if mask.startswith("Resp_") or mask.startswith("BCG_"):
|
||||||
event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0)
|
event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0)
|
||||||
|
|
||||||
event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"])
|
event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"])
|
||||||
event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0)
|
event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0)
|
||||||
|
|
||||||
fig, axes = create_psg_bcg_figure()
|
fig, axes = create_psg_bcg_figure()
|
||||||
for segment_start, segment_end in segment_list:
|
for segment_start, segment_end in tqdm(segment_list):
|
||||||
print(f"Drawing segment: {segment_start} to {segment_end} seconds")
|
|
||||||
for ax in axes:
|
for ax in axes:
|
||||||
ax.cla()
|
ax.cla()
|
||||||
|
|
||||||
@ -203,13 +211,17 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list):
|
|||||||
psg_label, event_codes=[1, 2, 3, 4])
|
psg_label, event_codes=[1, 2, 3, 4])
|
||||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
|
plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
|
||||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
|
plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
|
||||||
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4])
|
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["resp_twinx"]])
|
||||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
|
plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
|
||||||
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4])
|
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["bcg_twinx"]])
|
||||||
plt.show()
|
|
||||||
print(f"Finished drawing segment: {segment_start} to {segment_end} seconds")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if save_path is not None:
|
||||||
|
fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png")
|
||||||
|
tqdm.write(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
|
||||||
|
# print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
|
||||||
|
|
||||||
def draw_resp_label(resp_data, resp_label, segment_list):
|
def draw_resp_label(resp_data, resp_label, segment_list):
|
||||||
for mask in resp_label.keys():
|
for mask in resp_label.keys():
|
||||||
if mask.startswith("Resp_"):
|
if mask.startswith("Resp_"):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user