优化绘图功能,增加双y轴支持并调整图像保存路径

This commit is contained in:
marques 2025-11-17 08:05:42 +08:00
parent ed4205f5b8
commit 19d476d489
3 changed files with 59 additions and 41 deletions

View File

@ -130,7 +130,8 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
psg_label=psg_event_mask,
bcg_data=bcg_data,
event_mask=event_mask,
segment_list=segment_list)
segment_list=segment_list,
save_path=visual_path / f"{samp_id}")
if __name__ == '__main__':
@ -141,8 +142,11 @@ if __name__ == '__main__':
root_path = Path(conf["root_path"])
mask_path = Path(conf["mask_save_path"])
save_path = Path(conf["dataset_config"]["dataset_save_path"])
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
dataset_config = conf["dataset_config"]
visual_path.mkdir(parents=True, exist_ok=True)
save_processed_signal_path = save_path / "Signals"
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
@ -155,12 +159,13 @@ if __name__ == '__main__':
print(f"select_ids: {select_ids}")
print(f"root_path: {root_path}")
print(f"save_path: {save_path}")
print(f"visual_path: {visual_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned"
build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True)
#
# for samp_id in select_ids:
# print(f"Processing sample ID: {samp_id}")
# build_HYS_dataset_segment(samp_id, show=False)
# build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True)
for samp_id in select_ids:
print(f"Processing sample ID: {samp_id}")
build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)

View File

@ -83,3 +83,4 @@ dataset_config:
window_sec: 180
stride_sec: 60
dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset
dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset/visualization

View File

@ -5,7 +5,7 @@ import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import numpy as np
from tqdm.rich import tqdm
import utils
# 添加with_prediction参数
@ -19,7 +19,9 @@ psg_chn_name2ax = {
"HR": 5,
"resp": 6,
"bcg": 7,
"Stage": 8
"Stage": 8,
"resp_twinx": 9,
"bcg_twinx": 10,
}
resp_chn_name2ax = {
@ -29,7 +31,7 @@ resp_chn_name2ax = {
def create_psg_bcg_figure():
fig = plt.figure(figsize=(12, 8), dpi=100)
fig = plt.figure(figsize=(12, 8), dpi=200)
gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1])
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
axes = []
@ -37,39 +39,41 @@ def create_psg_bcg_figure():
ax = fig.add_subplot(gs[i])
axes.append(ax)
axes[0].grid(True)
axes[psg_chn_name2ax["SpO2"]].grid(True)
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
axes[0].set_ylim((85, 100))
axes[0].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["SpO2"]].set_ylim((85, 100))
axes[psg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white")
axes[1].grid(True)
axes[psg_chn_name2ax["Flow T"]].grid(True)
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
axes[1].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white")
axes[2].grid(True)
axes[psg_chn_name2ax["Flow P"]].grid(True)
# axes[2].xaxis.set_major_formatter(Params.FORMATTER)
axes[2].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white")
axes[3].grid(True)
axes[psg_chn_name2ax["Effort Tho"]].grid(True)
# axes[3].xaxis.set_major_formatter(Params.FORMATTER)
axes[3].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white")
axes[4].grid(True)
axes[psg_chn_name2ax["Effort Abd"]].grid(True)
# axes[4].xaxis.set_major_formatter(Params.FORMATTER)
axes[4].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white")
axes[5].grid(True)
axes[5].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["HR"]].grid(True)
axes[psg_chn_name2ax["HR"]].tick_params(axis='x', colors="white")
axes[6].grid(True)
axes[psg_chn_name2ax["resp"]].grid(True)
axes[psg_chn_name2ax["resp"]].tick_params(axis='x', colors="white")
axes.append(axes[psg_chn_name2ax["resp"]].twinx())
axes[psg_chn_name2ax["bcg"]].grid(True)
# axes[5].xaxis.set_major_formatter(Params.FORMATTER)
axes[6].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white")
axes.append(axes[psg_chn_name2ax["bcg"]].twinx())
axes[7].grid(True)
# axes[6].xaxis.set_major_formatter(Params.FORMATTER)
axes[7].tick_params(axis='x', colors="white")
axes[8].grid(True)
axes[psg_chn_name2ax["Stage"]].grid(True)
# axes[7].xaxis.set_major_formatter(Params.FORMATTER)
return fig, axes
@ -96,7 +100,7 @@ def create_resp_figure():
def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None,
event_codes: list[int] = None, multi_labels=None):
event_codes: list[int] = None, multi_labels=None, ax2: Axes = None):
signal_fs = signal_data["fs"]
chn_signal = signal_data["data"]
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs)
@ -112,7 +116,7 @@ def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, ev
elif multi_labels == "resp" and event_codes is not None:
ax.set_ylim(-6, 6)
# 建立第二个y轴坐标
ax2 = ax.twinx()
ax2.cla()
ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
color='blue', alpha=0.8, label='Low Amplitude Mask')
ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
@ -122,13 +126,15 @@ def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, ev
for event_code in event_codes:
sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code
score_mask = event_mask["SA_Score_Alpha"][segment_start:segment_end].repeat(signal_fs)
y = (sa_mask * score_mask).astype(float)
# y = (sa_mask * score_mask).astype(float)
y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * sa_mask).astype(float)
np.place(y, y == 0, np.nan)
ax2.plot(time_axis, y, color=utils.ColorCycle[event_code])
ax.plot(time_axis, y, color=utils.ColorCycle[event_code])
ax2.plot(time_axis, score_mask, color="orange")
ax2.set_ylim(-4, 5)
elif multi_labels == "bcg" and event_codes is not None:
# 建立第二个y轴坐标
ax2 = ax.twinx()
ax2.cla()
ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
color='blue', alpha=0.8, label='Low Amplitude Mask')
ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
@ -177,17 +183,19 @@ def score_mask2alpha(score_mask):
return alpha_mask
def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list):
def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None):
if save_path is not None:
save_path.mkdir(parents=True, exist_ok=True)
for mask in event_mask.keys():
if mask.startswith("Resp_") or mask.endswith("BCG_"):
if mask.startswith("Resp_") or mask.startswith("BCG_"):
event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0)
event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"])
event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0)
fig, axes = create_psg_bcg_figure()
for segment_start, segment_end in segment_list:
print(f"Drawing segment: {segment_start} to {segment_end} seconds")
for segment_start, segment_end in tqdm(segment_list):
for ax in axes:
ax.cla()
@ -203,13 +211,17 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list):
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4])
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["resp_twinx"]])
plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4])
plt.show()
print(f"Finished drawing segment: {segment_start} to {segment_end} seconds")
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["bcg_twinx"]])
if save_path is not None:
fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png")
tqdm.write(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
# print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
def draw_resp_label(resp_data, resp_label, segment_list):
for mask in resp_label.keys():
if mask.startswith("Resp_"):