From 097c9cbf0ba9ffcf29b5887e42f2b82251a818d0 Mon Sep 17 00:00:00 2001 From: marques Date: Mon, 19 Jan 2026 14:27:26 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=95=B0=E6=8D=AE=E5=A4=84?= =?UTF-8?q?=E7=90=86=E6=A8=A1=E5=9D=97=EF=BC=8C=E5=A2=9E=E5=8A=A0PSG?= =?UTF-8?q?=E4=BF=A1=E5=8F=B7=E7=BB=98=E5=9B=BE=E5=8A=9F=E8=83=BD=EF=BC=8C?= =?UTF-8?q?=E9=87=8D=E6=9E=84=E9=83=A8=E5=88=86=E5=87=BD=E6=95=B0=E4=BB=A5?= =?UTF-8?q?=E6=8F=90=E9=AB=98=E5=8F=AF=E8=AF=BB=E6=80=A7=E5=92=8C=E5=8F=AF?= =?UTF-8?q?=E7=BB=B4=E6=8A=A4=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dataset_builder/HYS_PSG_dataset.py | 395 ++++++++++++++++++++++++++ dataset_builder/HYS_dataset.py | 53 ++-- dataset_config/HYS_PSG_config.yaml | 129 +++++++++ draw_tools/__init__.py | 4 +- draw_tools/draw_label.py | 173 ++++++++--- draw_tools/draw_statics.py | 67 +++++ event_mask_process/HYS_PSG_process.py | 183 ++++++++++++ signal_method/__init__.py | 4 +- signal_method/normalize_method.py | 18 +- signal_method/signal_process.py | 47 ++- utils/HYS_FileReader.py | 58 ++++ utils/__init__.py | 6 +- utils/filter_func.py | 47 +++ utils/operation_tools.py | 125 +++++++- utils/split_method.py | 2 - 15 files changed, 1228 insertions(+), 83 deletions(-) create mode 100644 dataset_builder/HYS_PSG_dataset.py create mode 100644 dataset_config/HYS_PSG_config.yaml create mode 100644 event_mask_process/HYS_PSG_process.py diff --git a/dataset_builder/HYS_PSG_dataset.py b/dataset_builder/HYS_PSG_dataset.py new file mode 100644 index 0000000..e061e25 --- /dev/null +++ b/dataset_builder/HYS_PSG_dataset.py @@ -0,0 +1,395 @@ +import multiprocessing +import sys +from pathlib import Path + +import os +import numpy as np + +from utils import N2Chn + +os.environ['DISPLAY'] = "localhost:10.0" + +sys.path.append(str(Path(__file__).resolve().parent.parent)) +project_root_path = Path(__file__).resolve().parent.parent + +import utils +import signal_method +import draw_tools +import shutil +import gc + +def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None): + psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose) + + total_seconds = min( + psg_data[i]["second"] for i in N2Chn.values() if i != "Rpeak" + ) + for i in N2Chn.values(): + if i == "Rpeak": + continue + length = int(total_seconds * psg_data[i]["fs"]) + psg_data[i]["data"] = psg_data[i]["data"][:length] + psg_data[i]["length"] = length + psg_data[i]["second"] = total_seconds + + psg_data["HR"] = { + "name": "HR", + "data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"], + psg_data["Rpeak"]["fs"]), + "fs": psg_data["ECG_Sync"]["fs"], + "length": psg_data["ECG_Sync"]["length"], + "second": psg_data["ECG_Sync"]["second"] + } + # 预处理与滤波 + tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Tho"]["data"], effort_fs=psg_data["Effort Tho"]["fs"]) + abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Abd"]["data"], effort_fs=psg_data["Effort Abd"]["fs"]) + flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow P"]["data"], effort_fs=psg_data["Flow P"]["fs"]) + flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow T"]["data"], effort_fs=psg_data["Flow T"]["fs"]) + + rri, rri_fs = signal_method.rpeak2rri_interpolation(rpeak_indices=psg_data["Rpeak"]["data"], ecg_fs=psg_data["ECG_Sync"]["fs"], rri_fs=100) + + + mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv") + if verbose: + print(f"mask_excel_path: {mask_excel_path}") + + event_mask, event_list = utils.read_mask_execl(mask_excel_path) + + enable_list = [[0, psg_data["Effort Tho"]["second"]]] + normalized_tho_signal = signal_method.normalize_resp_signal_by_segment(tho_data_filt, tho_fs, np.zeros(psg_data["Effort Tho"]["second"]), enable_list) + normalized_abd_signal = signal_method.normalize_resp_signal_by_segment(abd_data_filt, abd_fs, np.zeros(psg_data["Effort Abd"]["second"]), enable_list) + normalized_flowp_signal = signal_method.normalize_resp_signal_by_segment(flowp_data_filt, flowp_fs, np.zeros(psg_data["Flow P"]["second"]), enable_list) + normalized_flowt_signal = signal_method.normalize_resp_signal_by_segment(flowt_data_filt, flowt_fs, np.zeros(psg_data["Flow T"]["second"]), enable_list) + + + # 都调整至100Hz采样率 + target_fs = 100 + normalized_tho_signal = utils.adjust_sample_rate(normalized_tho_signal, tho_fs, target_fs) + normalized_abd_signal = utils.adjust_sample_rate(normalized_abd_signal, abd_fs, target_fs) + normalized_flowp_signal = utils.adjust_sample_rate(normalized_flowp_signal, flowp_fs, target_fs) + normalized_flowt_signal = utils.adjust_sample_rate(normalized_flowt_signal, flowt_fs, target_fs) + spo2_data_filt = utils.adjust_sample_rate(psg_data["SpO2"]["data"], psg_data["SpO2"]["fs"], target_fs) + normalized_effort_signal = (normalized_tho_signal + normalized_abd_signal) / 2 + rri = utils.adjust_sample_rate(rri, rri_fs, target_fs) + + # 调整至相同长度 + min_length = min(len(normalized_tho_signal), len(normalized_abd_signal), len(normalized_flowp_signal), len(normalized_flowt_signal), len(spo2_data_filt), len(normalized_effort_signal) + ,len(rri)) + min_length = min_length - min_length % target_fs # 保证是整数秒 + normalized_tho_signal = normalized_tho_signal[:min_length] + normalized_abd_signal = normalized_abd_signal[:min_length] + normalized_flowp_signal = normalized_flowp_signal[:min_length] + normalized_flowt_signal = normalized_flowt_signal[:min_length] + spo2_data_filt = spo2_data_filt[:min_length] + normalized_effort_signal = normalized_effort_signal[:min_length] + rri = rri[:min_length] + + tho_second = min_length / target_fs + for i in event_mask.keys(): + event_mask[i] = event_mask[i][:int(tho_second)] + + spo2_data_filt_fill, spo2_disable_mask = utils.fill_spo2_anomaly(spo2_data=spo2_data_filt, + spo2_fs=target_fs, + max_fill_duration=30, + min_gap_duration=10,) + + + draw_tools.draw_psg_signal( + samp_id=samp_id, + tho_signal=normalized_tho_signal, + abd_signal=normalized_abd_signal, + flowp_signal=normalized_flowp_signal, + flowt_signal=normalized_flowt_signal, + spo2_signal=spo2_data_filt, + effort_signal=normalized_effort_signal, + rri_signal = rri, + fs=target_fs, + event_mask=event_mask["SA_Label"], + save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots.png", + show=show + ) + + draw_tools.draw_psg_signal( + samp_id=samp_id, + tho_signal=normalized_tho_signal, + abd_signal=normalized_abd_signal, + flowp_signal=normalized_flowp_signal, + flowt_signal=normalized_flowt_signal, + spo2_signal=spo2_data_filt_fill, + effort_signal=normalized_effort_signal, + rri_signal = rri, + fs=target_fs, + event_mask=event_mask["SA_Label"], + save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots_fill.png", + show=show + ) + + spo2_disable_mask = spo2_disable_mask[::target_fs] + min_len = min(len(event_mask["Disable_Label"]), len(spo2_disable_mask)) + + if len(event_mask["Disable_Label"]) != len(spo2_disable_mask): + print(f"Warning: Data length mismatch! Truncating to {min_len}.") + event_mask["Disable_Label"] = event_mask["Disable_Label"][:min_len] & spo2_disable_mask[:min_len] + + event_list = { + "EnableSegment": utils.event_mask_2_list(1 - event_mask["Disable_Label"]), + "DisableSegment": utils.event_mask_2_list(event_mask["Disable_Label"])} + + spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=95) + + segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose) + if verbose: + print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}") + + + # 复制mask到processed_Labels文件夹 + save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv" + shutil.copyfile(mask_excel_path, save_mask_excel_path) + + # 复制SA Label_corrected.csv到processed_Labels文件夹 + sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_Sync.csv") + if sa_label_corrected_path.exists(): + save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv" + shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path) + else: + if verbose: + print(f"Warning: {sa_label_corrected_path} does not exist.") + + # 保存处理后的信号和截取的片段列表 + save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz" + save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz" + + # psg_data更新为处理后的信号 + # 用下划线替换键里面的空格 + psg_data = { + "Effort Tho": { + "name": "Effort_Tho", + "data": normalized_tho_signal, + "fs": target_fs, + "length": len(normalized_tho_signal), + "second": len(normalized_tho_signal) / target_fs + }, + "Effort Abd": { + "name": "Effort_Abd", + "data": normalized_abd_signal, + "fs": target_fs, + "length": len(normalized_abd_signal), + "second": len(normalized_abd_signal) / target_fs + }, + "Effort": { + "name": "Effort", + "data": normalized_effort_signal, + "fs": target_fs, + "length": len(normalized_effort_signal), + "second": len(normalized_effort_signal) / target_fs + }, + "Flow P": { + "name": "Flow_P", + "data": normalized_flowp_signal, + "fs": target_fs, + "length": len(normalized_flowp_signal), + "second": len(normalized_flowp_signal) / target_fs + }, + "Flow T": { + "name": "Flow_T", + "data": normalized_flowt_signal, + "fs": target_fs, + "length": len(normalized_flowt_signal), + "second": len(normalized_flowt_signal) / target_fs + }, + "SpO2": { + "name": "SpO2", + "data": spo2_data_filt_fill, + "fs": target_fs, + "length": len(spo2_data_filt_fill), + "second": len(spo2_data_filt_fill) / target_fs + }, + "HR": { + "name": "HR", + "data": psg_data["HR"]["data"], + "fs": psg_data["HR"]["fs"], + "length": psg_data["HR"]["length"], + "second": psg_data["HR"]["second"] + }, + "RRI": { + "name": "RRI", + "data": rri, + "fs": target_fs, + "length": len(rri), + "second": len(rri) / target_fs + }, + "5_class": { + "name": "Stage", + "data": psg_data["5_class"]["data"], + "fs": psg_data["5_class"]["fs"], + "length": psg_data["5_class"]["length"], + "second": psg_data["5_class"]["second"] + } + } + + np.savez_compressed(save_signal_path, **psg_data) + np.savez_compressed(save_segment_path, + segment_list=segment_list, + disable_segment_list=disable_segment_list) + if verbose: + print(f"Saved processed signals to: {save_signal_path}") + print(f"Saved segment list to: {save_segment_path}") + + if draw_segment: + total_len = len(segment_list) + len(disable_segment_list) + if verbose: + print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}") + + + draw_tools.draw_psg_label( + psg_data=psg_data, + psg_label=event_mask["SA_Label"], + segment_list=segment_list, + save_path=visual_path / f"{samp_id}" / "enable", + verbose=verbose, + multi_p=multi_p, + multi_task_id=multi_task_id + ) + + draw_tools.draw_psg_label( + psg_data=psg_data, + psg_label=event_mask["SA_Label"], + segment_list=disable_segment_list, + save_path=visual_path / f"{samp_id}" / "disable", + verbose=verbose, + multi_p=multi_p, + multi_task_id=multi_task_id + ) + + # 显式删除大型对象 + try: + del psg_data + del normalized_tho_signal, normalized_abd_signal + del normalized_flowp_signal, normalized_flowt_signal + del normalized_effort_signal + del spo2_data_filt, spo2_data_filt_fill + del rri + del event_mask, event_list + del segment_list, disable_segment_list + except: + pass + + # 强制垃圾回收 + gc.collect() + + + +def multiprocess_entry(_progress, task_id, _id): + build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id) + + +def multiprocess_with_tqdm(args_list, n_processes): + from concurrent.futures import ProcessPoolExecutor + from rich import progress + + with progress.Progress( + "[progress.description]{task.description}", + progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + progress.MofNCompleteColumn(), + progress.TimeRemainingColumn(), + progress.TimeElapsedColumn(), + refresh_per_second=1, # bit slower updates + transient=False + ) as progress: + futures = [] + with multiprocessing.Manager() as manager: + _progress = manager.dict() + overall_progress_task = progress.add_task("[green]All jobs progress:") + with ProcessPoolExecutor(max_workers=n_processes, mp_context=multiprocessing.get_context("spawn")) as executor: + for i_args in range(len(args_list)): + args = args_list[i_args] + task_id = progress.add_task(f"task {i_args}", visible=True) + futures.append(executor.submit(multiprocess_entry, _progress, task_id, args_list[i_args])) + # monitor the progress: + while (n_finished := sum([future.done() for future in futures])) < len( + futures + ): + progress.update( + overall_progress_task, completed=n_finished, total=len(futures) + ) + for task_id, update_data in _progress.items(): + desc = update_data.get("desc", "") + # update the progress bar for this task: + progress.update( + task_id, + completed=update_data.get("progress", 0), + total=update_data.get("total", 0), + description=desc + ) + + # raise any errors: + for future in futures: + future.result() + + +def multiprocess_with_pool(args_list, n_processes): + """使用Pool,每个worker处理固定数量任务后重启""" + from multiprocessing import Pool + + # maxtasksperchild 设置每个worker处理多少任务后重启(释放内存) + with Pool(processes=n_processes, maxtasksperchild=2) as pool: + results = [] + for samp_id in args_list: + result = pool.apply_async( + build_HYS_dataset_segment, + args=(samp_id, False, True, False, None, None) + ) + results.append(result) + + # 等待所有任务完成 + for i, result in enumerate(results): + try: + result.get() + print(f"Completed {i + 1}/{len(args_list)}") + except Exception as e: + print(f"Error processing task {i}: {e}") + + pool.close() + pool.join() + + +if __name__ == '__main__': + yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml" + + conf = utils.load_dataset_conf(yaml_path) + select_ids = conf["select_ids"] + root_path = Path(conf["root_path"]) + mask_path = Path(conf["mask_save_path"]) + save_path = Path(conf["dataset_config"]["dataset_save_path"]) + visual_path = Path(conf["dataset_config"]["dataset_visual_path"]) + dataset_config = conf["dataset_config"] + + visual_path.mkdir(parents=True, exist_ok=True) + + save_processed_signal_path = save_path / "Signals" + save_processed_signal_path.mkdir(parents=True, exist_ok=True) + + save_segment_list_path = save_path / "Segments_List" + save_segment_list_path.mkdir(parents=True, exist_ok=True) + + save_processed_label_path = save_path / "Labels" + save_processed_label_path.mkdir(parents=True, exist_ok=True) + + # print(f"select_ids: {select_ids}") + # print(f"root_path: {root_path}") + # print(f"save_path: {save_path}") + # print(f"visual_path: {visual_path}") + + org_signal_root_path = root_path / "OrgBCG_Aligned" + psg_signal_root_path = root_path / "PSG_Aligned" + print(select_ids) + + # build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True) + + # for samp_id in select_ids: + # print(f"Processing sample ID: {samp_id}") + # build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) + + # multiprocess_with_tqdm(args_list=select_ids, n_processes=8) + multiprocess_with_pool(args_list=select_ids, n_processes=8) \ No newline at end of file diff --git a/dataset_builder/HYS_dataset.py b/dataset_builder/HYS_dataset.py index 20024f5..f944539 100644 --- a/dataset_builder/HYS_dataset.py +++ b/dataset_builder/HYS_dataset.py @@ -33,7 +33,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=verbose) bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs, verbose=verbose) - normalized_resp_signal = signal_method.normalize_resp_signal(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"]) + normalized_resp_signal = signal_method.normalize_resp_signal_by_segment(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"]) # 如果signal_data采样率过,进行降采样 @@ -123,7 +123,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose) psg_data["HR"] = { "name": "HR", - "data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"]), + "data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"], psg_data["Rpeak"]["fs"]), "fs": psg_data["ECG_Sync"]["fs"], "length": psg_data["ECG_Sync"]["length"], "second": psg_data["ECG_Sync"]["second"] @@ -136,28 +136,28 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T if verbose: print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}") - draw_tools.draw_psg_bcg_label(psg_data=psg_data, - psg_label=psg_event_mask, - bcg_data=bcg_data, - event_mask=event_mask, - segment_list=segment_list, - save_path=visual_path / f"{samp_id}" / "enable", - verbose=verbose, - multi_p=multi_p, - multi_task_id=multi_task_id - ) - - draw_tools.draw_psg_bcg_label( - psg_data=psg_data, - psg_label=psg_event_mask, - bcg_data=bcg_data, - event_mask=event_mask, - segment_list=disable_segment_list, - save_path=visual_path / f"{samp_id}" / "disable", - verbose=verbose, - multi_p=multi_p, - multi_task_id=multi_task_id - ) + # draw_tools.draw_psg_bcg_label(psg_data=psg_data, + # psg_label=psg_event_mask, + # bcg_data=bcg_data, + # event_mask=event_mask, + # segment_list=segment_list, + # save_path=visual_path / f"{samp_id}" / "enable", + # verbose=verbose, + # multi_p=multi_p, + # multi_task_id=multi_task_id + # ) + # + # draw_tools.draw_psg_bcg_label( + # psg_data=psg_data, + # psg_label=psg_event_mask, + # bcg_data=bcg_data, + # event_mask=event_mask, + # segment_list=disable_segment_list, + # save_path=visual_path / f"{samp_id}" / "disable", + # verbose=verbose, + # multi_p=multi_p, + # multi_task_id=multi_task_id + # ) @@ -241,11 +241,12 @@ if __name__ == '__main__': org_signal_root_path = root_path / "OrgBCG_Aligned" psg_signal_root_path = root_path / "PSG_Aligned" + print(select_ids) - build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True) + # build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True) # for samp_id in select_ids: # print(f"Processing sample ID: {samp_id}") # build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) - # multiprocess_with_tqdm(args_list=select_ids, n_processes=16) \ No newline at end of file + multiprocess_with_tqdm(args_list=select_ids, n_processes=16) \ No newline at end of file diff --git a/dataset_config/HYS_PSG_config.yaml b/dataset_config/HYS_PSG_config.yaml new file mode 100644 index 0000000..63233df --- /dev/null +++ b/dataset_config/HYS_PSG_config.yaml @@ -0,0 +1,129 @@ +select_ids: + - 54 + - 88 + - 220 + - 221 + - 229 + - 282 + - 286 + - 541 + - 579 + - 582 + - 670 + - 671 + - 683 + - 684 + - 735 + - 933 + - 935 + - 950 + - 952 + - 960 + - 962 + - 967 + - 1302 + +root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS +mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG + +dataset_config: + window_sec: 180 + stride_sec: 60 + dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_PSG_dataset + dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_PSG_dataset/visualization + + +effort: + downsample_fs: 10 + +effort_filter: + filter_type: bandpass + low_cut: 0.05 + high_cut: 0.5 + order: 3 + +flow: + downsample_fs: 10 + +flow_filter: + filter_type: bandpass + low_cut: 0.05 + high_cut: 0.5 + order: 3 + +#ecg: +# downsample_fs: 100 +# +#ecg_filter: +# filter_type: bandpass +# low_cut: 0.5 +# high_cut: 40 +# order: 5 + + +#resp: +# downsample_fs_1: None +# downsample_fs_2: 10 +# +#resp_filter: +# filter_type: bandpass +# low_cut: 0.05 +# high_cut: 0.5 +# order: 3 +# +#resp_low_amp: +# window_size_sec: 30 +# stride_sec: +# amplitude_threshold: 3 +# merge_gap_sec: 60 +# min_duration_sec: 60 +# +#resp_movement: +# window_size_sec: 20 +# stride_sec: 1 +# std_median_multiplier: 4 +# compare_intervals_sec: +# - 60 +# - 120 +## - 180 +# interval_multiplier: 3 +# merge_gap_sec: 30 +# min_duration_sec: 1 +# +#resp_movement_revise: +# up_interval_multiplier: 3 +# down_interval_multiplier: 2 +# compare_intervals_sec: 30 +# merge_gap_sec: 10 +# min_duration_sec: 1 +# +#resp_amp_change: +# mav_calc_window_sec: 4 +# threshold_amplitude: 0.25 +# threshold_energy: 0.4 +# +# +#bcg: +# downsample_fs: 100 +# +#bcg_filter: +# filter_type: bandpass +# low_cut: 1 +# high_cut: 10 +# order: 10 +# +#bcg_low_amp: +# window_size_sec: 1 +# stride_sec: +# amplitude_threshold: 8 +# merge_gap_sec: 20 +# min_duration_sec: 3 +# +# +#bcg_movement: +# window_size_sec: 2 +# stride_sec: +# merge_gap_sec: 20 +# min_duration_sec: 4 + + diff --git a/draw_tools/__init__.py b/draw_tools/__init__.py index 3386b90..ab3b0ae 100644 --- a/draw_tools/__init__.py +++ b/draw_tools/__init__.py @@ -1,2 +1,2 @@ -from .draw_statics import draw_signal_with_mask -from .draw_label import draw_psg_bcg_label, draw_resp_label \ No newline at end of file +from .draw_statics import draw_signal_with_mask, draw_psg_signal +from .draw_label import draw_psg_bcg_label,draw_psg_label \ No newline at end of file diff --git a/draw_tools/draw_label.py b/draw_tools/draw_label.py index b93cc6c..29daa5d 100644 --- a/draw_tools/draw_label.py +++ b/draw_tools/draw_label.py @@ -7,10 +7,10 @@ import seaborn as sns import numpy as np from tqdm.rich import tqdm import utils - +import gc # 添加with_prediction参数 -psg_chn_name2ax = { +psg_bcg_chn_name2ax = { "SpO2": 0, "Flow T": 1, "Flow P": 2, @@ -24,6 +24,19 @@ psg_chn_name2ax = { "bcg_twinx": 10, } +psg_chn_name2ax = { + "SpO2": 0, + "Flow T": 1, + "Flow P": 2, + "Effort Tho": 3, + "Effort Abd": 4, + "Effort": 5, + "HR": 6, + "RRI": 7, + "Stage": 8, +} + + resp_chn_name2ax = { "resp": 0, "bcg": 1, @@ -39,6 +52,54 @@ def create_psg_bcg_figure(): ax = fig.add_subplot(gs[i]) axes.append(ax) + axes[psg_bcg_chn_name2ax["SpO2"]].grid(True) + # axes[0].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["SpO2"]].set_ylim((85, 100)) + axes[psg_bcg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["Flow T"]].grid(True) + # axes[1].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["Flow P"]].grid(True) + # axes[2].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["Effort Tho"]].grid(True) + # axes[3].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["Effort Abd"]].grid(True) + # axes[4].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["HR"]].grid(True) + axes[psg_bcg_chn_name2ax["HR"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["resp"]].grid(True) + axes[psg_bcg_chn_name2ax["resp"]].tick_params(axis='x', colors="white") + axes.append(axes[psg_bcg_chn_name2ax["resp"]].twinx()) + + axes[psg_bcg_chn_name2ax["bcg"]].grid(True) + # axes[5].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white") + axes.append(axes[psg_bcg_chn_name2ax["bcg"]].twinx()) + + axes[psg_bcg_chn_name2ax["Stage"]].grid(True) + # axes[7].xaxis.set_major_formatter(Params.FORMATTER) + + return fig, axes + + +def create_psg_figure(): + fig = plt.figure(figsize=(12, 8), dpi=200) + gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 1, 1, 1]) + fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) + axes = [] + for i in range(9): + ax = fig.add_subplot(gs[i]) + axes.append(ax) + axes[psg_chn_name2ax["SpO2"]].grid(True) # axes[0].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_chn_name2ax["SpO2"]].set_ylim((85, 100)) @@ -60,24 +121,21 @@ def create_psg_bcg_figure(): # axes[4].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["Effort"]].grid(True) + axes[psg_chn_name2ax["Effort"]].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["HR"]].grid(True) axes[psg_chn_name2ax["HR"]].tick_params(axis='x', colors="white") - axes[psg_chn_name2ax["resp"]].grid(True) - axes[psg_chn_name2ax["resp"]].tick_params(axis='x', colors="white") - axes.append(axes[psg_chn_name2ax["resp"]].twinx()) + axes[psg_chn_name2ax["RRI"]].grid(True) + axes[psg_chn_name2ax["RRI"]].tick_params(axis='x', colors="white") - axes[psg_chn_name2ax["bcg"]].grid(True) - # axes[5].xaxis.set_major_formatter(Params.FORMATTER) - axes[psg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white") - axes.append(axes[psg_chn_name2ax["bcg"]].twinx()) axes[psg_chn_name2ax["Stage"]].grid(True) - # axes[7].xaxis.set_major_formatter(Params.FORMATTER) + return fig, axes - def create_resp_figure(): fig = plt.figure(figsize=(12, 6), dpi=100) gs = GridSpec(2, 1, height_ratios=[3, 2]) @@ -150,8 +208,8 @@ def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, ev def plt_stage_on_ax(ax, stage_data, segment_start, segment_end): stage_signal = stage_data["data"] stage_fs = stage_data["fs"] - time_axis = np.linspace(segment_start / stage_fs, segment_end / stage_fs, segment_end - segment_start) - ax.plot(time_axis, stage_signal[segment_start:segment_end], color='black', label=stage_data["name"]) + time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * stage_fs) + ax.plot(time_axis, stage_signal[segment_start * stage_fs:segment_end * stage_fs], color='black', label=stage_data["name"]) ax.set_ylim(0, 6) ax.set_yticks([1, 2, 3, 4, 5]) ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"]) @@ -162,11 +220,11 @@ def plt_stage_on_ax(ax, stage_data, segment_start, segment_end): def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end): spo2_signal = spo2_data["data"] spo2_fs = spo2_data["fs"] - time_axis = np.linspace(segment_start / spo2_fs, segment_end / spo2_fs, segment_end - segment_start) - ax.plot(time_axis, spo2_signal[segment_start:segment_end], color='black', label=spo2_data["name"]) + time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * spo2_fs) + ax.plot(time_axis, spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs], color='black', label=spo2_data["name"]) - if spo2_signal[segment_start:segment_end].min() < 85: - ax.set_ylim((spo2_signal[segment_start:segment_end].min() - 5, 100)) + if spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() < 85: + ax.set_ylim((spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() - 5, 100)) else: ax.set_ylim((85, 100)) ax.set_ylabel("SpO2 (%)") @@ -197,6 +255,56 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, # event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0) fig, axes = create_psg_bcg_figure() + for i, (segment_start, segment_end) in enumerate(segment_list): + for ax in axes: + ax.cla() + + plt_spo2_on_ax(axes[psg_bcg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end) + plt_stage_on_ax(axes[psg_bcg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end, + event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], + ax2=axes[psg_bcg_chn_name2ax["resp_twinx"]]) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end, + event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], + ax2=axes[psg_bcg_chn_name2ax["bcg_twinx"]]) + + + if save_path is not None: + fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png") + # print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}") + + if multi_p is not None: + multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"} + + plt.close(fig) + plt.close('all') + gc.collect() + + +def draw_psg_label(psg_data, psg_label, segment_list, save_path=None, verbose=True, + multi_p=None, multi_task_id=None): + if save_path is not None: + save_path.mkdir(parents=True, exist_ok=True) + + + if multi_p is None: + # 遍历psg_data中所有数据的长度 + for i in range(len(psg_data.keys())): + chn_name = list(psg_data.keys())[i] + print(f"{chn_name} data length: {len(psg_data[chn_name]['data'])}, fs: {psg_data[chn_name]['fs']}") + # psg_label的长度 + print(f"psg_label length: {len(psg_label)}") + + fig, axes = create_psg_figure() for i, (segment_start, segment_end) in enumerate(segment_list): for ax in axes: ax.cla() @@ -211,14 +319,10 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort"]], psg_data["Effort"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) - plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end, - event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], - ax2=axes[psg_chn_name2ax["resp_twinx"]]) - plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end, - event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], - ax2=axes[psg_chn_name2ax["bcg_twinx"]]) - + plt_signal_label_on_ax(axes[psg_chn_name2ax["RRI"]], psg_data["RRI"], segment_start, segment_end) if save_path is not None: fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png") @@ -226,23 +330,8 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, if multi_p is not None: multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"} + plt.close(fig) + plt.close('all') + gc.collect() -def draw_resp_label(resp_data, resp_label, segment_list): - for mask in resp_label.keys(): - if mask.startswith("Resp_"): - resp_label[mask] = utils.none_to_nan_mask(resp_label[mask], 0) - - # resp_label["Resp_Score_Alpha"] = score_mask2alpha(resp_label["Resp_Score"]) - # resp_label["Resp_Label_Alpha"] = utils.none_to_nan_mask(resp_label["Resp_Label_Alpha"], 0) - - fig, axes = create_resp_figure() - for segment_start, segment_end in segment_list: - for ax in axes: - ax.cla() - - plt_signal_label_on_ax(axes[resp_chn_name2ax["resp"]], resp_data["resp_signal"], segment_start, segment_end, - resp_label, multi_labels="resp", event_codes=[1, 2, 3, 4]) - plt_signal_label_on_ax(axes[resp_chn_name2ax["bcg"]], resp_data["bcg_signal"], segment_start, segment_end, - resp_label, multi_labels="bcg", event_codes=[1, 2, 3, 4]) - plt.show() diff --git a/draw_tools/draw_statics.py b/draw_tools/draw_statics.py index f94679d..06e39aa 100644 --- a/draw_tools/draw_statics.py +++ b/draw_tools/draw_statics.py @@ -247,6 +247,8 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, ax1_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask', 'SA Mask'], loc='upper right') ax1.set_title(f'Sample {samp_id} - Respiration Component') + + ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) ax2.plot(np.linspace(0, len(bcg_data) // bcg_fs, len(bcg_data)), bcg_data, color='green') ax2.set_ylabel('Amplitude') @@ -300,5 +302,70 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, plt.show() +def draw_psg_signal(samp_id, tho_signal, abd_signal, flowp_signal, flowt_signal, spo2_signal, effort_signal, rri_signal, event_mask, fs, + show=False, save_path=None): + sa_mask = event_mask.repeat(fs) + fig, axs = plt.subplots(7, 1, figsize=(18, 12), sharex=True) + time_axis = np.linspace(0, len(tho_signal) / fs, len(tho_signal)) + axs[0].plot(time_axis, tho_signal, label='THO', color='black') + axs[0].set_title(f'Sample {samp_id} - PSG Signal Data') + axs[0].set_ylabel('THO Amplitude') + axs[0].legend(loc='upper right') + + ax0_twin = axs[0].twinx() + ax0_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask') + ax0_twin.autoscale(enable=False, axis='y', tight=True) + ax0_twin.set_ylim((-4, 5)) + + axs[1].plot(time_axis, abd_signal, label='ABD', color='black') + axs[1].set_ylabel('ABD Amplitude') + axs[1].legend(loc='upper right') + + ax1_twin = axs[1].twinx() + ax1_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask') + ax1_twin.autoscale(enable=False, axis='y', tight=True) + ax1_twin.set_ylim((-4, 5)) + + axs[2].plot(time_axis, effort_signal, label='EFFO', color='black') + axs[2].set_ylabel('EFFO Amplitude') + axs[2].legend(loc='upper right') + + ax2_twin = axs[2].twinx() + ax2_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask') + ax2_twin.autoscale(enable=False, axis='y', tight=True) + ax2_twin.set_ylim((-4, 5)) + + axs[3].plot(time_axis, flowp_signal, label='FLOWP', color='black') + axs[3].set_ylabel('FLOWP Amplitude') + axs[3].legend(loc='upper right') + + ax3_twin = axs[3].twinx() + ax3_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask') + ax3_twin.autoscale(enable=False, axis='y', tight=True) + ax3_twin.set_ylim((-4, 5)) + + axs[4].plot(time_axis, flowt_signal, label='FLOWT', color='black') + axs[4].set_ylabel('FLOWT Amplitude') + axs[4].legend(loc='upper right') + + ax4_twin = axs[4].twinx() + ax4_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask') + ax4_twin.autoscale(enable=False, axis='y', tight=True) + ax4_twin.set_ylim((-4, 5)) + + axs[5].plot(time_axis, rri_signal, label='RRI', color='black') + axs[5].set_ylabel('RRI Amplitude') + axs[5].legend(loc='upper right') + axs[6].plot(time_axis, spo2_signal, label='SPO2', color='black') + axs[6].set_ylabel('SPO2 Amplitude') + axs[6].set_xlabel('Time (s)') + axs[6].legend(loc='upper right') + + + if save_path is not None: + plt.savefig(save_path, dpi=300) + if show: + plt.show() + diff --git a/event_mask_process/HYS_PSG_process.py b/event_mask_process/HYS_PSG_process.py new file mode 100644 index 0000000..bd04f63 --- /dev/null +++ b/event_mask_process/HYS_PSG_process.py @@ -0,0 +1,183 @@ +""" +本脚本完成对呼研所数据的处理,包含以下功能: +1. 数据读取与预处理 + 从传入路径中,进行数据和标签的读取,并进行初步的预处理 + 预处理包括为数据进行滤波、去噪等操作 +2. 数据清洗与异常值处理 +3. 输出清晰后的统计信息 +4. 数据保存 + 将处理后的数据保存到指定路径,便于后续使用 + 主要是保存切分后的数据位置和标签 +5. 可视化 + 提供数据处理前后的可视化对比,帮助理解数据变化 + 绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等 + + + +# 低幅值区间规则标定与剔除 +# 高幅值连续体动规则标定与剔除 +# 手动标定不可用区间提剔除 +""" +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) +project_root_path = Path(__file__).resolve().parent.parent + +import shutil +import draw_tools +import utils +import numpy as np +import signal_method +import os + + +os.environ['DISPLAY'] = "localhost:10.0" + + +def process_one_signal(samp_id, show=False): + pass + + tho_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Tho_Sync_*.txt")) + abd_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Abd_Sync_*.txt")) + flowp_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow P_Sync_*.txt")) + flowt_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow T_Sync_*.txt")) + spo2_signal_path = list((org_signal_root_path / f"{samp_id}").glob("SpO2_Sync_*.txt")) + stage_signal_path = list((org_signal_root_path / f"{samp_id}").glob("5_class_Sync_*.txt")) + + if not tho_signal_path: + raise FileNotFoundError(f"Effort Tho_Sync file not found for sample ID: {samp_id}") + tho_signal_path = tho_signal_path[0] + print(f"Processing Effort Tho_Sync signal file: {tho_signal_path}") + if not abd_signal_path: + raise FileNotFoundError(f"Effort Abd_Sync file not found for sample ID: {samp_id}") + abd_signal_path = abd_signal_path[0] + print(f"Processing Effort Abd_Sync signal file: {abd_signal_path}") + if not flowp_signal_path: + raise FileNotFoundError(f"Flow P_Sync file not found for sample ID: {samp_id}") + flowp_signal_path = flowp_signal_path[0] + print(f"Processing Flow P_Sync signal file: {flowp_signal_path}") + if not flowt_signal_path: + raise FileNotFoundError(f"Flow T_Sync file not found for sample ID: {samp_id}") + flowt_signal_path = flowt_signal_path[0] + print(f"Processing Flow T_Sync signal file: {flowt_signal_path}") + if not spo2_signal_path: + raise FileNotFoundError(f"SpO2_Sync file not found for sample ID: {samp_id}") + spo2_signal_path = spo2_signal_path[0] + print(f"Processing SpO2_Sync signal file: {spo2_signal_path}") + if not stage_signal_path: + raise FileNotFoundError(f"5_class_Sync file not found for sample ID: {samp_id}") + stage_signal_path = stage_signal_path[0] + print(f"Processing 5_class_Sync signal file: {stage_signal_path}") + + + + label_path = (label_root_path / f"{samp_id}").glob("SA Label_Sync.csv") + if not label_path: + raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}") + label_path = list(label_path)[0] + print(f"Processing Label_corrected file: {label_path}") + # + # # 保存处理后的数据和标签 + save_samp_path = save_path / f"{samp_id}" + save_samp_path.mkdir(parents=True, exist_ok=True) + + # # # 读取信号数据 + tho_data_raw, tho_length, tho_fs, tho_second = utils.read_signal_txt(tho_signal_path, dtype=float, verbose=True) + # abd_data_raw, abd_length, abd_fs, abd_second = utils.read_signal_txt(abd_signal_path, dtype=float, verbose=True) + # flowp_data_raw, flowp_length, flowp_fs, flowp_second = utils.read_signal_txt(flowp_signal_path, dtype=float, verbose=True) + # flowt_data_raw, flowt_length, flowt_fs, flowt_second = utils.read_signal_txt(flowt_signal_path, dtype=float, verbose=True) + # spo2_data_raw, spo2_length, spo2_fs, spo2_second = utils.read_signal_txt(spo2_signal_path, dtype=float, verbose=True) + stage_data_raw, stage_length, stage_fs, stage_second = utils.read_signal_txt(stage_signal_path, dtype=str, verbose=True) + + + # + # # 预处理与滤波 + # tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=tho_data_raw, effort_fs=tho_fs) + # abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=abd_data_raw, effort_fs=abd_fs) + # flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=flowp_data_raw, effort_fs=flowp_fs) + # flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=flowt_data_raw, effort_fs=flowt_fs) + + # 降采样 + # old_tho_fs = tho_fs + # tho_fs = conf["effort"]["downsample_fs"] + # tho_data_filt = utils.downsample_signal_fast(original_signal=tho_data_filt, original_fs=old_tho_fs, target_fs=tho_fs) + # old_abd_fs = abd_fs + # abd_fs = conf["effort"]["downsample_fs"] + # abd_data_filt = utils.downsample_signal_fast(original_signal=abd_data_filt, original_fs=old_abd_fs, target_fs=abd_fs) + # old_flowp_fs = flowp_fs + # flowp_fs = conf["effort"]["downsample_fs"] + # flowp_data_filt = utils.downsample_signal_fast(original_signal=flowp_data_filt, original_fs=old_flowp_fs, target_fs=flowp_fs) + # old_flowt_fs = flowt_fs + # flowt_fs = conf["effort"]["downsample_fs"] + # flowt_data_filt = utils.downsample_signal_fast(original_signal=flowt_data_filt, original_fs=old_flowt_fs, target_fs=flowt_fs) + + # spo2不降采样 + # spo2_data_filt = spo2_data_raw + # spo2_fs = spo2_fs + + label_data = utils.read_raw_psg_label(path=label_path) + event_mask, score_mask = utils.generate_event_mask(signal_second=tho_second, event_df=label_data, use_correct=False, with_score=False) + # event_mask > 0 的部分为1,其他为0 + score_mask = np.where(event_mask > 0, 1, 0) + + # 根据睡眠分期生成不可用区间 + wake_mask = utils.get_wake_mask(stage_data_raw) + # 剔除短于60秒的觉醒区间 + wake_mask = utils.remove_short_durations(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), min_duration_sec=60) + # 合并短于120秒的觉醒区间 + wake_mask = utils.merge_short_gaps(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), max_gap_sec=60) + + disable_label = wake_mask + + disable_label = disable_label[:tho_second] + + + # 复制事件文件 到保存路径 + sa_label_save_name = f"{samp_id}_" + label_path.name + shutil.copyfile(label_path, save_samp_path / sa_label_save_name) + # + # 新建一个dataframe,分别是秒数、SA标签, + save_dict = { + "Second": np.arange(tho_second), + "SA_Label": event_mask, + "SA_Score": score_mask, + "Disable_Label": disable_label, + "Resp_LowAmp_Label": np.zeros_like(event_mask), + "Resp_Movement_Label": np.zeros_like(event_mask), + "Resp_AmpChange_Label": np.zeros_like(event_mask), + "BCG_LowAmp_Label": np.zeros_like(event_mask), + "BCG_Movement_Label": np.zeros_like(event_mask), + "BCG_AmpChange_Label": np.zeros_like(event_mask) + } + + mask_label_save_name = f"{samp_id}_Processed_Labels.csv" + utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict) + + +if __name__ == '__main__': + yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml" + # disable_df_path = project_root_path / "排除区间.xlsx" + # + conf = utils.load_dataset_conf(yaml_path) + + root_path = Path(conf["root_path"]) + save_path = Path(conf["mask_save_path"]) + select_ids = conf["select_ids"] + # + print(f"select_ids: {select_ids}") + print(f"root_path: {root_path}") + print(f"save_path: {save_path}") + # + org_signal_root_path = root_path / "PSG_Aligned" + label_root_path = root_path / "PSG_Aligned" + # + # all_samp_disable_df = utils.read_disable_excel(disable_df_path) + # + # process_one_signal(select_ids[0], show=True) + # # + for samp_id in select_ids: + print(f"Processing sample ID: {samp_id}") + process_one_signal(samp_id, show=False) + print(f"Finished processing sample ID: {samp_id}\n\n") + pass \ No newline at end of file diff --git a/signal_method/__init__.py b/signal_method/__init__.py index 7ce8cdb..eb9e1e2 100644 --- a/signal_method/__init__.py +++ b/signal_method/__init__.py @@ -2,5 +2,5 @@ from .rule_base_event import detect_low_amplitude_signal, detect_movement from .rule_base_event import position_based_sleep_recognition_v2, position_based_sleep_recognition_v3 from .rule_base_event import movement_revise from .time_metrics import calc_mav_by_slide_windows -from .signal_process import signal_filter_split, rpeak2hr -from .normalize_method import normalize_resp_signal \ No newline at end of file +from .signal_process import signal_filter_split, rpeak2hr, psg_effort_filter, rpeak2rri_interpolation +from .normalize_method import normalize_resp_signal_by_segment diff --git a/signal_method/normalize_method.py b/signal_method/normalize_method.py index 8ed89ce..095f16e 100644 --- a/signal_method/normalize_method.py +++ b/signal_method/normalize_method.py @@ -3,7 +3,7 @@ import pandas as pd import numpy as np from scipy import signal -def normalize_resp_signal(resp_signal: np.ndarray, resp_fs, movement_mask, enable_list): +def normalize_resp_signal_by_segment(resp_signal: np.ndarray, resp_fs, movement_mask, enable_list): # 根据呼吸信号的幅值改变区间,对每段进行Z-Score标准化 normalized_resp_signal = np.zeros_like(resp_signal) # 全部填成nan @@ -33,4 +33,20 @@ def normalize_resp_signal(resp_signal: np.ndarray, resp_fs, movement_mask, enabl raw_segment = resp_signal[enable_start:enable_end] normalized_resp_signal[enable_start:enable_end] = (raw_segment - segment_mean) / segment_std + + #如果enable区间不从0开始,则将前面的部分也进行标准化 + if enable_list[0][0] > 0: + new_enable_start = 0 + enable_start = enable_list[0][0] * resp_fs + enable_end = enable_list[0][1] * resp_fs + segment = resp_signal_no_movement[enable_start:enable_end] + + segment_mean = np.nanmean(segment) + segment_std = np.nanstd(segment) + if segment_std == 0: + raise ValueError(f"segment_std is zero! segment_start: {enable_start}, segment_end: {enable_end}") + + raw_segment = resp_signal[new_enable_start:enable_start] + normalized_resp_signal[new_enable_start:enable_start] = (raw_segment - segment_mean) / segment_std + return normalized_resp_signal diff --git a/signal_method/signal_process.py b/signal_method/signal_process.py index c0c6699..a303ca4 100644 --- a/signal_method/signal_process.py +++ b/signal_method/signal_process.py @@ -1,4 +1,5 @@ import numpy as np +from scipy.interpolate import interp1d import utils @@ -44,14 +45,24 @@ def signal_filter_split(conf, signal_data_raw, signal_fs, verbose=True): return signal_data, resp_data_2, resp_fs, bcg_data, signal_fs +def psg_effort_filter(conf, effort_data_raw, effort_fs): + # 滤波 + effort_data_1 = utils.bessel(data=effort_data_raw, _type=conf["effort_filter"]["filter_type"], + low_cut=conf["effort_filter"]["low_cut"], + high_cut=conf["effort_filter"]["high_cut"], order=conf["effort_filter"]["order"], + sample_rate=effort_fs) + # 移动平均 + effort_data_2 = utils.average_filter(raw_data=effort_data_1, sample_rate=effort_fs, window_size_sec=20) + return effort_data_raw, effort_data_2, effort_fs -def rpeak2hr(rpeak_indices, signal_length): + +def rpeak2hr(rpeak_indices, signal_length, ecg_fs): hr_signal = np.zeros(signal_length) for i in range(1, len(rpeak_indices)): rri = rpeak_indices[i] - rpeak_indices[i - 1] if rri == 0: continue - hr = 60 * 1000 / rri # 心率,单位:bpm + hr = 60 * ecg_fs / rri # 心率,单位:bpm if hr > 120: hr = 120 elif hr < 30: @@ -62,3 +73,35 @@ def rpeak2hr(rpeak_indices, signal_length): hr_signal[rpeak_indices[-1]:] = hr_signal[rpeak_indices[-2]] return hr_signal +def rpeak2rri_repeat(rpeak_indices, signal_length, ecg_fs): + rri_signal = np.zeros(signal_length) + for i in range(1, len(rpeak_indices)): + rri = rpeak_indices[i] - rpeak_indices[i - 1] + rri_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = rri + # 填充最后一个R峰之后的RRI值 + if len(rpeak_indices) > 1: + rri_signal[rpeak_indices[-1]:] = rri_signal[rpeak_indices[-2]] + + # 遍历异常值 + for i in range(1, len(rpeak_indices)): + rri = rpeak_indices[i] - rpeak_indices[i - 1] + if rri < 0.3 * ecg_fs or rri > 2 * ecg_fs: + rri_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = 0 + + return rri_signal + +def rpeak2rri_interpolation(rpeak_indices, ecg_fs, rri_fs): + r_peak_time = np.asarray(rpeak_indices) / ecg_fs + rri = np.diff(r_peak_time) + t_rri = r_peak_time[1:] + + mask = (rri > 0.3) & (rri < 2.0) + rri_clean = rri[mask] + t_rri_clean = t_rri[mask] + + t_uniform = np.arange(t_rri_clean[0], t_rri_clean[-1], 1/rri_fs) + f = interp1d(t_rri_clean, rri_clean, kind='linear', fill_value="extrapolate") + rri_uniform = f(t_uniform) + + return rri_uniform, rri_fs + diff --git a/utils/HYS_FileReader.py b/utils/HYS_FileReader.py index f41d16a..d2fb03f 100644 --- a/utils/HYS_FileReader.py +++ b/utils/HYS_FileReader.py @@ -178,6 +178,55 @@ def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame: return df +def read_raw_psg_label(path: Union[str, Path], verbose=True) -> pd.DataFrame: + """ + Read a CSV file and return it as a pandas DataFrame. + + Args: + path (str | Path): Path to the CSV file. + verbose (bool): + Returns: + pd.DataFrame: The content of the CSV file as a pandas DataFrame. + :param path: + :param verbose: + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + # 直接用pandas读取 包含中文 故指定编码 + df = pd.read_csv(path, encoding="gbk") + if verbose: + print(f"Label file read from {path}, number of rows: {len(df)}") + + num_psg_events = np.sum(df["Event type"].notna()) + # 统计事件 + num_psg_hyp = np.sum(df["Event type"] == "Hypopnea") + num_psg_csa = np.sum(df["Event type"] == "Central apnea") + num_psg_osa = np.sum(df["Event type"] == "Obstructive apnea") + num_psg_msa = np.sum(df["Event type"] == "Mixed apnea") + + + + if verbose: + print("Event Statistics:") + # 格式化输出 总计/来自PSG/手动/删除/未标注 指定宽度 + print(f"Type {'Total':^8s}") + print( + f"Hyp: {num_psg_hyp:^8d} ") + print( + f"CSA: {num_psg_csa:^8d} ") + print( + f"OSA: {num_psg_osa:^8d} ") + print( + f"MSA: {num_psg_msa:^8d} ") + print( + f"All: {num_psg_events:^8d}") + + df["Start"] = df["Start"].astype(int) + df["End"] = df["End"].astype(int) + return df + def read_disable_excel(path: Union[str, Path]) -> pd.DataFrame: """ Read an Excel file and return it as a pandas DataFrame. @@ -225,6 +274,15 @@ def read_mask_execl(path: Union[str, Path]): return event_mask, event_list +def read_psg_mask_excel(path: Union[str, Path]): + + df = pd.read_csv(path) + event_mask = df.to_dict(orient="list") + for key in event_mask: + event_mask[key] = np.array(event_mask[key]) + + return event_mask + def read_psg_channel(path_str: Union[str, Path], channel_number: list[int], verbose=True): """ diff --git a/utils/__init__.py b/utils/__init__.py index 362297e..c449802 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,10 +1,12 @@ -from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel, read_psg_label +from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel, read_psg_label, read_raw_psg_label, read_psg_mask_excel from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list from .operation_tools import merge_short_gaps, remove_short_durations from .operation_tools import collect_values from .operation_tools import save_process_label from .operation_tools import none_to_nan_mask +from .operation_tools import get_wake_mask +from .operation_tools import fill_spo2_anomaly from .split_method import resp_split from .HYS_FileReader import read_mask_execl, read_psg_channel from .event_map import E2N, N2Chn, Stage2N, ColorCycle -from .filter_func import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel \ No newline at end of file +from .filter_func import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel, adjust_sample_rate diff --git a/utils/filter_func.py b/utils/filter_func.py index e690c33..5a00c55 100644 --- a/utils/filter_func.py +++ b/utils/filter_func.py @@ -20,6 +20,7 @@ def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10, sample_rate=10 raise ValueError("Please choose a type of fliter") +@timing_decorator() def bessel(data, _type, low_cut=0.0, high_cut=0.0, order=4, sample_rate=1000): if _type == "lowpass": # 低通滤波处理 b, a = signal.bessel(order, low_cut / (sample_rate * 0.5), btype='lowpass', analog=False, norm='mag') @@ -89,6 +90,52 @@ def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=1 return downsampled_signal +def upsample_signal(original_signal, original_fs, target_fs): + """ + 信号升采样 + + 参数: + original_signal : array-like, 原始信号数组 + original_fs : float, 原始采样率 (Hz) + target_fs : float, 目标采样率 (Hz) + + 返回: + upsampled_signal : array-like, 升采样后的信号 + """ + if not isinstance(original_signal, np.ndarray): + original_signal = np.array(original_signal) + if target_fs <= original_fs: + raise ValueError("目标采样率必须大于原始采样率") + if target_fs <= 0 or original_fs <= 0: + raise ValueError("采样率必须为正数") + + upsample_factor = target_fs / original_fs + num_output_samples = int(len(original_signal) * upsample_factor) + + upsampled_signal = signal.resample(original_signal, num_output_samples) + + return upsampled_signal + + +def adjust_sample_rate(signal_data, original_fs, target_fs): + """ + 根据信号的原始采样率和目标采样率,自动选择升采样或降采样。 + + 参数: + signal_data : array-like, 原始信号数组 + original_fs : float, 原始采样率 (Hz) + target_fs : float, 目标采样率 (Hz) + + 返回: + adjusted_signal : array-like, 调整采样率后的信号 + """ + if original_fs == target_fs: + return signal_data + elif original_fs > target_fs: + return downsample_signal_fast(signal_data, original_fs, target_fs) + else: + return upsample_signal(signal_data, original_fs, target_fs) + @timing_decorator() def average_filter(raw_data, sample_rate, window_size_sec=20): diff --git a/utils/operation_tools.py b/utils/operation_tools.py index 866029d..d6d7051 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -6,6 +6,7 @@ import pandas as pd from matplotlib import pyplot as plt import yaml from numpy.ma.core import append +from scipy.interpolate import PchipInterpolator from utils.event_map import E2N plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 @@ -198,9 +199,12 @@ def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray: return disable_mask -def generate_event_mask(signal_second: int, event_df, use_correct=True): +def generate_event_mask(signal_second: int, event_df, use_correct=True, with_score=True): event_mask = np.zeros(signal_second, dtype=int) - score_mask = np.zeros(signal_second, dtype=int) + if with_score: + score_mask = np.zeros(signal_second, dtype=int) + else: + score_mask = None if use_correct: start_name = "correct_Start" end_name = "correct_End" @@ -217,7 +221,8 @@ def generate_event_mask(signal_second: int, event_df, use_correct=True): start = row[start_name] end = row[end_name] + 1 event_mask[start:end] = E2N[row[event_type_name]] - score_mask[start:end] = row["score"] + if with_score: + score_mask[start:end] = row["score"] return event_mask, score_mask @@ -260,4 +265,116 @@ def none_to_nan_mask(mask, ref): else: # 将mask中的0替换为nan,其他的保持 mask = np.where(mask == 0, np.nan, mask) - return mask \ No newline at end of file + return mask + +def get_wake_mask(sleep_stage_mask): + # 将N1, N2, N3, REM视为睡眠 0,其他为清醒 1 + # 输入是字符 分别有 'W', 'N1', 'N2', 'N3', 'R' 等 + wake_mask = np.where(np.isin(sleep_stage_mask, ['N1', 'N2', 'N3', 'REM', 'R']), 0, 1) + return wake_mask + +def detect_spo2_anomaly(spo2, fs, diff_thresh=7): + anomaly = np.zeros(len(spo2), dtype=bool) + + # 生理范围 + anomaly |= (spo2 < 50) | (spo2 > 100) + + # 突变 + diff = np.abs(np.diff(spo2, prepend=spo2[0])) + anomaly |= diff > diff_thresh + + # NaN + anomaly |= np.isnan(spo2) + + return anomaly + +def merge_close_anomalies(anomaly, fs, min_gap_duration): + min_gap = int(min_gap_duration * fs) + merged = anomaly.copy() + + i = 0 + n = len(anomaly) + + while i < n: + if not anomaly[i]: + i += 1 + continue + + # 当前异常段 + start = i + while i < n and anomaly[i]: + i += 1 + end = i + + # 向后看 gap + j = end + while j < n and not anomaly[j]: + j += 1 + + if j < n and (j - end) < min_gap: + merged[end:j] = True + + return merged + +def fill_spo2_anomaly( + spo2_data, + spo2_fs, + max_fill_duration, + min_gap_duration, +): + spo2 = spo2_data.astype(float).copy() + n = len(spo2) + + anomaly = detect_spo2_anomaly(spo2, spo2_fs) + anomaly = merge_close_anomalies(anomaly, spo2_fs, min_gap_duration) + + max_len = int(max_fill_duration * spo2_fs) + + valid_mask = ~anomaly + + i = 0 + while i < n: + if not anomaly[i]: + i += 1 + continue + + start = i + while i < n and anomaly[i]: + i += 1 + end = i + + seg_len = end - start + + # 超长异常段 + if seg_len > max_len: + spo2[start:end] = np.nan + valid_mask[start:end] = False + continue + + has_left = start > 0 and valid_mask[start - 1] + has_right = end < n and valid_mask[end] + + # 开头异常:单侧填充 + if not has_left and has_right: + spo2[start:end] = spo2[end] + continue + + # 结尾异常:单侧填充 + if has_left and not has_right: + spo2[start:end] = spo2[start - 1] + continue + + # 两侧都有 → PCHIP + if has_left and has_right: + x = np.array([start - 1, end]) + y = np.array([spo2[start - 1], spo2[end]]) + + interp = PchipInterpolator(x, y) + spo2[start:end] = interp(np.arange(start, end)) + continue + + # 两侧都没有(极端情况) + spo2[start:end] = np.nan + valid_mask[start:end] = False + + return spo2, valid_mask \ No newline at end of file diff --git a/utils/split_method.py b/utils/split_method.py index f113013..e5c796c 100644 --- a/utils/split_method.py +++ b/utils/split_method.py @@ -54,5 +54,3 @@ def resp_split(dataset_config, event_mask, event_list, verbose=False): return segment_list, disable_segment_list - -