From ed4205f5b8a12c40490b87244571d794721192e2 Mon Sep 17 00:00:00 2001 From: marques Date: Fri, 14 Nov 2025 18:39:50 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=95=B0=E6=8D=AE=E5=A4=84?= =?UTF-8?q?=E7=90=86=E6=A8=A1=E5=9D=97=EF=BC=8C=E6=B7=BB=E5=8A=A0=E4=BF=A1?= =?UTF-8?q?=E5=8F=B7=E6=A0=87=E5=87=86=E5=8C=96=E5=92=8C=E7=BB=98=E5=9B=BE?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=EF=BC=8C=E9=87=8D=E6=9E=84=E9=83=A8=E5=88=86?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E4=BB=A5=E6=8F=90=E9=AB=98=E5=8F=AF=E8=AF=BB?= =?UTF-8?q?=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 + dataset_builder/HYS_dataset.py | 166 ++++++++++++ dataset_config/HYS_config.yaml | 27 +- draw_tools/__init__.py | 3 +- draw_tools/draw_label.py | 230 ++++++++++++++++ event_mask_process/HYS_process.py | 114 +++----- event_mask_process/SHHS_process.py | 0 event_mask_process/ZD5Y_process.py | 277 -------------------- signal_method/__init__.py | 4 +- signal_method/normalize_method.py | 36 +++ signal_method/signal_process.py | 62 +++++ utils/HYS_FileReader.py | 134 +++++++++- utils/__init__.py | 9 +- utils/event_map.py | 37 ++- utils/{signal_process.py => filter_func.py} | 0 utils/operation_tools.py | 28 +- utils/split_method.py | 27 ++ 17 files changed, 774 insertions(+), 382 deletions(-) create mode 100644 draw_tools/draw_label.py delete mode 100644 event_mask_process/SHHS_process.py delete mode 100644 event_mask_process/ZD5Y_process.py create mode 100644 signal_method/normalize_method.py create mode 100644 signal_method/signal_process.py rename utils/{signal_process.py => filter_func.py} (100%) create mode 100644 utils/split_method.py diff --git a/.gitignore b/.gitignore index 2429834..1119c69 100644 --- a/.gitignore +++ b/.gitignore @@ -253,3 +253,5 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +output/* +!output/ diff --git a/dataset_builder/HYS_dataset.py b/dataset_builder/HYS_dataset.py index e69de29..50cca9a 100644 --- a/dataset_builder/HYS_dataset.py +++ b/dataset_builder/HYS_dataset.py @@ -0,0 +1,166 @@ +import sys +from pathlib import Path + +import os + +import numpy as np + +os.environ['DISPLAY'] = "localhost:10.0" + +sys.path.append(str(Path(__file__).resolve().parent.parent)) +project_root_path = Path(__file__).resolve().parent.parent + +import utils +import signal_method +import draw_tools +import shutil + + +def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False): + signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) + if not signal_path: + raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}") + signal_path = signal_path[0] + print(f"Processing OrgBCG_Sync signal file: {signal_path}") + + mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv") + print(f"mask_excel_path: {mask_excel_path}") + + event_mask, event_list = utils.read_mask_execl(mask_excel_path) + + bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float) + + bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs) + normalized_resp_signal = signal_method.normalize_resp_signal(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"]) + + + # 如果signal_data采样率过,进行降采样 + if signal_fs == 1000: + bcg_signal_notch = utils.downsample_signal_fast(original_signal=bcg_signal_notch, original_fs=signal_fs, target_fs=100) + bcg_signal_raw = utils.downsample_signal_fast(original_signal=bcg_signal_raw, original_fs=signal_fs, + target_fs=100) + signal_fs = 100 + + if bcg_fs == 1000: + bcg_signal = utils.downsample_signal_fast(original_signal=bcg_signal, original_fs=bcg_fs, target_fs=100) + bcg_fs = 100 + + # draw_tools.draw_signal_with_mask(samp_id=samp_id, + # signal_data=resp_signal, + # signal_fs=resp_fs, + # resp_data=normalized_resp_signal, + # resp_fs=resp_fs, + # bcg_data=bcg_signal, + # bcg_fs=bcg_fs, + # signal_disable_mask=event_mask["Disable_Label"], + # resp_low_amp_mask=event_mask["Resp_LowAmp_Label"], + # resp_movement_mask=event_mask["Resp_Movement_Label"], + # resp_change_mask=event_mask["Resp_AmpChange_Label"], + # resp_sa_mask=event_mask["SA_Label"], + # bcg_low_amp_mask=event_mask["BCG_LowAmp_Label"], + # bcg_movement_mask=event_mask["BCG_Movement_Label"], + # bcg_change_mask=event_mask["BCG_AmpChange_Label"], + # show=show, + # save_path=None) + + segment_list = utils.resp_split(dataset_config, event_mask, event_list) + print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}") + + + # 复制mask到processed_Labels文件夹 + save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv" + shutil.copyfile(mask_excel_path, save_mask_excel_path) + + # 复制SA Label_corrected.csv到processed_Labels文件夹 + sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_corrected.csv") + if sa_label_corrected_path.exists(): + save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv" + shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path) + else: + print(f"Warning: {sa_label_corrected_path} does not exist.") + + # 保存处理后的信号和截取的片段列表 + save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz" + save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz" + + bcg_data = { + "bcg_signal_notch": { + "name": "BCG_Signal_Notch", + "data": bcg_signal_notch, + "fs": signal_fs, + "length": len(bcg_signal_notch), + "second": len(bcg_signal_notch) // signal_fs + }, + "bcg_signal":{ + "name": "BCG_Signal_Raw", + "data": bcg_signal, + "fs": bcg_fs, + "length": len(bcg_signal), + "second": len(bcg_signal) // bcg_fs + }, + "resp_signal": { + "name": "Resp_Signal", + "data": normalized_resp_signal, + "fs": resp_fs, + "length": len(normalized_resp_signal), + "second": len(normalized_resp_signal) // resp_fs + } + } + + np.savez_compressed(save_signal_path, **bcg_data) + np.savez_compressed(save_segment_path, + segment_list=segment_list) + print(f"Saved processed signals to: {save_signal_path}") + print(f"Saved segment list to: {save_segment_path}") + + if draw_segment: + psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8]) + psg_data["HR"] = { + "name": "HR", + "data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"]), + "fs": psg_data["ECG_Sync"]["fs"], + "length": psg_data["ECG_Sync"]["length"], + "second": psg_data["ECG_Sync"]["second"] + } + + + psg_label = utils.read_psg_label(sa_label_corrected_path) + psg_event_mask, _ = utils.generate_event_mask(event_df=psg_label, signal_second=psg_data["ECG_Sync"]["second"], use_correct=False) + draw_tools.draw_psg_bcg_label(psg_data=psg_data, + psg_label=psg_event_mask, + bcg_data=bcg_data, + event_mask=event_mask, + segment_list=segment_list) + + +if __name__ == '__main__': + yaml_path = project_root_path / "dataset_config/HYS_config.yaml" + + conf = utils.load_dataset_conf(yaml_path) + select_ids = conf["select_ids"] + root_path = Path(conf["root_path"]) + mask_path = Path(conf["mask_save_path"]) + save_path = Path(conf["dataset_config"]["dataset_save_path"]) + dataset_config = conf["dataset_config"] + + save_processed_signal_path = save_path / "Signals" + save_processed_signal_path.mkdir(parents=True, exist_ok=True) + + save_segment_list_path = save_path / "Segments_List" + save_segment_list_path.mkdir(parents=True, exist_ok=True) + + save_processed_label_path = save_path / "Labels" + save_processed_label_path.mkdir(parents=True, exist_ok=True) + + print(f"select_ids: {select_ids}") + print(f"root_path: {root_path}") + print(f"save_path: {save_path}") + + org_signal_root_path = root_path / "OrgBCG_Aligned" + psg_signal_root_path = root_path / "PSG_Aligned" + + build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True) + # + # for samp_id in select_ids: + # print(f"Processing sample ID: {samp_id}") + # build_HYS_dataset_segment(samp_id, show=False) \ No newline at end of file diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index 0f5be51..a15e03d 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -11,7 +11,7 @@ select_ids: - 960 root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS -save_path: /mnt/disk_code/marques/dataprepare/output/HYS +mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS resp: downsample_fs_1: 100 @@ -43,11 +43,11 @@ resp_movement: min_duration_sec: 1 resp_movement_revise: - up_interval_multiplier: 3 - down_interval_multiplier: 2 - compare_intervals_sec: 30 - merge_gap_sec: 10 - min_duration_sec: 1 + up_interval_multiplier: 3 + down_interval_multiplier: 2 + compare_intervals_sec: 30 + merge_gap_sec: 10 + min_duration_sec: 1 resp_amp_change: mav_calc_window_sec: 4 @@ -56,7 +56,7 @@ resp_amp_change: bcg: - downsample_fs: 100 + downsample_fs: 100 bcg_filter: filter_type: bandpass @@ -73,8 +73,13 @@ bcg_low_amp: bcg_movement: - window_size_sec: 2 - stride_sec: - merge_gap_sec: 20 - min_duration_sec: 4 + window_size_sec: 2 + stride_sec: + merge_gap_sec: 20 + min_duration_sec: 4 + +dataset_config: + window_sec: 180 + stride_sec: 60 + dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset diff --git a/draw_tools/__init__.py b/draw_tools/__init__.py index 281cc34..3386b90 100644 --- a/draw_tools/__init__.py +++ b/draw_tools/__init__.py @@ -1 +1,2 @@ -from .draw_statics import draw_signal_with_mask \ No newline at end of file +from .draw_statics import draw_signal_with_mask +from .draw_label import draw_psg_bcg_label, draw_resp_label \ No newline at end of file diff --git a/draw_tools/draw_label.py b/draw_tools/draw_label.py new file mode 100644 index 0000000..605de63 --- /dev/null +++ b/draw_tools/draw_label.py @@ -0,0 +1,230 @@ +from matplotlib.axes import Axes +from matplotlib.gridspec import GridSpec +from matplotlib.colors import PowerNorm +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +import seaborn as sns +import numpy as np + +import utils + +# 添加with_prediction参数 + +psg_chn_name2ax = { + "SpO2": 0, + "Flow T": 1, + "Flow P": 2, + "Effort Tho": 3, + "Effort Abd": 4, + "HR": 5, + "resp": 6, + "bcg": 7, + "Stage": 8 +} + +resp_chn_name2ax = { + "resp": 0, + "bcg": 1, +} + + +def create_psg_bcg_figure(): + fig = plt.figure(figsize=(12, 8), dpi=100) + gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1]) + fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) + axes = [] + for i in range(9): + ax = fig.add_subplot(gs[i]) + axes.append(ax) + + axes[0].grid(True) + # axes[0].xaxis.set_major_formatter(Params.FORMATTER) + axes[0].set_ylim((85, 100)) + axes[0].tick_params(axis='x', colors="white") + + axes[1].grid(True) + # axes[1].xaxis.set_major_formatter(Params.FORMATTER) + axes[1].tick_params(axis='x', colors="white") + + axes[2].grid(True) + # axes[2].xaxis.set_major_formatter(Params.FORMATTER) + axes[2].tick_params(axis='x', colors="white") + + axes[3].grid(True) + # axes[3].xaxis.set_major_formatter(Params.FORMATTER) + axes[3].tick_params(axis='x', colors="white") + + axes[4].grid(True) + # axes[4].xaxis.set_major_formatter(Params.FORMATTER) + axes[4].tick_params(axis='x', colors="white") + + axes[5].grid(True) + axes[5].tick_params(axis='x', colors="white") + + axes[6].grid(True) + # axes[5].xaxis.set_major_formatter(Params.FORMATTER) + axes[6].tick_params(axis='x', colors="white") + + axes[7].grid(True) + # axes[6].xaxis.set_major_formatter(Params.FORMATTER) + axes[7].tick_params(axis='x', colors="white") + + axes[8].grid(True) + # axes[7].xaxis.set_major_formatter(Params.FORMATTER) + + return fig, axes + + +def create_resp_figure(): + fig = plt.figure(figsize=(12, 6), dpi=100) + gs = GridSpec(2, 1, height_ratios=[3, 2]) + fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) + axes = [] + for i in range(2): + ax = fig.add_subplot(gs[i]) + axes.append(ax) + + axes[0].grid(True) + # axes[0].xaxis.set_major_formatter(Params.FORMATTER) + axes[0].tick_params(axis='x', colors="white") + + axes[1].grid(True) + # axes[1].xaxis.set_major_formatter(Params.FORMATTER) + axes[1].tick_params(axis='x', colors="white") + + return fig, axes + + +def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None, + event_codes: list[int] = None, multi_labels=None): + signal_fs = signal_data["fs"] + chn_signal = signal_data["data"] + time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs) + ax.plot(time_axis, chn_signal[segment_start * signal_fs:segment_end * signal_fs], color='black', + label=signal_data["name"]) + if event_mask is not None: + if multi_labels is None and event_codes is not None: + for event_code in event_codes: + mask = event_mask[segment_start:segment_end].repeat(signal_fs) == event_code + y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * mask).astype(float) + np.place(y, y == 0, np.nan) + ax.plot(time_axis, y, color=utils.ColorCycle[event_code]) + elif multi_labels == "resp" and event_codes is not None: + ax.set_ylim(-6, 6) + # 建立第二个y轴坐标 + ax2 = ax.twinx() + ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, + color='blue', alpha=0.8, label='Low Amplitude Mask') + ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, + color='orange', alpha=0.8, label='Movement Mask') + ax2.plot(time_axis, event_mask["Resp_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3, + color='green', alpha=0.8, label='Amplitude Change Mask') + for event_code in event_codes: + sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code + score_mask = event_mask["SA_Score_Alpha"][segment_start:segment_end].repeat(signal_fs) + y = (sa_mask * score_mask).astype(float) + np.place(y, y == 0, np.nan) + ax2.plot(time_axis, y, color=utils.ColorCycle[event_code]) + ax2.set_ylim(-4, 5) + elif multi_labels == "bcg" and event_codes is not None: + # 建立第二个y轴坐标 + ax2 = ax.twinx() + ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, + color='blue', alpha=0.8, label='Low Amplitude Mask') + ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, + color='orange', alpha=0.8, label='Movement Mask') + ax2.plot(time_axis, event_mask["BCG_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3, + color='green', alpha=0.8, label='Amplitude Change Mask') + + ax2.set_ylim(-4, 4) + + ax.set_ylabel("Amplitude") + ax.legend(loc=1) + + +def plt_stage_on_ax(ax, stage_data, segment_start, segment_end): + stage_signal = stage_data["data"] + stage_fs = stage_data["fs"] + time_axis = np.linspace(segment_start / stage_fs, segment_end / stage_fs, segment_end - segment_start) + ax.plot(time_axis, stage_signal[segment_start:segment_end], color='black', label=stage_data["name"]) + ax.set_ylim(0, 6) + ax.set_yticks([1, 2, 3, 4, 5]) + ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"]) + ax.set_ylabel("Stage") + ax.legend(loc=1) + + +def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end): + spo2_signal = spo2_data["data"] + spo2_fs = spo2_data["fs"] + time_axis = np.linspace(segment_start / spo2_fs, segment_end / spo2_fs, segment_end - segment_start) + ax.plot(time_axis, spo2_signal[segment_start:segment_end], color='black', label=spo2_data["name"]) + + if spo2_signal[segment_start:segment_end].min() < 85: + ax.set_ylim((spo2_signal[segment_start:segment_end].min() - 5, 100)) + else: + ax.set_ylim((85, 100)) + ax.set_ylabel("SpO2 (%)") + ax.legend(loc=1) + + +def score_mask2alpha(score_mask): + alpha_mask = np.zeros_like(score_mask, dtype=float) + alpha_mask[score_mask <= 0] = 0 + alpha_mask[score_mask == 1] = 0.9 + alpha_mask[score_mask == 2] = 0.6 + alpha_mask[score_mask == 3] = 0.1 + return alpha_mask + + +def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list): + for mask in event_mask.keys(): + if mask.startswith("Resp_") or mask.endswith("BCG_"): + event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0) + + event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"]) + event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0) + + fig, axes = create_psg_bcg_figure() + for segment_start, segment_end in segment_list: + print(f"Drawing segment: {segment_start} to {segment_end} seconds") + for ax in axes: + ax.cla() + + plt_spo2_on_ax(axes[psg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end) + plt_stage_on_ax(axes[psg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end) + plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) + plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end, + event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end, + event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4]) + plt.show() + print(f"Finished drawing segment: {segment_start} to {segment_end} seconds") + + +def draw_resp_label(resp_data, resp_label, segment_list): + for mask in resp_label.keys(): + if mask.startswith("Resp_"): + resp_label[mask] = utils.none_to_nan_mask(resp_label[mask], 0) + + resp_label["Resp_Score_Alpha"] = score_mask2alpha(resp_label["Resp_Score"]) + resp_label["Resp_Label_Alpha"] = utils.none_to_nan_mask(resp_label["Resp_Label_Alpha"], 0) + + fig, axes = create_resp_figure() + for segment_start, segment_end in segment_list: + for ax in axes: + ax.cla() + + plt_signal_label_on_ax(axes[resp_chn_name2ax["resp"]], resp_data["resp_signal"], segment_start, segment_end, + resp_label, multi_labels="resp", event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[resp_chn_name2ax["bcg"]], resp_data["bcg_signal"], segment_start, segment_end, + resp_label, multi_labels="bcg", event_codes=[1, 2, 3, 4]) + plt.show() diff --git a/event_mask_process/HYS_process.py b/event_mask_process/HYS_process.py index 1054dba..dc33461 100644 --- a/event_mask_process/HYS_process.py +++ b/event_mask_process/HYS_process.py @@ -18,15 +18,19 @@ # 高幅值连续体动规则标定与剔除 # 手动标定不可用区间提剔除 """ - +import sys from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) +project_root_path = Path(__file__).resolve().parent.parent + import shutil import draw_tools import utils import numpy as np import signal_method import os -from matplotlib import pyplot as plt + os.environ['DISPLAY'] = "localhost:10.0" @@ -48,56 +52,14 @@ def process_one_signal(samp_id, show=False): save_samp_path = save_path / f"{samp_id}" save_samp_path.mkdir(parents=True, exist_ok=True) - signal_data_raw = utils.read_signal_txt(signal_path) - signal_length = len(signal_data_raw) - print(f"signal_length: {signal_length}") - signal_fs = int(signal_path.stem.split("_")[-1]) - print(f"signal_fs: {signal_fs}") - signal_second = signal_length // signal_fs - print(f"signal_second: {signal_second}") + signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, verbose=True) - # 根据采样率进行截断 - signal_data_raw = signal_data_raw[:signal_second * signal_fs] - - # 滤波 - # 50Hz陷波滤波器 - # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) - print("Applying 50Hz notch filter...") - signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) - - resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs) - resp_fs = conf["resp"]["downsample_fs_1"] - resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs) - resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20) - resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"], - low_cut=conf["resp_filter"]["low_cut"], - high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"], - sample_rate=resp_fs) - print("Begin plotting signal data...") - - # fig = plt.figure(figsize=(12, 8)) - # # 绘制三个图raw_data、resp_data_1、resp_data_2 - # ax0 = fig.add_subplot(3, 1, 1) - # ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue') - # ax0.set_title('Raw Signal Data') - # ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) - # ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange') - # ax1.set_title('Resp Data after Average Filtering') - # ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) - # ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green') - # ax2.set_title('Resp Data after Butterworth Filtering') - # plt.tight_layout() - # plt.show() - - bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"], - low_cut=conf["bcg_filter"]["low_cut"], - high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"], - sample_rate=signal_fs) + signal_data, resp_data, resp_fs, bcg_data, bcg_fs = signal_method.signal_filter_split(conf=conf, signal_data_raw=signal_data_raw, signal_fs=signal_fs) # 降采样 old_resp_fs = resp_fs resp_fs = conf["resp"]["downsample_fs_2"] - resp_data = utils.downsample_signal_fast(original_signal=resp_data_2, original_fs=old_resp_fs, target_fs=resp_fs) + resp_data = utils.downsample_signal_fast(original_signal=resp_data, original_fs=old_resp_fs, target_fs=resp_fs) bcg_fs = conf["bcg"]["downsample_fs"] bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs) @@ -214,26 +176,26 @@ def process_one_signal(samp_id, show=False): target_fs=100) signal_fs = 100 - draw_tools.draw_signal_with_mask(samp_id=samp_id, - signal_data=signal_data, - signal_fs=signal_fs, - resp_data=resp_data, - resp_fs=resp_fs, - bcg_data=bcg_data, - bcg_fs=bcg_fs, - signal_disable_mask=manual_disable_mask, - resp_low_amp_mask=resp_low_amp_mask, - resp_movement_mask=resp_movement_mask, - resp_change_mask=resp_amp_change_mask, - resp_sa_mask=event_mask, - bcg_low_amp_mask=bcg_low_amp_mask, - bcg_movement_mask=bcg_movement_mask, - bcg_change_mask=bcg_amp_change_mask, - show=show, - save_path=save_samp_path / f"{samp_id}_Signal_Plots.png") + draw_tools.draw_signal_with_mask(samp_id=samp_id, + signal_data=signal_data, + signal_fs=signal_fs, + resp_data=resp_data, + resp_fs=resp_fs, + bcg_data=bcg_data, + bcg_fs=bcg_fs, + signal_disable_mask=manual_disable_mask, + resp_low_amp_mask=resp_low_amp_mask, + resp_movement_mask=resp_movement_mask, + resp_change_mask=resp_amp_change_mask, + resp_sa_mask=event_mask, + bcg_low_amp_mask=bcg_low_amp_mask, + bcg_movement_mask=bcg_movement_mask, + bcg_change_mask=bcg_amp_change_mask, + show=show, + save_path=save_samp_path / f"{samp_id}_Signal_Plots.png") # 复制事件文件 到保存路径 - sa_label_save_name = f"{samp_id}" + label_path.name + sa_label_save_name = f"{samp_id}_" + label_path.name shutil.copyfile(label_path, save_samp_path / sa_label_save_name) # 新建一个dataframe,分别是秒数、SA标签,SA质量标签,禁用标签,Resp低幅值标签,Resp体动标签,Resp幅值突变标签,Bcg低幅值标签,Bcg体动标签,Bcg幅值突变标签 @@ -247,10 +209,10 @@ def process_one_signal(samp_id, show=False): dtype=int), "Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second, dtype=int), - "Bcg_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int), - "Bcg_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second, + "BCG_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int), + "BCG_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second, dtype=int), - "Bcg_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second, + "BCG_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second, dtype=int) } @@ -259,13 +221,13 @@ def process_one_signal(samp_id, show=False): if __name__ == '__main__': - yaml_path = Path("../dataset_config/HYS_config.yaml") - disable_df_path = Path("../排除区间.xlsx") + yaml_path = project_root_path / "dataset_config/HYS_config.yaml" + disable_df_path = project_root_path / "排除区间.xlsx" conf = utils.load_dataset_conf(yaml_path) select_ids = conf["select_ids"] root_path = Path(conf["root_path"]) - save_path = Path(conf["save_path"]) + save_path = Path(conf["mask_save_path"]) print(f"select_ids: {select_ids}") print(f"root_path: {root_path}") @@ -276,9 +238,9 @@ if __name__ == '__main__': all_samp_disable_df = utils.read_disable_excel(disable_df_path) - process_one_signal(select_ids[6], show=True) + # process_one_signal(select_ids[6], show=True) # - # for samp_id in select_ids: - # print(f"Processing sample ID: {samp_id}") - # process_one_signal(samp_id, show=False) - # print(f"Finished processing sample ID: {samp_id}\n\n") + for samp_id in select_ids: + print(f"Processing sample ID: {samp_id}") + process_one_signal(samp_id, show=False) + print(f"Finished processing sample ID: {samp_id}\n\n") diff --git a/event_mask_process/SHHS_process.py b/event_mask_process/SHHS_process.py deleted file mode 100644 index e69de29..0000000 diff --git a/event_mask_process/ZD5Y_process.py b/event_mask_process/ZD5Y_process.py deleted file mode 100644 index 8701d65..0000000 --- a/event_mask_process/ZD5Y_process.py +++ /dev/null @@ -1,277 +0,0 @@ -""" -本脚本完成对呼研所数据的处理,包含以下功能: -1. 数据读取与预处理 - 从传入路径中,进行数据和标签的读取,并进行初步的预处理 - 预处理包括为数据进行滤波、去噪等操作 -2. 数据清洗与异常值处理 -3. 输出清晰后的统计信息 -4. 数据保存 - 将处理后的数据保存到指定路径,便于后续使用 - 主要是保存切分后的数据位置和标签 -5. 可视化 - 提供数据处理前后的可视化对比,帮助理解数据变化 - 绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等 - -todo: 使用mask 屏蔽无用区间 - - -# 低幅值区间规则标定与剔除 -# 高幅值连续体动规则标定与剔除 -# 手动标定不可用区间提剔除 -""" - -from pathlib import Path -import shutil -import draw_tools -import utils -import numpy as np -import signal_method -import os -from matplotlib import pyplot as plt -os.environ['DISPLAY'] = "localhost:10.0" - -def process_one_signal(samp_id, show=False): - signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) - if not signal_path: - raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}") - signal_path = signal_path[0] - print(f"Processing OrgBCG_Sync signal file: {signal_path}") - - label_path = (label_root_path / f"{samp_id}").glob("SA Label_corrected.csv") - if not label_path: - raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}") - label_path = list(label_path)[0] - print(f"Processing Label_corrected file: {label_path}") - - signal_data_raw = utils.read_signal_txt(signal_path) - signal_length = len(signal_data_raw) - print(f"signal_length: {signal_length}") - signal_fs = int(signal_path.stem.split("_")[-1]) - print(f"signal_fs: {signal_fs}") - signal_second = signal_length // signal_fs - print(f"signal_second: {signal_second}") - - # 根据采样率进行截断 - signal_data_raw = signal_data_raw[:signal_second * signal_fs] - - # 滤波 - # 50Hz陷波滤波器 - # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) - print("Applying 50Hz notch filter...") - signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) - - resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs) - resp_fs = conf["resp"]["downsample_fs_1"] - resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs) - resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20) - resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"], - low_cut=conf["resp_filter"]["low_cut"], - high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"], - sample_rate=resp_fs) - print("Begin plotting signal data...") - - - # fig = plt.figure(figsize=(12, 8)) - # # 绘制三个图raw_data、resp_data_1、resp_data_2 - # ax0 = fig.add_subplot(3, 1, 1) - # ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue') - # ax0.set_title('Raw Signal Data') - # ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) - # ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange') - # ax1.set_title('Resp Data after Average Filtering') - # ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) - # ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green') - # ax2.set_title('Resp Data after Butterworth Filtering') - # plt.tight_layout() - # plt.show() - - bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"], - low_cut=conf["bcg_filter"]["low_cut"], - high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"], - sample_rate=signal_fs) - - # 降采样 - old_resp_fs = resp_fs - resp_fs = conf["resp"]["downsample_fs_2"] - resp_data = utils.downsample_signal_fast(original_signal=resp_data_2, original_fs=old_resp_fs, target_fs=resp_fs) - bcg_fs = conf["bcg"]["downsample_fs"] - bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs) - - label_data = utils.read_label_csv(path=label_path) - event_mask, score_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data) - - manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[ - all_samp_disable_df["id"] == samp_id]) - print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}") - - # 分析Resp的低幅值区间 - resp_low_amp_conf = conf.get("resp_low_amp", None) - if resp_low_amp_conf is not None: - resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal( - signal_data=resp_data, - sampling_rate=resp_fs, - **resp_low_amp_conf - ) - print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}, count_low_amp_positions: {len(resp_low_amp_position_list)}") - else: - resp_low_amp_mask, resp_low_amp_position_list = None, None - print("resp_low_amp_mask is None") - - # 分析Resp的高幅值伪迹区间 - resp_movement_conf = conf.get("resp_movement", None) - if resp_movement_conf is not None: - raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement( - signal_data=resp_data, - sampling_rate=resp_fs, - **resp_movement_conf - ) - print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") - else: - resp_movement_mask, resp_movement_position_list = None, None - print("resp_movement_mask is None") - - resp_movement_revise_conf = conf.get("resp_movement_revise", None) - if resp_movement_mask is not None and resp_movement_revise_conf is not None: - resp_movement_mask, resp_movement_position_list = signal_method.movement_revise( - signal_data=resp_data, - movement_mask=resp_movement_mask, - movement_list=resp_movement_position_list, - sampling_rate=resp_fs, - **resp_movement_revise_conf, - verbose=False - ) - print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") - else: - print("resp_movement_mask revise is skipped") - - - # 分析Resp的幅值突变区间 - resp_amp_change_conf = conf.get("resp_amp_change", None) - if resp_amp_change_conf is not None and resp_movement_mask is not None: - resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v3( - signal_data=resp_data, - movement_mask=resp_movement_mask, - movement_list=resp_movement_position_list, - sampling_rate=resp_fs, - **resp_amp_change_conf) - print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}") - else: - resp_amp_change_mask = None - print("amp_change_mask is None") - - - - # 分析Bcg的低幅值区间 - bcg_low_amp_conf = conf.get("bcg_low_amp", None) - if bcg_low_amp_conf is not None: - bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal( - signal_data=bcg_data, - sampling_rate=bcg_fs, - **bcg_low_amp_conf - ) - print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}, count_low_amp_positions: {len(bcg_low_amp_position_list)}") - else: - bcg_low_amp_mask, bcg_low_amp_position_list = None, None - print("bcg_low_amp_mask is None") - # 分析Bcg的高幅值伪迹区间 - bcg_movement_conf = conf.get("bcg_movement", None) - if bcg_movement_conf is not None: - raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement( - signal_data=bcg_data, - sampling_rate=bcg_fs, - **bcg_movement_conf - ) - print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}, count_movement_positions: {len(bcg_movement_position_list)}") - else: - bcg_movement_mask = None - print("bcg_movement_mask is None") - # 分析Bcg的幅值突变区间 - if bcg_movement_mask is not None: - bcg_amp_change_mask, bcg_amp_change_list = signal_method.position_based_sleep_recognition_v2( - signal_data=bcg_data, - movement_mask=bcg_movement_mask, - sampling_rate=bcg_fs) - print(f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}, count_amp_change_positions: {len(bcg_amp_change_list)}") - else: - bcg_amp_change_mask = None - print("bcg_amp_change_mask is None") - - - # 如果signal_data采样率过,进行降采样 - if signal_fs == 1000: - signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100) - signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs, target_fs=100) - signal_fs = 100 - if show: - draw_tools.draw_signal_with_mask(samp_id=samp_id, - signal_data=signal_data, - signal_fs=signal_fs, - resp_data=resp_data, - resp_fs=resp_fs, - bcg_data=bcg_data, - bcg_fs=bcg_fs, - signal_disable_mask=manual_disable_mask, - resp_low_amp_mask=resp_low_amp_mask, - resp_movement_mask=resp_movement_mask, - resp_change_mask=resp_amp_change_mask, - resp_sa_mask=event_mask, - bcg_low_amp_mask=bcg_low_amp_mask, - bcg_movement_mask=bcg_movement_mask, - bcg_change_mask=bcg_amp_change_mask) - - - # 保存处理后的数据和标签 - save_samp_path = save_path / f"{samp_id}" - save_samp_path.mkdir(parents=True, exist_ok=True) - - # 复制事件文件 到保存路径 - sa_label_save_name = f"{samp_id}" + label_path.name - shutil.copyfile(label_path, save_samp_path / sa_label_save_name) - - # 新建一个dataframe,分别是秒数、SA标签,SA质量标签,禁用标签,Resp低幅值标签,Resp体动标签,Resp幅值突变标签,Bcg低幅值标签,Bcg体动标签,Bcg幅值突变标签 - save_dict = { - "Second": np.arange(signal_second), - "SA_Label": event_mask, - "SA_Score": score_mask, - "Disable_Label": manual_disable_mask, - "Resp_LowAmp_Label": resp_low_amp_mask if resp_low_amp_mask is not None else np.zeros(signal_second, dtype=int), - "Resp_Movement_Label": resp_movement_mask if resp_movement_mask is not None else np.zeros(signal_second, dtype=int), - "Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second, dtype=int), - "Bcg_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int), - "Bcg_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second, dtype=int), - "Bcg_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second, dtype=int) - } - - mask_label_save_name = f"{samp_id}_Processed_Labels.csv" - utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict) - - - - - - - -if __name__ == '__main__': - yaml_path = Path("../dataset_config/ZD5Y_config.yaml") - disable_df_path = Path("../排除区间.xlsx") - - conf = utils.load_dataset_conf(yaml_path) - select_ids = conf["select_ids"] - root_path = Path(conf["root_path"]) - save_path = Path(conf["save_path"]) - - print(f"select_ids: {select_ids}") - print(f"root_path: {root_path}") - print(f"save_path: {save_path}") - - org_signal_root_path = root_path / "OrgBCG_Aligned" - label_root_path = root_path / "Label" - - all_samp_disable_df = utils.read_disable_excel(disable_df_path) - - process_one_signal(select_ids[1], show=True) - - # for samp_id in select_ids: - # print(f"Processing sample ID: {samp_id}") - # process_one_signal(samp_id, show=False) - # print(f"Finished processing sample ID: {samp_id}\n\n") \ No newline at end of file diff --git a/signal_method/__init__.py b/signal_method/__init__.py index eaea6ea..7ce8cdb 100644 --- a/signal_method/__init__.py +++ b/signal_method/__init__.py @@ -1,4 +1,6 @@ from .rule_base_event import detect_low_amplitude_signal, detect_movement from .rule_base_event import position_based_sleep_recognition_v2, position_based_sleep_recognition_v3 from .rule_base_event import movement_revise -from .time_metrics import calc_mav_by_slide_windows \ No newline at end of file +from .time_metrics import calc_mav_by_slide_windows +from .signal_process import signal_filter_split, rpeak2hr +from .normalize_method import normalize_resp_signal \ No newline at end of file diff --git a/signal_method/normalize_method.py b/signal_method/normalize_method.py new file mode 100644 index 0000000..8ed89ce --- /dev/null +++ b/signal_method/normalize_method.py @@ -0,0 +1,36 @@ +import utils +import pandas as pd +import numpy as np +from scipy import signal + +def normalize_resp_signal(resp_signal: np.ndarray, resp_fs, movement_mask, enable_list): + # 根据呼吸信号的幅值改变区间,对每段进行Z-Score标准化 + normalized_resp_signal = np.zeros_like(resp_signal) + # 全部填成nan + normalized_resp_signal[:] = np.nan + + resp_signal_no_movement = resp_signal.copy() + + + resp_signal_no_movement[np.array(movement_mask == 1).repeat(resp_fs)] = np.nan + + + for i in range(len(enable_list)): + enable_start = enable_list[i][0] * resp_fs + enable_end = enable_list[i][1] * resp_fs + segment = resp_signal_no_movement[enable_start:enable_end] + + # print(f"Normalizing segment {i+1}/{len(enable_list)}: start={enable_start}, end={enable_end}, length={len(segment)}") + + segment_mean = np.nanmean(segment) + segment_std = np.nanstd(segment) + if segment_std == 0: + raise ValueError(f"segment_std is zero! segment_start: {enable_start}, segment_end: {enable_end}") + + # 同下一个enable区间的体动一起进行标准化 + if i <= len(enable_list) - 2: + enable_end = enable_list[i + 1][0] * resp_fs + raw_segment = resp_signal[enable_start:enable_end] + normalized_resp_signal[enable_start:enable_end] = (raw_segment - segment_mean) / segment_std + + return normalized_resp_signal diff --git a/signal_method/signal_process.py b/signal_method/signal_process.py new file mode 100644 index 0000000..eaaea59 --- /dev/null +++ b/signal_method/signal_process.py @@ -0,0 +1,62 @@ +import numpy as np + +import utils + +def signal_filter_split(conf, signal_data_raw, signal_fs): + # 滤波 + # 50Hz陷波滤波器 + # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) + print("Applying 50Hz notch filter...") + signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) + + resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs) + resp_fs = conf["resp"]["downsample_fs_1"] + resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs) + resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20) + resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"], + low_cut=conf["resp_filter"]["low_cut"], + high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"], + sample_rate=resp_fs) + print("Begin plotting signal data...") + + # fig = plt.figure(figsize=(12, 8)) + # # 绘制三个图raw_data、resp_data_1、resp_data_2 + # ax0 = fig.add_subplot(3, 1, 1) + # ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue') + # ax0.set_title('Raw Signal Data') + # ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) + # ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange') + # ax1.set_title('Resp Data after Average Filtering') + # ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) + # ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green') + # ax2.set_title('Resp Data after Butterworth Filtering') + # plt.tight_layout() + # plt.show() + + bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"], + low_cut=conf["bcg_filter"]["low_cut"], + high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"], + sample_rate=signal_fs) + + + return signal_data, resp_data_2, resp_fs, bcg_data, signal_fs + + + +def rpeak2hr(rpeak_indices, signal_length): + hr_signal = np.zeros(signal_length) + for i in range(1, len(rpeak_indices)): + rri = rpeak_indices[i] - rpeak_indices[i - 1] + if rri == 0: + continue + hr = 60 * 1000 / rri # 心率,单位:bpm + if hr > 120: + hr = 120 + elif hr < 30: + hr = 30 + hr_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = hr + # 填充最后一个R峰之后的心率值 + if len(rpeak_indices) > 1: + hr_signal[rpeak_indices[-1]:] = hr_signal[rpeak_indices[-2]] + return hr_signal + diff --git a/utils/HYS_FileReader.py b/utils/HYS_FileReader.py index dd65bab..6f7d95a 100644 --- a/utils/HYS_FileReader.py +++ b/utils/HYS_FileReader.py @@ -1,9 +1,11 @@ from pathlib import Path from typing import Union +import utils +from .event_map import N2Chn import numpy as np import pandas as pd - +from .operation_tools import event_mask_2_list # 尝试导入 Polars try: import polars as pl @@ -13,15 +15,17 @@ except ImportError: HAS_POLARS = False -def read_signal_txt(path: Union[str, Path]) -> np.ndarray: +def read_signal_txt(path: Union[str, Path], dtype, verbose=True, is_peak=False): """ Read a txt file and return the first column as a numpy array. Args: - path (str | Path): Path to the txt file. - + :param path: + :param verbose: + :param dtype: Returns: np.ndarray: The first column of the txt file as a numpy array. + """ path = Path(path) if not path.exists(): @@ -29,10 +33,30 @@ def read_signal_txt(path: Union[str, Path]) -> np.ndarray: if HAS_POLARS: df = pl.read_csv(path, has_header=False, infer_schema_length=0) - return df[:, 0].to_numpy().astype(float) + signal_data_raw = df[:, 0].to_numpy().astype(dtype) else: - df = pd.read_csv(path, header=None, dtype=float) - return df.iloc[:, 0].to_numpy() + df = pd.read_csv(path, header=None, dtype=dtype) + signal_data_raw = df.iloc[:, 0].to_numpy() + + signal_original_length = len(signal_data_raw) + signal_fs = int(path.stem.split("_")[-1]) + if is_peak: + signal_second = None + signal_length = None + else: + signal_second = signal_original_length // signal_fs + # 根据采样率进行截断 + signal_data_raw = signal_data_raw[:signal_second * signal_fs] + signal_length = len(signal_data_raw) + + if verbose: + print(f"Signal file read from {path}") + print(f"signal_fs: {signal_fs}") + print(f"signal_original_length: {signal_original_length}") + print(f"signal_after_cut_off_length: {signal_length}") + print(f"signal_second: {signal_second}") + + return signal_data_raw, signal_length, signal_fs, signal_second def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame: @@ -172,3 +196,99 @@ def read_disable_excel(path: Union[str, Path]) -> pd.DataFrame: df["start"] = df["start"].astype(int) df["end"] = df["end"].astype(int) return df + + +def read_mask_execl(path: Union[str, Path]): + """ + Read an Excel file and return the mask as a numpy array. + Args: + path (str | Path): Path to the Excel file. + Returns: + np.ndarray: The mask as a numpy array. + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + df = pd.read_csv(path) + event_mask = df.to_dict(orient="list") + for key in event_mask: + event_mask[key] = np.array(event_mask[key]) + + event_list = {"RespAmpChangeSegment": event_mask_2_list(1 - event_mask["Resp_AmpChange_Label"]), + "BCGAmpChangeSegment": event_mask_2_list(1 - event_mask["BCG_AmpChange_Label"]), + "EnableSegment": event_mask_2_list(1 - event_mask["Disable_Label"]),} + + + return event_mask, event_list + + + +def read_psg_channel(path_str: Union[str, Path], channel_number: list[int]): + """ + 读取PSG文件中特定通道的数据。 + + 参数: + path_str (Union[str, Path]): 存放PSG文件的文件夹路径。 + channel_name (str): 需要读取的通道名称。 + 返回: + np.ndarray: 指定通道的数据数组。 + """ + path = Path(path_str) + if not path.exists(): + raise FileNotFoundError(f"PSG Dir not found: {path}") + + if not path.is_dir(): + raise NotADirectoryError(f"PSG Dir not found: {path}") + channel_data = {} + # 遍历检查通道对应的文件是否存在 + for ch_id in channel_number: + ch_name = N2Chn[ch_id] + ch_path = list(path.glob(f"{ch_name}*.txt")) + + if not any(ch_path): + raise FileNotFoundError(f"PSG Channel file not found: {ch_path}") + + if len(ch_path) > 1: + print(f"Warning!!! PSG Channel file more than one: {ch_path}") + + if ch_id == 8: + # sleep stage 特例 读取为整数 + ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=str, verbose=True) + # 转换为整数数组 + for stage_str, stage_number in utils.Stage2N.items(): + np.place(ch_signal, ch_signal == stage_str, stage_number) + ch_signal = ch_signal.astype(int) + elif ch_id == 1: + # Rpeak 特例 读取为整数 + ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=int, verbose=True, is_peak=True) + else: + ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=float, verbose=True) + channel_data[ch_name] = { + "name": ch_name, + "path": ch_path[0], + "data": ch_signal, + "length": ch_length, + "fs": ch_fs, + "second": ch_second + } + + return channel_data + + +def read_psg_label(path: Union[str, Path], verbose=True): + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + # 直接用pandas读取 包含中文 故指定编码 + df = pd.read_csv(path, encoding="gbk") + if verbose: + print(f"Label file read from {path}, number of rows: {len(df)}") + + # 丢掉Event type为空的行 + df = df.dropna(subset=["Event type"], how='all').reset_index(drop=True) + + return df + + diff --git a/utils/__init__.py b/utils/__init__.py index 68e7772..362297e 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,7 +1,10 @@ -from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel +from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel, read_psg_label from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list from .operation_tools import merge_short_gaps, remove_short_durations from .operation_tools import collect_values from .operation_tools import save_process_label -from .event_map import E2N -from .signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel \ No newline at end of file +from .operation_tools import none_to_nan_mask +from .split_method import resp_split +from .HYS_FileReader import read_mask_execl, read_psg_channel +from .event_map import E2N, N2Chn, Stage2N, ColorCycle +from .filter_func import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel \ No newline at end of file diff --git a/utils/event_map.py b/utils/event_map.py index c85a027..20c6c58 100644 --- a/utils/event_map.py +++ b/utils/event_map.py @@ -4,4 +4,39 @@ E2N = { "Central apnea": 2, "Obstructive apnea": 3, "Mixed apnea": 4 -} \ No newline at end of file +} + +N2Chn = { + 1: "Rpeak", + 2: "ECG_Sync", + 3: "Effort Tho", + 4: "Effort Abd", + 5: "Flow P", + 6: "Flow T", + 7: "SpO2", + 8: "5_class" +} + +Stage2N = { + "W": 5, + "N1": 3, + "N2": 2, + "N3": 1, + "R": 4, +} + +# 设定事件和其对应颜色 +# event_code color event +# 0 黑色 背景 +# 1 粉色 低通气 +# 2 蓝色 中枢性 +# 3 红色 阻塞型 +# 4 灰色 混合型 +# 5 绿色 血氧饱和度下降 +# 6 橙色 大体动 +# 7 橙色 小体动 +# 8 橙色 深呼吸 +# 9 橙色 脉冲体动 +# 10 橙色 无效片段 +ColorCycle = ["black", "pink", "blue", "red", "silver", "green", "orange", "orange", "orange", "orange", + "orange"] \ No newline at end of file diff --git a/utils/signal_process.py b/utils/filter_func.py similarity index 100% rename from utils/signal_process.py rename to utils/filter_func.py diff --git a/utils/operation_tools.py b/utils/operation_tools.py index 095dd5c..866029d 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -198,16 +198,25 @@ def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray: return disable_mask -def generate_event_mask(signal_second: int, event_df): +def generate_event_mask(signal_second: int, event_df, use_correct=True): event_mask = np.zeros(signal_second, dtype=int) score_mask = np.zeros(signal_second, dtype=int) + if use_correct: + start_name = "correct_Start" + end_name = "correct_End" + event_type_name = "correct_EventsType" + else: + start_name = "Start" + end_name = "End" + event_type_name = "Event type" + # 剔除start = -1 的行 - event_df = event_df[event_df["correct_Start"] >= 0] + event_df = event_df[event_df[start_name] >= 0] for _, row in event_df.iterrows(): - start = row["correct_Start"] - end = row["correct_End"] + 1 - event_mask[start:end] = E2N[row["correct_EventsType"]] + start = row[start_name] + end = row[end_name] + 1 + event_mask[start:end] = E2N[row[event_type_name]] score_mask[start:end] = row["score"] return event_mask, score_mask @@ -243,3 +252,12 @@ def collect_values(arr: np.ndarray, index: int, step: int, limit: int, mask=None def save_process_label(save_path: Path, save_dict: dict): save_df = pd.DataFrame(save_dict) save_df.to_csv(save_path, index=False) + +def none_to_nan_mask(mask, ref): + """将None转换为与ref形状相同的nan掩码""" + if mask is None: + return np.full_like(ref, np.nan) + else: + # 将mask中的0替换为nan,其他的保持 + mask = np.where(mask == 0, np.nan, mask) + return mask \ No newline at end of file diff --git a/utils/split_method.py b/utils/split_method.py new file mode 100644 index 0000000..e9c151b --- /dev/null +++ b/utils/split_method.py @@ -0,0 +1,27 @@ + + + +def resp_split(dataset_config, event_mask, event_list): + # 提取体动区间和呼吸低幅值区间 + enable_list = event_list["EnableSegment"] + + # 读取数据集配置 + window_sec = dataset_config["window_sec"] + stride_sec = dataset_config["stride_sec"] + + segment_list = [] + + # 遍历每个enable区间, 如果最后一个窗口不足stride的1/2,则舍弃,否则以enable_end为结尾截取一个窗口 + for enable_start, enable_end in enable_list: + current_start = enable_start + while current_start + window_sec <= enable_end: + segment_list.append((current_start, current_start + window_sec)) + current_start += stride_sec + # 检查最后一个窗口是否需要添加 + if (enable_end - current_start >= stride_sec / 2) and (enable_end - current_start >= window_sec): + segment_list.append((enable_end - window_sec, enable_end)) + + return segment_list + + +