diff --git a/dataset_builder/HYS_dataset.py b/dataset_builder/HYS_dataset.py index 50cca9a..4442936 100644 --- a/dataset_builder/HYS_dataset.py +++ b/dataset_builder/HYS_dataset.py @@ -130,7 +130,8 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False): psg_label=psg_event_mask, bcg_data=bcg_data, event_mask=event_mask, - segment_list=segment_list) + segment_list=segment_list, + save_path=visual_path / f"{samp_id}") if __name__ == '__main__': @@ -141,8 +142,11 @@ if __name__ == '__main__': root_path = Path(conf["root_path"]) mask_path = Path(conf["mask_save_path"]) save_path = Path(conf["dataset_config"]["dataset_save_path"]) + visual_path = Path(conf["dataset_config"]["dataset_visual_path"]) dataset_config = conf["dataset_config"] + visual_path.mkdir(parents=True, exist_ok=True) + save_processed_signal_path = save_path / "Signals" save_processed_signal_path.mkdir(parents=True, exist_ok=True) @@ -155,12 +159,13 @@ if __name__ == '__main__': print(f"select_ids: {select_ids}") print(f"root_path: {root_path}") print(f"save_path: {save_path}") + print(f"visual_path: {visual_path}") org_signal_root_path = root_path / "OrgBCG_Aligned" psg_signal_root_path = root_path / "PSG_Aligned" - build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True) - # - # for samp_id in select_ids: - # print(f"Processing sample ID: {samp_id}") - # build_HYS_dataset_segment(samp_id, show=False) \ No newline at end of file + # build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True) + + for samp_id in select_ids: + print(f"Processing sample ID: {samp_id}") + build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) \ No newline at end of file diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index a15e03d..80b8165 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -83,3 +83,4 @@ dataset_config: window_sec: 180 stride_sec: 60 dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset + dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset/visualization diff --git a/draw_tools/draw_label.py b/draw_tools/draw_label.py index 605de63..134202e 100644 --- a/draw_tools/draw_label.py +++ b/draw_tools/draw_label.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt import matplotlib.patches as mpatches import seaborn as sns import numpy as np - +from tqdm.rich import tqdm import utils # 添加with_prediction参数 @@ -19,7 +19,9 @@ psg_chn_name2ax = { "HR": 5, "resp": 6, "bcg": 7, - "Stage": 8 + "Stage": 8, + "resp_twinx": 9, + "bcg_twinx": 10, } resp_chn_name2ax = { @@ -29,7 +31,7 @@ resp_chn_name2ax = { def create_psg_bcg_figure(): - fig = plt.figure(figsize=(12, 8), dpi=100) + fig = plt.figure(figsize=(12, 8), dpi=200) gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1]) fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) axes = [] @@ -37,39 +39,41 @@ def create_psg_bcg_figure(): ax = fig.add_subplot(gs[i]) axes.append(ax) - axes[0].grid(True) + axes[psg_chn_name2ax["SpO2"]].grid(True) # axes[0].xaxis.set_major_formatter(Params.FORMATTER) - axes[0].set_ylim((85, 100)) - axes[0].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["SpO2"]].set_ylim((85, 100)) + axes[psg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white") - axes[1].grid(True) + axes[psg_chn_name2ax["Flow T"]].grid(True) # axes[1].xaxis.set_major_formatter(Params.FORMATTER) - axes[1].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white") - axes[2].grid(True) + axes[psg_chn_name2ax["Flow P"]].grid(True) # axes[2].xaxis.set_major_formatter(Params.FORMATTER) - axes[2].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white") - axes[3].grid(True) + axes[psg_chn_name2ax["Effort Tho"]].grid(True) # axes[3].xaxis.set_major_formatter(Params.FORMATTER) - axes[3].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white") - axes[4].grid(True) + axes[psg_chn_name2ax["Effort Abd"]].grid(True) # axes[4].xaxis.set_major_formatter(Params.FORMATTER) - axes[4].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white") - axes[5].grid(True) - axes[5].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["HR"]].grid(True) + axes[psg_chn_name2ax["HR"]].tick_params(axis='x', colors="white") - axes[6].grid(True) + axes[psg_chn_name2ax["resp"]].grid(True) + axes[psg_chn_name2ax["resp"]].tick_params(axis='x', colors="white") + axes.append(axes[psg_chn_name2ax["resp"]].twinx()) + + axes[psg_chn_name2ax["bcg"]].grid(True) # axes[5].xaxis.set_major_formatter(Params.FORMATTER) - axes[6].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white") + axes.append(axes[psg_chn_name2ax["bcg"]].twinx()) - axes[7].grid(True) - # axes[6].xaxis.set_major_formatter(Params.FORMATTER) - axes[7].tick_params(axis='x', colors="white") - axes[8].grid(True) + axes[psg_chn_name2ax["Stage"]].grid(True) # axes[7].xaxis.set_major_formatter(Params.FORMATTER) return fig, axes @@ -96,7 +100,7 @@ def create_resp_figure(): def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None, - event_codes: list[int] = None, multi_labels=None): + event_codes: list[int] = None, multi_labels=None, ax2: Axes = None): signal_fs = signal_data["fs"] chn_signal = signal_data["data"] time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs) @@ -112,7 +116,7 @@ def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, ev elif multi_labels == "resp" and event_codes is not None: ax.set_ylim(-6, 6) # 建立第二个y轴坐标 - ax2 = ax.twinx() + ax2.cla() ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, color='blue', alpha=0.8, label='Low Amplitude Mask') ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, @@ -122,13 +126,15 @@ def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, ev for event_code in event_codes: sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code score_mask = event_mask["SA_Score_Alpha"][segment_start:segment_end].repeat(signal_fs) - y = (sa_mask * score_mask).astype(float) + # y = (sa_mask * score_mask).astype(float) + y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * sa_mask).astype(float) np.place(y, y == 0, np.nan) - ax2.plot(time_axis, y, color=utils.ColorCycle[event_code]) + ax.plot(time_axis, y, color=utils.ColorCycle[event_code]) + ax2.plot(time_axis, score_mask, color="orange") ax2.set_ylim(-4, 5) elif multi_labels == "bcg" and event_codes is not None: # 建立第二个y轴坐标 - ax2 = ax.twinx() + ax2.cla() ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, color='blue', alpha=0.8, label='Low Amplitude Mask') ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, @@ -177,17 +183,19 @@ def score_mask2alpha(score_mask): return alpha_mask -def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list): +def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None): + if save_path is not None: + save_path.mkdir(parents=True, exist_ok=True) + for mask in event_mask.keys(): - if mask.startswith("Resp_") or mask.endswith("BCG_"): + if mask.startswith("Resp_") or mask.startswith("BCG_"): event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0) event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"]) event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0) fig, axes = create_psg_bcg_figure() - for segment_start, segment_end in segment_list: - print(f"Drawing segment: {segment_start} to {segment_end} seconds") + for segment_start, segment_end in tqdm(segment_list): for ax in axes: ax.cla() @@ -203,13 +211,17 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list): psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end, - event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4]) + event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["resp_twinx"]]) plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end, - event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4]) - plt.show() - print(f"Finished drawing segment: {segment_start} to {segment_end} seconds") + event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["bcg_twinx"]]) + + if save_path is not None: + fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png") + tqdm.write(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}") + # print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}") + def draw_resp_label(resp_data, resp_label, segment_list): for mask in resp_label.keys(): if mask.startswith("Resp_"):