优化绘图功能,增加双y轴支持并调整图像保存路径

This commit is contained in:
marques 2025-11-17 08:05:42 +08:00
parent ed4205f5b8
commit 19d476d489
3 changed files with 59 additions and 41 deletions

View File

@ -130,7 +130,8 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
psg_label=psg_event_mask, psg_label=psg_event_mask,
bcg_data=bcg_data, bcg_data=bcg_data,
event_mask=event_mask, event_mask=event_mask,
segment_list=segment_list) segment_list=segment_list,
save_path=visual_path / f"{samp_id}")
if __name__ == '__main__': if __name__ == '__main__':
@ -141,8 +142,11 @@ if __name__ == '__main__':
root_path = Path(conf["root_path"]) root_path = Path(conf["root_path"])
mask_path = Path(conf["mask_save_path"]) mask_path = Path(conf["mask_save_path"])
save_path = Path(conf["dataset_config"]["dataset_save_path"]) save_path = Path(conf["dataset_config"]["dataset_save_path"])
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
dataset_config = conf["dataset_config"] dataset_config = conf["dataset_config"]
visual_path.mkdir(parents=True, exist_ok=True)
save_processed_signal_path = save_path / "Signals" save_processed_signal_path = save_path / "Signals"
save_processed_signal_path.mkdir(parents=True, exist_ok=True) save_processed_signal_path.mkdir(parents=True, exist_ok=True)
@ -155,12 +159,13 @@ if __name__ == '__main__':
print(f"select_ids: {select_ids}") print(f"select_ids: {select_ids}")
print(f"root_path: {root_path}") print(f"root_path: {root_path}")
print(f"save_path: {save_path}") print(f"save_path: {save_path}")
print(f"visual_path: {visual_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned" org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned" psg_signal_root_path = root_path / "PSG_Aligned"
build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True) # build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True)
#
# for samp_id in select_ids: for samp_id in select_ids:
# print(f"Processing sample ID: {samp_id}") print(f"Processing sample ID: {samp_id}")
# build_HYS_dataset_segment(samp_id, show=False) build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)

View File

@ -83,3 +83,4 @@ dataset_config:
window_sec: 180 window_sec: 180
stride_sec: 60 stride_sec: 60
dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset
dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset/visualization

View File

