From 3adcf00abb215bbfc2e8304ed8664559a04bd764 Mon Sep 17 00:00:00 2001 From: marques Date: Mon, 30 Mar 2026 09:37:14 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=A4=84=E7=90=86=E5=8A=9F=E8=83=BD=EF=BC=8C=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E4=BF=A1=E5=8F=B7=E8=AF=BB=E5=8F=96=E4=B8=8E=E5=A4=84=E7=90=86?= =?UTF-8?q?=E9=80=BB=E8=BE=91=EF=BC=8C=E6=9B=B4=E6=96=B0=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E4=BB=A5=E6=94=AF=E6=8C=81=E6=96=B0=E6=A0=B7?= =?UTF-8?q?=E6=9C=ACID?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/settings.json | 4 + dataset_builder/HYS_PSG_dataset.py | 275 +++++++++++++++++++++----- dataset_config/HYS_PSG_config.yaml | 31 ++- event_mask_process/HYS_PSG_process.py | 116 ++++++----- utils/HYS_FileReader.py | 5 +- 5 files changed, 329 insertions(+), 102 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..4b5a294 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "python-envs.defaultEnvManager": "ms-python.python:conda", + "python-envs.defaultPackageManager": "ms-python.python:conda" +} \ No newline at end of file diff --git a/dataset_builder/HYS_PSG_dataset.py b/dataset_builder/HYS_PSG_dataset.py index 54e8d05..1d5e801 100644 --- a/dataset_builder/HYS_PSG_dataset.py +++ b/dataset_builder/HYS_PSG_dataset.py @@ -1,25 +1,163 @@ import multiprocessing +import signal import sys +import time from pathlib import Path import os import numpy as np -from utils import N2Chn - os.environ['DISPLAY'] = "localhost:10.0" sys.path.append(str(Path(__file__).resolve().parent.parent)) project_root_path = Path(__file__).resolve().parent.parent import utils +from utils import N2Chn import signal_method import draw_tools import shutil import gc +DEFAULT_YAML_PATH = project_root_path / "dataset_config/HYS_PSG_config.yaml" + +conf = None +select_ids = None +root_path = None +mask_path = None +save_path = None +visual_path = None +dataset_config = None +org_signal_root_path = None +psg_signal_root_path = None +save_processed_signal_path = None +save_segment_list_path = None +save_processed_label_path = None + + +def get_missing_psg_channels(samp_id, channel_number=None): + ensure_runtime_initialized() + + if channel_number is None: + channel_number = [1, 2, 3, 4, 5, 6, 7, 8] + + sample_path = psg_signal_root_path / f"{samp_id}" + if not sample_path.exists(): + return [f"PSG dir missing: {sample_path}"] + + missing_channels = [] + for ch_id in channel_number: + ch_name = N2Chn[ch_id] + if not any(sample_path.glob(f"{ch_name}*.txt")): + missing_channels.append(ch_name) + + return missing_channels + + +def filter_valid_psg_samples(sample_ids, verbose=True): + valid_ids = [] + skipped_ids = [] + + for samp_id in sample_ids: + missing_channels = get_missing_psg_channels(samp_id) + if missing_channels: + skipped_ids.append((samp_id, missing_channels)) + if verbose: + print( + f"Skipping sample {samp_id}: missing PSG channels {missing_channels}", + flush=True, + ) + continue + + valid_ids.append(samp_id) + + if verbose and skipped_ids: + print( + f"Filtered out {len(skipped_ids)} sample(s) with incomplete PSG inputs.", + flush=True, + ) + + return valid_ids, skipped_ids + + +def initialize_runtime(yaml_path=DEFAULT_YAML_PATH): + global conf + global select_ids + global root_path + global mask_path + global save_path + global visual_path + global dataset_config + global org_signal_root_path + global psg_signal_root_path + global save_processed_signal_path + global save_segment_list_path + global save_processed_label_path + + yaml_path = Path(yaml_path) + + conf = utils.load_dataset_conf(yaml_path) + select_ids = conf["select_ids"] + root_path = Path(conf["root_path"]) + mask_path = Path(conf["mask_save_path"]) + save_path = Path(conf["dataset_config"]["dataset_save_path"]) + visual_path = Path(conf["dataset_config"]["dataset_visual_path"]) + dataset_config = conf["dataset_config"] + + visual_path.mkdir(parents=True, exist_ok=True) + + save_processed_signal_path = save_path / "Signals" + save_processed_signal_path.mkdir(parents=True, exist_ok=True) + + save_segment_list_path = save_path / "Segments_List" + save_segment_list_path.mkdir(parents=True, exist_ok=True) + + save_processed_label_path = save_path / "Labels" + save_processed_label_path.mkdir(parents=True, exist_ok=True) + + org_signal_root_path = root_path / "OrgBCG_Aligned" + psg_signal_root_path = root_path / "PSG_Aligned" + + +def ensure_runtime_initialized(): + if conf is None or psg_signal_root_path is None: + initialize_runtime() + + +def sanitize_rpeak_indices(samp_id, rpeak_indices, signal_length, verbose=True): + rpeak_indices = np.asarray(rpeak_indices, dtype=int) + + valid_mask = (rpeak_indices >= 0) & (rpeak_indices < signal_length) + invalid_count = int((~valid_mask).sum()) + if invalid_count > 0: + invalid_indices = rpeak_indices[~valid_mask] + print( + f"Sample {samp_id}: dropping {invalid_count} invalid Rpeak index/indices " + f"outside [0, {signal_length - 1}]. " + f"min_invalid={invalid_indices.min()}, max_invalid={invalid_indices.max()}", + flush=True, + ) + + rpeak_indices = rpeak_indices[valid_mask] + if rpeak_indices.size == 0: + raise ValueError(f"Sample {samp_id}: no valid Rpeak indices remain after bounds check.") + + rpeak_indices = np.unique(rpeak_indices) + if rpeak_indices.size < 2: + raise ValueError(f"Sample {samp_id}: fewer than 2 valid Rpeak indices remain after bounds check.") + + return rpeak_indices + + def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None): + ensure_runtime_initialized() psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose) + psg_data["Rpeak"]["data"] = sanitize_rpeak_indices( + samp_id=samp_id, + rpeak_indices=psg_data["Rpeak"]["data"], + signal_length=psg_data["ECG_Sync"]["length"], + verbose=verbose, + ) total_seconds = min( psg_data[i]["second"] for i in N2Chn.values() if i != "Rpeak" @@ -283,10 +421,28 @@ def multiprocess_entry(_progress, task_id, _id): build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id) +def _init_pool_worker(yaml_path=DEFAULT_YAML_PATH): + # 让主进程统一响应 Ctrl+C,避免父子进程同时处理中断导致退出卡住。 + signal.signal(signal.SIGINT, signal.SIG_IGN) + initialize_runtime(yaml_path) + + +def _shutdown_executor(executor, wait, cancel_futures=False): + if executor is None: + return + + try: + executor.shutdown(wait=wait, cancel_futures=cancel_futures) + except TypeError: + executor.shutdown(wait=wait) + + def multiprocess_with_tqdm(args_list, n_processes): from concurrent.futures import ProcessPoolExecutor from rich import progress + yaml_path = str(DEFAULT_YAML_PATH) + with progress.Progress( "[progress.description]{task.description}", progress.BarColumn(), @@ -301,7 +457,13 @@ def multiprocess_with_tqdm(args_list, n_processes): with multiprocessing.Manager() as manager: _progress = manager.dict() overall_progress_task = progress.add_task("[green]All jobs progress:") - with ProcessPoolExecutor(max_workers=n_processes, mp_context=multiprocessing.get_context("spawn")) as executor: + executor = ProcessPoolExecutor( + max_workers=n_processes, + mp_context=multiprocessing.get_context("spawn"), + initializer=_init_pool_worker, + initargs=(yaml_path,), + ) + try: for i_args in range(len(args_list)): args = args_list[i_args] task_id = progress.add_task(f"task {i_args}", visible=True) @@ -322,69 +484,90 @@ def multiprocess_with_tqdm(args_list, n_processes): total=update_data.get("total", 0), description=desc ) + time.sleep(0.2) # raise any errors: for future in futures: future.result() + except KeyboardInterrupt: + print("\nKeyboardInterrupt received, cancelling pending jobs...", flush=True) + for future in futures: + future.cancel() + _shutdown_executor(executor, wait=False, cancel_futures=True) + executor = None + raise SystemExit(130) + finally: + _shutdown_executor(executor, wait=True) def multiprocess_with_pool(args_list, n_processes): """使用Pool,每个worker处理固定数量任务后重启""" - from multiprocessing import Pool + if not args_list: + return - # maxtasksperchild 设置每个worker处理多少任务后重启(释放内存) - with Pool(processes=n_processes, maxtasksperchild=2) as pool: - results = [] + ctx = multiprocessing.get_context("spawn") + yaml_path = str(DEFAULT_YAML_PATH) + pool = ctx.Pool( + processes=n_processes, + maxtasksperchild=2, + initializer=_init_pool_worker, + initargs=(yaml_path,), + ) + + pending_results = {} + completed = 0 + + try: for samp_id in args_list: - result = pool.apply_async( + pending_results[samp_id] = pool.apply_async( build_HYS_dataset_segment, args=(samp_id, False, True, False, None, None) ) - results.append(result) - - # 等待所有任务完成 - for i, result in enumerate(results): - try: - result.get() - print(f"Completed {i + 1}/{len(args_list)}") - except Exception as e: - print(f"Error processing task {i}: {e}") pool.close() + + while pending_results: + finished_ids = [] + for samp_id, result in pending_results.items(): + if not result.ready(): + continue + + try: + result.get() + completed += 1 + print(f"Completed {completed}/{len(args_list)}: {samp_id}", flush=True) + except Exception as e: + completed += 1 + print(f"Error processing {samp_id}: {e}", flush=True) + + finished_ids.append(samp_id) + + for samp_id in finished_ids: + pending_results.pop(samp_id, None) + + if pending_results: + time.sleep(0.5) + pool.join() + except KeyboardInterrupt: + print("\nKeyboardInterrupt received, terminating worker processes...", flush=True) + pool.terminate() + pool.join() + raise SystemExit(130) + except Exception: + pool.terminate() + pool.join() + raise if __name__ == '__main__': - yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml" - - conf = utils.load_dataset_conf(yaml_path) - select_ids = conf["select_ids"] - root_path = Path(conf["root_path"]) - mask_path = Path(conf["mask_save_path"]) - save_path = Path(conf["dataset_config"]["dataset_save_path"]) - visual_path = Path(conf["dataset_config"]["dataset_visual_path"]) - dataset_config = conf["dataset_config"] - - visual_path.mkdir(parents=True, exist_ok=True) - - save_processed_signal_path = save_path / "Signals" - save_processed_signal_path.mkdir(parents=True, exist_ok=True) - - save_segment_list_path = save_path / "Segments_List" - save_segment_list_path.mkdir(parents=True, exist_ok=True) - - save_processed_label_path = save_path / "Labels" - save_processed_label_path.mkdir(parents=True, exist_ok=True) - - # print(f"select_ids: {select_ids}") - # print(f"root_path: {root_path}") - # print(f"save_path: {save_path}") - # print(f"visual_path: {visual_path}") - - org_signal_root_path = root_path / "OrgBCG_Aligned" - psg_signal_root_path = root_path / "PSG_Aligned" + yaml_path = DEFAULT_YAML_PATH + initialize_runtime(yaml_path) print(select_ids) + valid_select_ids, skipped_select_ids = filter_valid_psg_samples(select_ids, verbose=True) + print(f"Valid PSG samples: {len(valid_select_ids)} / {len(select_ids)}", flush=True) + # build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True) # for samp_id in select_ids: @@ -392,4 +575,4 @@ if __name__ == '__main__': # build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) # multiprocess_with_tqdm(args_list=select_ids, n_processes=8) - multiprocess_with_pool(args_list=select_ids, n_processes=8) \ No newline at end of file + # multiprocess_with_pool(args_list=valid_select_ids, n_processes=8) diff --git a/dataset_config/HYS_PSG_config.yaml b/dataset_config/HYS_PSG_config.yaml index 6496a72..b252807 100644 --- a/dataset_config/HYS_PSG_config.yaml +++ b/dataset_config/HYS_PSG_config.yaml @@ -1,27 +1,48 @@ select_ids: - - 54 - - 88 + - 1000 + - 1004 + - 1006 + - 1009 + - 1010 + - 1300 + - 1302 + - 1308 + - 1314 + - 1354 + - 1378 - 220 - 221 - 229 - 282 + - 285 - 286 + - 54 - 541 - - 579 - 582 - 670 - - 671 - 683 - 684 + - 686 + - 703 + - 704 + - 726 - 735 + - 736 + - 88 + - 893 - 933 - 935 - 950 - 952 + - 954 + - 955 + - 956 - 960 + - 961 - 962 - 967 - - 1302 + - 971 + - 972 root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG diff --git a/event_mask_process/HYS_PSG_process.py b/event_mask_process/HYS_PSG_process.py index bd04f63..10df8a0 100644 --- a/event_mask_process/HYS_PSG_process.py +++ b/event_mask_process/HYS_PSG_process.py @@ -35,59 +35,77 @@ import os os.environ['DISPLAY'] = "localhost:10.0" +def resolve_sample_file(sample_dir: Path, prefix: str, suffix=".txt", prefer_tokens=("Sync", "RoughCut")) -> Path: + candidates = sorted(sample_dir.glob(f"{prefix}*{suffix}")) + if not candidates: + if sample_dir.exists(): + available_files = ", ".join(sorted(path.name for path in sample_dir.iterdir())) + else: + available_files = "" + raise FileNotFoundError( + f"{prefix} file not found in {sample_dir}. " + f"searched pattern: {prefix}*{suffix}. available: {available_files}" + ) + + for token in prefer_tokens: + preferred = [ + path for path in candidates + if f"_{token}_" in path.name or f"_{token}." in path.name + ] + if preferred: + if len(preferred) > 1: + print(f"Warning!!! multiple preferred files found for {prefix}: {preferred}") + return preferred[0] + + if len(candidates) > 1: + print(f"Warning!!! multiple files found for {prefix}: {candidates}") + return candidates[0] + + +def get_signal_duration_second(signal_path: Path) -> int: + signal_fs = int(signal_path.stem.split("_")[-1]) + with signal_path.open("r", encoding="utf-8", errors="ignore") as file_obj: + signal_length = sum(1 for _ in file_obj) + return signal_length // signal_fs + + def process_one_signal(samp_id, show=False): - pass + sample_dir = org_signal_root_path / f"{samp_id}" + label_dir = label_root_path / f"{samp_id}" - tho_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Tho_Sync_*.txt")) - abd_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Abd_Sync_*.txt")) - flowp_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow P_Sync_*.txt")) - flowt_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow T_Sync_*.txt")) - spo2_signal_path = list((org_signal_root_path / f"{samp_id}").glob("SpO2_Sync_*.txt")) - stage_signal_path = list((org_signal_root_path / f"{samp_id}").glob("5_class_Sync_*.txt")) + tho_signal_path = resolve_sample_file(sample_dir, "Effort Tho") + abd_signal_path = resolve_sample_file(sample_dir, "Effort Abd") + flowp_signal_path = resolve_sample_file(sample_dir, "Flow P") + flowt_signal_path = resolve_sample_file(sample_dir, "Flow T") + spo2_signal_path = resolve_sample_file(sample_dir, "SpO2") + stage_signal_path = resolve_sample_file(sample_dir, "5_class") + label_path = resolve_sample_file(label_dir, "SA Label", suffix=".csv") - if not tho_signal_path: - raise FileNotFoundError(f"Effort Tho_Sync file not found for sample ID: {samp_id}") - tho_signal_path = tho_signal_path[0] - print(f"Processing Effort Tho_Sync signal file: {tho_signal_path}") - if not abd_signal_path: - raise FileNotFoundError(f"Effort Abd_Sync file not found for sample ID: {samp_id}") - abd_signal_path = abd_signal_path[0] - print(f"Processing Effort Abd_Sync signal file: {abd_signal_path}") - if not flowp_signal_path: - raise FileNotFoundError(f"Flow P_Sync file not found for sample ID: {samp_id}") - flowp_signal_path = flowp_signal_path[0] - print(f"Processing Flow P_Sync signal file: {flowp_signal_path}") - if not flowt_signal_path: - raise FileNotFoundError(f"Flow T_Sync file not found for sample ID: {samp_id}") - flowt_signal_path = flowt_signal_path[0] - print(f"Processing Flow T_Sync signal file: {flowt_signal_path}") - if not spo2_signal_path: - raise FileNotFoundError(f"SpO2_Sync file not found for sample ID: {samp_id}") - spo2_signal_path = spo2_signal_path[0] - print(f"Processing SpO2_Sync signal file: {spo2_signal_path}") - if not stage_signal_path: - raise FileNotFoundError(f"5_class_Sync file not found for sample ID: {samp_id}") - stage_signal_path = stage_signal_path[0] - print(f"Processing 5_class_Sync signal file: {stage_signal_path}") - - - - label_path = (label_root_path / f"{samp_id}").glob("SA Label_Sync.csv") - if not label_path: - raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}") - label_path = list(label_path)[0] - print(f"Processing Label_corrected file: {label_path}") + print(f"Processing Effort Tho signal file: {tho_signal_path}") + print(f"Processing Effort Abd signal file: {abd_signal_path}") + print(f"Processing Flow P signal file: {flowp_signal_path}") + print(f"Processing Flow T signal file: {flowt_signal_path}") + print(f"Processing SpO2 signal file: {spo2_signal_path}") + print(f"Processing 5_class signal file: {stage_signal_path}") + print(f"Processing SA Label file: {label_path}") # # # 保存处理后的数据和标签 save_samp_path = save_path / f"{samp_id}" save_samp_path.mkdir(parents=True, exist_ok=True) + signal_seconds = { + "Effort Tho": get_signal_duration_second(tho_signal_path), + "Effort Abd": get_signal_duration_second(abd_signal_path), + "Flow P": get_signal_duration_second(flowp_signal_path), + "Flow T": get_signal_duration_second(flowt_signal_path), + "SpO2": get_signal_duration_second(spo2_signal_path), + "5_class": get_signal_duration_second(stage_signal_path), + } + common_second = min(signal_seconds.values()) + print(f"Sample {samp_id} signal seconds: {signal_seconds}") + print(f"Sample {samp_id} common_second: {common_second}") + # # # 读取信号数据 - tho_data_raw, tho_length, tho_fs, tho_second = utils.read_signal_txt(tho_signal_path, dtype=float, verbose=True) - # abd_data_raw, abd_length, abd_fs, abd_second = utils.read_signal_txt(abd_signal_path, dtype=float, verbose=True) - # flowp_data_raw, flowp_length, flowp_fs, flowp_second = utils.read_signal_txt(flowp_signal_path, dtype=float, verbose=True) - # flowt_data_raw, flowt_length, flowt_fs, flowt_second = utils.read_signal_txt(flowt_signal_path, dtype=float, verbose=True) - # spo2_data_raw, spo2_length, spo2_fs, spo2_second = utils.read_signal_txt(spo2_signal_path, dtype=float, verbose=True) stage_data_raw, stage_length, stage_fs, stage_second = utils.read_signal_txt(stage_signal_path, dtype=str, verbose=True) @@ -117,7 +135,7 @@ def process_one_signal(samp_id, show=False): # spo2_fs = spo2_fs label_data = utils.read_raw_psg_label(path=label_path) - event_mask, score_mask = utils.generate_event_mask(signal_second=tho_second, event_df=label_data, use_correct=False, with_score=False) + event_mask, score_mask = utils.generate_event_mask(signal_second=common_second, event_df=label_data, use_correct=False, with_score=False) # event_mask > 0 的部分为1,其他为0 score_mask = np.where(event_mask > 0, 1, 0) @@ -128,9 +146,7 @@ def process_one_signal(samp_id, show=False): # 合并短于120秒的觉醒区间 wake_mask = utils.merge_short_gaps(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), max_gap_sec=60) - disable_label = wake_mask - - disable_label = disable_label[:tho_second] + disable_label = wake_mask[:common_second] # 复制事件文件 到保存路径 @@ -139,7 +155,7 @@ def process_one_signal(samp_id, show=False): # # 新建一个dataframe,分别是秒数、SA标签, save_dict = { - "Second": np.arange(tho_second), + "Second": np.arange(common_second), "SA_Label": event_mask, "SA_Score": score_mask, "Disable_Label": disable_label, @@ -180,4 +196,4 @@ if __name__ == '__main__': print(f"Processing sample ID: {samp_id}") process_one_signal(samp_id, show=False) print(f"Finished processing sample ID: {samp_id}\n\n") - pass \ No newline at end of file + pass diff --git a/utils/HYS_FileReader.py b/utils/HYS_FileReader.py index d2fb03f..97ce9a2 100644 --- a/utils/HYS_FileReader.py +++ b/utils/HYS_FileReader.py @@ -307,7 +307,10 @@ def read_psg_channel(path_str: Union[str, Path], channel_number: list[int], verb ch_path = list(path.glob(f"{ch_name}*.txt")) if not any(ch_path): - raise FileNotFoundError(f"PSG Channel file not found: {ch_path}") + raise FileNotFoundError( + f"PSG channel '{ch_name}' file not found in {path} " + f"(glob pattern: {ch_name}*.txt)" + ) if len(ch_path) > 1: print(f"Warning!!! PSG Channel file more than one: {ch_path}")