import multiprocessing import signal import sys import time from pathlib import Path import os import numpy as np os.environ['DISPLAY'] = "localhost:10.0" sys.path.append(str(Path(__file__).resolve().parent.parent)) project_root_path = Path(__file__).resolve().parent.parent import utils from utils import N2Chn import signal_method import draw_tools import shutil import gc DEFAULT_YAML_PATH = project_root_path / "dataset_config/HYS_PSG_config.yaml" conf = None select_ids = None root_path = None mask_path = None save_path = None visual_path = None dataset_config = None org_signal_root_path = None psg_signal_root_path = None save_processed_signal_path = None save_segment_list_path = None save_processed_label_path = None def get_missing_psg_channels(samp_id, channel_number=None): ensure_runtime_initialized() if channel_number is None: channel_number = [1, 2, 3, 4, 5, 6, 7, 8] sample_path = psg_signal_root_path / f"{samp_id}" if not sample_path.exists(): return [f"PSG dir missing: {sample_path}"] missing_channels = [] for ch_id in channel_number: ch_name = N2Chn[ch_id] if not any(sample_path.glob(f"{ch_name}*.txt")): missing_channels.append(ch_name) return missing_channels def filter_valid_psg_samples(sample_ids, verbose=True): valid_ids = [] skipped_ids = [] for samp_id in sample_ids: missing_channels = get_missing_psg_channels(samp_id) if missing_channels: skipped_ids.append((samp_id, missing_channels)) if verbose: print( f"Skipping sample {samp_id}: missing PSG channels {missing_channels}", flush=True, ) continue valid_ids.append(samp_id) if verbose and skipped_ids: print( f"Filtered out {len(skipped_ids)} sample(s) with incomplete PSG inputs.", flush=True, ) return valid_ids, skipped_ids def initialize_runtime(yaml_path=DEFAULT_YAML_PATH): global conf global select_ids global root_path global mask_path global save_path global visual_path global dataset_config global org_signal_root_path global psg_signal_root_path global save_processed_signal_path global save_segment_list_path global save_processed_label_path yaml_path = Path(yaml_path) conf = utils.load_dataset_conf(yaml_path) select_ids = conf["select_ids"] root_path = Path(conf["root_path"]) mask_path = Path(conf["mask_save_path"]) save_path = Path(conf["dataset_config"]["dataset_save_path"]) visual_path = Path(conf["dataset_config"]["dataset_visual_path"]) dataset_config = conf["dataset_config"] visual_path.mkdir(parents=True, exist_ok=True) save_processed_signal_path = save_path / "Signals" save_processed_signal_path.mkdir(parents=True, exist_ok=True) save_segment_list_path = save_path / "Segments_List" save_segment_list_path.mkdir(parents=True, exist_ok=True) save_processed_label_path = save_path / "Labels" save_processed_label_path.mkdir(parents=True, exist_ok=True) org_signal_root_path = root_path / "OrgBCG_Aligned" psg_signal_root_path = root_path / "PSG_Aligned" def ensure_runtime_initialized(): if conf is None or psg_signal_root_path is None: initialize_runtime() def sanitize_rpeak_indices(samp_id, rpeak_indices, signal_length, verbose=True): rpeak_indices = np.asarray(rpeak_indices, dtype=int) valid_mask = (rpeak_indices >= 0) & (rpeak_indices < signal_length) invalid_count = int((~valid_mask).sum()) if invalid_count > 0: invalid_indices = rpeak_indices[~valid_mask] print( f"Sample {samp_id}: dropping {invalid_count} invalid Rpeak index/indices " f"outside [0, {signal_length - 1}]. " f"min_invalid={invalid_indices.min()}, max_invalid={invalid_indices.max()}", flush=True, ) rpeak_indices = rpeak_indices[valid_mask] if rpeak_indices.size == 0: raise ValueError(f"Sample {samp_id}: no valid Rpeak indices remain after bounds check.") rpeak_indices = np.unique(rpeak_indices) if rpeak_indices.size < 2: raise ValueError(f"Sample {samp_id}: fewer than 2 valid Rpeak indices remain after bounds check.") return rpeak_indices def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None): ensure_runtime_initialized() psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose) psg_data["Rpeak"]["data"] = sanitize_rpeak_indices( samp_id=samp_id, rpeak_indices=psg_data["Rpeak"]["data"], signal_length=psg_data["ECG_Sync"]["length"], verbose=verbose, ) total_seconds = min( psg_data[i]["second"] for i in N2Chn.values() if i != "Rpeak" ) for i in N2Chn.values(): if i == "Rpeak": continue length = int(total_seconds * psg_data[i]["fs"]) psg_data[i]["data"] = psg_data[i]["data"][:length] psg_data[i]["length"] = length psg_data[i]["second"] = total_seconds psg_data["HR"] = { "name": "HR", "data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"], psg_data["Rpeak"]["fs"]), "fs": psg_data["ECG_Sync"]["fs"], "length": psg_data["ECG_Sync"]["length"], "second": psg_data["ECG_Sync"]["second"] } # 预处理与滤波 tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Tho"]["data"], effort_fs=psg_data["Effort Tho"]["fs"]) abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Abd"]["data"], effort_fs=psg_data["Effort Abd"]["fs"]) flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow P"]["data"], effort_fs=psg_data["Flow P"]["fs"]) flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow T"]["data"], effort_fs=psg_data["Flow T"]["fs"]) rri, rri_fs = signal_method.rpeak2rri_interpolation(rpeak_indices=psg_data["Rpeak"]["data"], ecg_fs=psg_data["ECG_Sync"]["fs"], rri_fs=100) mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv") if verbose: print(f"mask_excel_path: {mask_excel_path}") event_mask, event_list = utils.read_mask_execl(mask_excel_path) enable_list = [[0, psg_data["Effort Tho"]["second"]]] normalized_tho_signal = signal_method.normalize_resp_signal_by_segment(tho_data_filt, tho_fs, np.zeros(psg_data["Effort Tho"]["second"]), enable_list) normalized_abd_signal = signal_method.normalize_resp_signal_by_segment(abd_data_filt, abd_fs, np.zeros(psg_data["Effort Abd"]["second"]), enable_list) normalized_flowp_signal = signal_method.normalize_resp_signal_by_segment(flowp_data_filt, flowp_fs, np.zeros(psg_data["Flow P"]["second"]), enable_list) normalized_flowt_signal = signal_method.normalize_resp_signal_by_segment(flowt_data_filt, flowt_fs, np.zeros(psg_data["Flow T"]["second"]), enable_list) # 都调整至100Hz采样率 target_fs = conf["target_fs"] normalized_tho_signal = utils.adjust_sample_rate(normalized_tho_signal, tho_fs, target_fs) normalized_abd_signal = utils.adjust_sample_rate(normalized_abd_signal, abd_fs, target_fs) normalized_flowp_signal = utils.adjust_sample_rate(normalized_flowp_signal, flowp_fs, target_fs) normalized_flowt_signal = utils.adjust_sample_rate(normalized_flowt_signal, flowt_fs, target_fs) spo2_data_filt = utils.adjust_sample_rate(psg_data["SpO2"]["data"], psg_data["SpO2"]["fs"], target_fs) normalized_effort_signal = (normalized_tho_signal + normalized_abd_signal) / 2 rri = utils.adjust_sample_rate(rri, rri_fs, target_fs) # 调整至相同长度 min_length = min(len(normalized_tho_signal), len(normalized_abd_signal), len(normalized_flowp_signal), len(normalized_flowt_signal), len(spo2_data_filt), len(normalized_effort_signal) ,len(rri)) min_length = min_length - min_length % target_fs # 保证是整数秒 normalized_tho_signal = normalized_tho_signal[:min_length] normalized_abd_signal = normalized_abd_signal[:min_length] normalized_flowp_signal = normalized_flowp_signal[:min_length] normalized_flowt_signal = normalized_flowt_signal[:min_length] spo2_data_filt = spo2_data_filt[:min_length] normalized_effort_signal = normalized_effort_signal[:min_length] rri = rri[:min_length] tho_second = min_length / target_fs for i in event_mask.keys(): event_mask[i] = event_mask[i][:int(tho_second)] spo2_data_filt_fill, spo2_disable_mask = utils.fill_spo2_anomaly(spo2_data=spo2_data_filt, spo2_fs=target_fs, max_fill_duration=conf["spo2_fill__anomaly"]["max_fill_duration"], min_gap_duration=conf["spo2_fill__anomaly"]["min_gap_duration"]) draw_tools.draw_psg_signal( samp_id=samp_id, tho_signal=normalized_tho_signal, abd_signal=normalized_abd_signal, flowp_signal=normalized_flowp_signal, flowt_signal=normalized_flowt_signal, spo2_signal=spo2_data_filt, effort_signal=normalized_effort_signal, rri_signal = rri, fs=target_fs, event_mask=event_mask["SA_Label"], save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots.png", show=show ) draw_tools.draw_psg_signal( samp_id=samp_id, tho_signal=normalized_tho_signal, abd_signal=normalized_abd_signal, flowp_signal=normalized_flowp_signal, flowt_signal=normalized_flowt_signal, spo2_signal=spo2_data_filt_fill, effort_signal=normalized_effort_signal, rri_signal = rri, fs=target_fs, event_mask=event_mask["SA_Label"], save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots_fill.png", show=show ) spo2_disable_mask = spo2_disable_mask[::target_fs] min_len = min(len(event_mask["Disable_Label"]), len(spo2_disable_mask)) if len(event_mask["Disable_Label"]) != len(spo2_disable_mask): print(f"Warning: Data length mismatch! Truncating to {min_len}.") event_mask["Disable_Label"] = event_mask["Disable_Label"][:min_len] & spo2_disable_mask[:min_len] event_list = { "EnableSegment": utils.event_mask_2_list(1 - event_mask["Disable_Label"]), "DisableSegment": utils.event_mask_2_list(event_mask["Disable_Label"])} spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=conf["spo2_fill__anomaly"]["nan_to_num_value"]) segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose) if verbose: print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}") # 复制mask到processed_Labels文件夹 save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv" shutil.copyfile(mask_excel_path, save_mask_excel_path) # 复制SA Label_corrected.csv到processed_Labels文件夹 sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_Sync.csv") if sa_label_corrected_path.exists(): save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv" shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path) else: if verbose: print(f"Warning: {sa_label_corrected_path} does not exist.") # 保存处理后的信号和截取的片段列表 save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz" save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz" # psg_data更新为处理后的信号 # 用下划线替换键里面的空格 psg_data = { "Effort Tho": { "name": "Effort_Tho", "data": normalized_tho_signal, "fs": target_fs, "length": len(normalized_tho_signal), "second": len(normalized_tho_signal) / target_fs }, "Effort Abd": { "name": "Effort_Abd", "data": normalized_abd_signal, "fs": target_fs, "length": len(normalized_abd_signal), "second": len(normalized_abd_signal) / target_fs }, "Effort": { "name": "Effort", "data": normalized_effort_signal, "fs": target_fs, "length": len(normalized_effort_signal), "second": len(normalized_effort_signal) / target_fs }, "Flow P": { "name": "Flow_P", "data": normalized_flowp_signal, "fs": target_fs, "length": len(normalized_flowp_signal), "second": len(normalized_flowp_signal) / target_fs }, "Flow T": { "name": "Flow_T", "data": normalized_flowt_signal, "fs": target_fs, "length": len(normalized_flowt_signal), "second": len(normalized_flowt_signal) / target_fs }, "SpO2": { "name": "SpO2", "data": spo2_data_filt_fill, "fs": target_fs, "length": len(spo2_data_filt_fill), "second": len(spo2_data_filt_fill) / target_fs }, "HR": { "name": "HR", "data": psg_data["HR"]["data"], "fs": psg_data["HR"]["fs"], "length": psg_data["HR"]["length"], "second": psg_data["HR"]["second"] }, "RRI": { "name": "RRI", "data": rri, "fs": target_fs, "length": len(rri), "second": len(rri) / target_fs }, "5_class": { "name": "Stage", "data": psg_data["5_class"]["data"], "fs": psg_data["5_class"]["fs"], "length": psg_data["5_class"]["length"], "second": psg_data["5_class"]["second"] } } np.savez_compressed(save_signal_path, **psg_data) np.savez_compressed(save_segment_path, segment_list=segment_list, disable_segment_list=disable_segment_list) if verbose: print(f"Saved processed signals to: {save_signal_path}") print(f"Saved segment list to: {save_segment_path}") if draw_segment: total_len = len(segment_list) + len(disable_segment_list) if verbose: print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}") draw_tools.draw_psg_label( psg_data=psg_data, psg_label=event_mask["SA_Label"], segment_list=segment_list, save_path=visual_path / f"{samp_id}" / "enable", verbose=verbose, multi_p=multi_p, multi_task_id=multi_task_id ) draw_tools.draw_psg_label( psg_data=psg_data, psg_label=event_mask["SA_Label"], segment_list=disable_segment_list, save_path=visual_path / f"{samp_id}" / "disable", verbose=verbose, multi_p=multi_p, multi_task_id=multi_task_id ) # 显式删除大型对象 try: del psg_data del normalized_tho_signal, normalized_abd_signal del normalized_flowp_signal, normalized_flowt_signal del normalized_effort_signal del spo2_data_filt, spo2_data_filt_fill del rri del event_mask, event_list del segment_list, disable_segment_list except: pass # 强制垃圾回收 gc.collect() def multiprocess_entry(_progress, task_id, _id): build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id) def _init_pool_worker(yaml_path=DEFAULT_YAML_PATH): # 让主进程统一响应 Ctrl+C,避免父子进程同时处理中断导致退出卡住。 signal.signal(signal.SIGINT, signal.SIG_IGN) initialize_runtime(yaml_path) def _shutdown_executor(executor, wait, cancel_futures=False): if executor is None: return try: executor.shutdown(wait=wait, cancel_futures=cancel_futures) except TypeError: executor.shutdown(wait=wait) def multiprocess_with_tqdm(args_list, n_processes): from concurrent.futures import ProcessPoolExecutor from rich import progress yaml_path = str(DEFAULT_YAML_PATH) with progress.Progress( "[progress.description]{task.description}", progress.BarColumn(), "[progress.percentage]{task.percentage:>3.0f}%", progress.MofNCompleteColumn(), progress.TimeRemainingColumn(), progress.TimeElapsedColumn(), refresh_per_second=1, # bit slower updates transient=False ) as progress: futures = [] with multiprocessing.Manager() as manager: _progress = manager.dict() overall_progress_task = progress.add_task("[green]All jobs progress:") executor = ProcessPoolExecutor( max_workers=n_processes, mp_context=multiprocessing.get_context("spawn"), initializer=_init_pool_worker, initargs=(yaml_path,), ) try: for i_args in range(len(args_list)): args = args_list[i_args] task_id = progress.add_task(f"task {i_args}", visible=True) futures.append(executor.submit(multiprocess_entry, _progress, task_id, args_list[i_args])) # monitor the progress: while (n_finished := sum([future.done() for future in futures])) < len( futures ): progress.update( overall_progress_task, completed=n_finished, total=len(futures) ) for task_id, update_data in _progress.items(): desc = update_data.get("desc", "") # update the progress bar for this task: progress.update( task_id, completed=update_data.get("progress", 0), total=update_data.get("total", 0), description=desc ) time.sleep(0.2) # raise any errors: for future in futures: future.result() except KeyboardInterrupt: print("\nKeyboardInterrupt received, cancelling pending jobs...", flush=True) for future in futures: future.cancel() _shutdown_executor(executor, wait=False, cancel_futures=True) executor = None raise SystemExit(130) finally: _shutdown_executor(executor, wait=True) def multiprocess_with_pool(args_list, n_processes): """使用Pool,每个worker处理固定数量任务后重启""" if not args_list: return ctx = multiprocessing.get_context("spawn") yaml_path = str(DEFAULT_YAML_PATH) pool = ctx.Pool( processes=n_processes, maxtasksperchild=2, initializer=_init_pool_worker, initargs=(yaml_path,), ) pending_results = {} completed = 0 try: for samp_id in args_list: pending_results[samp_id] = pool.apply_async( build_HYS_dataset_segment, args=(samp_id, False, True, False, None, None) ) pool.close() while pending_results: finished_ids = [] for samp_id, result in pending_results.items(): if not result.ready(): continue try: result.get() completed += 1 print(f"Completed {completed}/{len(args_list)}: {samp_id}", flush=True) except Exception as e: completed += 1 print(f"Error processing {samp_id}: {e}", flush=True) finished_ids.append(samp_id) for samp_id in finished_ids: pending_results.pop(samp_id, None) if pending_results: time.sleep(0.5) pool.join() except KeyboardInterrupt: print("\nKeyboardInterrupt received, terminating worker processes...", flush=True) pool.terminate() pool.join() raise SystemExit(130) except Exception: pool.terminate() pool.join() raise if __name__ == '__main__': yaml_path = DEFAULT_YAML_PATH initialize_runtime(yaml_path) print(select_ids) valid_select_ids, skipped_select_ids = filter_valid_psg_samples(select_ids, verbose=True) print(f"Valid PSG samples: {len(valid_select_ids)} / {len(select_ids)}", flush=True) # build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True) # for samp_id in select_ids: # print(f"Processing sample ID: {samp_id}") # build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) # multiprocess_with_tqdm(args_list=select_ids, n_processes=8) # multiprocess_with_pool(args_list=valid_select_ids, n_processes=8)