@ -5,7 +5,7 @@ import matplotlib.pyplot as plt
import matplotlib.patches as mpatches import matplotlib.patches as mpatches
import seaborn as sns import seaborn as sns
import numpy as np import numpy as np
from tqdm.rich import tqdm
import utils import utils
# 添加with_prediction参数 # 添加with_prediction参数
@ -19,7 +19,9 @@ psg_chn_name2ax = {
"HR": 5, "HR": 5,
"resp": 6, "resp": 6,
"bcg": 7, "bcg": 7,
"Stage": 8 "Stage": 8,
"resp_twinx": 9,
"bcg_twinx": 10,
} }
resp_chn_name2ax = { resp_chn_name2ax = {
@ -29,7 +31,7 @@ resp_chn_name2ax = {
def create_psg_bcg_figure(): def create_psg_bcg_figure():
fig = plt.figure(figsize=(12, 8), dpi=100) fig = plt.figure(figsize=(12, 8), dpi=200)
gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1]) gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1])
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
axes = [] axes = []
@ -37,39 +39,41 @@ def create_psg_bcg_figure():
ax = fig.add_subplot(gs[i]) ax = fig.add_subplot(gs[i])
axes.append(ax) axes.append(ax)
axes[0].grid(True) axes[psg_chn_name2ax["SpO2"]].grid(True)
# axes[0].xaxis.set_major_formatter(Params.FORMATTER) # axes[0].xaxis.set_major_formatter(Params.FORMATTER)
axes[0].set_ylim((85, 100)) axes[psg_chn_name2ax["SpO2"]].set_ylim((85, 100))
axes[0].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white")
axes[1].grid(True) axes[psg_chn_name2ax["Flow T"]].grid(True)
# axes[1].xaxis.set_major_formatter(Params.FORMATTER) # axes[1].xaxis.set_major_formatter(Params.FORMATTER)
axes[1].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white")
axes[2].grid(True) axes[psg_chn_name2ax["Flow P"]].grid(True)
# axes[2].xaxis.set_major_formatter(Params.FORMATTER) # axes[2].xaxis.set_major_formatter(Params.FORMATTER)
axes[2].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white")
axes[3].grid(True) axes[psg_chn_name2ax["Effort Tho"]].grid(True)
# axes[3].xaxis.set_major_formatter(Params.FORMATTER) # axes[3].xaxis.set_major_formatter(Params.FORMATTER)
axes[3].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white")
axes[4].grid(True) axes[psg_chn_name2ax["Effort Abd"]].grid(True)
# axes[4].xaxis.set_major_formatter(Params.FORMATTER) # axes[4].xaxis.set_major_formatter(Params.FORMATTER)
axes[4].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white")
axes[5].grid(True) axes[psg_chn_name2ax["HR"]].grid(True)
axes[5].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["HR"]].tick_params(axis='x', colors="white")
axes[6].grid(True) axes[psg_chn_name2ax["resp"]].grid(True)
axes[psg_chn_name2ax["resp"]].tick_params(axis='x', colors="white")
axes.append(axes[psg_chn_name2ax["resp"]].twinx())
axes[psg_chn_name2ax["bcg"]].grid(True)
# axes[5].xaxis.set_major_formatter(Params.FORMATTER) # axes[5].xaxis.set_major_formatter(Params.FORMATTER)
axes[6].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white")
axes.append(axes[psg_chn_name2ax["bcg"]].twinx())
axes[7].grid(True)
# axes[6].xaxis.set_major_formatter(Params.FORMATTER)
axes[7].tick_params(axis='x', colors="white")
axes[8].grid(True) axes[psg_chn_name2ax["Stage"]].grid(True)
# axes[7].xaxis.set_major_formatter(Params.FORMATTER) # axes[7].xaxis.set_major_formatter(Params.FORMATTER)
return fig, axes return fig, axes
@ -96,7 +100,7 @@ def create_resp_figure():
def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None, def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None,
event_codes: list[int] = None, multi_labels=None): event_codes: list[int] = None, multi_labels=None, ax2: Axes = None):
signal_fs = signal_data["fs"] signal_fs = signal_data["fs"]
chn_signal = signal_data["data"] chn_signal = signal_data["data"]
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs) time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs)
@ -112,7 +116,7 @@ def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, ev
elif multi_labels == "resp" and event_codes is not None: elif multi_labels == "resp" and event_codes is not None:
ax.set_ylim(-6, 6) ax.set_ylim(-6, 6)
# 建立第二个y轴坐标 # 建立第二个y轴坐标
ax2 = ax.twinx() ax2.cla()
ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
color='blue', alpha=0.8, label='Low Amplitude Mask') color='blue', alpha=0.8, label='Low Amplitude Mask')
ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
@ -122,13 +126,15 @@ def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, ev
for event_code in event_codes: for event_code in event_codes:
sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code
score_mask = event_mask["SA_Score_Alpha"][segment_start:segment_end].repeat(signal_fs) score_mask = event_mask["SA_Score_Alpha"][segment_start:segment_end].repeat(signal_fs)
y = (sa_mask * score_mask).astype(float) # y = (sa_mask * score_mask).astype(float)
y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * sa_mask).astype(float)
np.place(y, y == 0, np.nan) np.place(y, y == 0, np.nan)
ax2.plot(time_axis, y, color=utils.ColorCycle[event_code]) ax.plot(time_axis, y, color=utils.ColorCycle[event_code])
ax2.plot(time_axis, score_mask, color="orange")
ax2.set_ylim(-4, 5) ax2.set_ylim(-4, 5)
elif multi_labels == "bcg" and event_codes is not None: elif multi_labels == "bcg" and event_codes is not None:
# 建立第二个y轴坐标 # 建立第二个y轴坐标
ax2 = ax.twinx() ax2.cla()
ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
color='blue', alpha=0.8, label='Low Amplitude Mask') color='blue', alpha=0.8, label='Low Amplitude Mask')
ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
@ -177,17 +183,19 @@ def score_mask2alpha(score_mask):
return alpha_mask return alpha_mask
def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list): def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None):
if save_path is not None:
save_path.mkdir(parents=True, exist_ok=True)
for mask in event_mask.keys(): for mask in event_mask.keys():
if mask.startswith("Resp_") or mask.endswith("BCG_"): if mask.startswith("Resp_") or mask.startswith("BCG_"):
event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0) event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0)
event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"]) event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"])
event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0) event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0)
fig, axes = create_psg_bcg_figure() fig, axes = create_psg_bcg_figure()
for segment_start, segment_end in segment_list: for segment_start, segment_end in tqdm(segment_list):
print(f"Drawing segment: {segment_start} to {segment_end} seconds")
for ax in axes: for ax in axes:
ax.cla() ax.cla()
@ -203,13 +211,17 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list):
psg_label, event_codes=[1, 2, 3, 4]) psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end, plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4]) event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["resp_twinx"]])
plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end, plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4]) event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["bcg_twinx"]])
plt.show()
print(f"Finished drawing segment: {segment_start} to {segment_end} seconds")
if save_path is not None:
fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png")
tqdm.write(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
# print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
def draw_resp_label(resp_data, resp_label, segment_list): def draw_resp_label(resp_data, resp_label, segment_list):
for mask in resp_label.keys(): for mask in resp_label.keys():
if mask.startswith("Resp_"): if mask.startswith("Resp_"):