import multiprocessing import sys from pathlib import Path import os import numpy as np os.environ['DISPLAY'] = "localhost:10.0" sys.path.append(str(Path(__file__).resolve().parent.parent)) project_root_path = Path(__file__).resolve().parent.parent import utils import signal_method import draw_tools import shutil def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None): signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) if not signal_path: raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}") signal_path = signal_path[0] if verbose: print(f"Processing OrgBCG_Sync signal file: {signal_path}") mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv") if verbose: print(f"mask_excel_path: {mask_excel_path}") event_mask, event_list = utils.read_mask_execl(mask_excel_path) bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=verbose) bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs, verbose=verbose) normalized_resp_signal = signal_method.normalize_resp_signal(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"]) # 如果signal_data采样率过,进行降采样 if signal_fs == 1000: bcg_signal_notch = utils.downsample_signal_fast(original_signal=bcg_signal_notch, original_fs=signal_fs, target_fs=100) bcg_signal_raw = utils.downsample_signal_fast(original_signal=bcg_signal_raw, original_fs=signal_fs, target_fs=100) signal_fs = 100 if bcg_fs == 1000: bcg_signal = utils.downsample_signal_fast(original_signal=bcg_signal, original_fs=bcg_fs, target_fs=100) bcg_fs = 100 # draw_tools.draw_signal_with_mask(samp_id=samp_id, # signal_data=resp_signal, # signal_fs=resp_fs, # resp_data=normalized_resp_signal, # resp_fs=resp_fs, # bcg_data=bcg_signal, # bcg_fs=bcg_fs, # signal_disable_mask=event_mask["Disable_Label"], # resp_low_amp_mask=event_mask["Resp_LowAmp_Label"], # resp_movement_mask=event_mask["Resp_Movement_Label"], # resp_change_mask=event_mask["Resp_AmpChange_Label"], # resp_sa_mask=event_mask["SA_Label"], # bcg_low_amp_mask=event_mask["BCG_LowAmp_Label"], # bcg_movement_mask=event_mask["BCG_Movement_Label"], # bcg_change_mask=event_mask["BCG_AmpChange_Label"], # show=show, # save_path=None) segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose) if verbose: print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}") # 复制mask到processed_Labels文件夹 save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv" shutil.copyfile(mask_excel_path, save_mask_excel_path) # 复制SA Label_corrected.csv到processed_Labels文件夹 sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_corrected.csv") if sa_label_corrected_path.exists(): save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv" shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path) else: if verbose: print(f"Warning: {sa_label_corrected_path} does not exist.") # 保存处理后的信号和截取的片段列表 save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz" save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz" bcg_data = { "bcg_signal_notch": { "name": "BCG_Signal_Notch", "data": bcg_signal_notch, "fs": signal_fs, "length": len(bcg_signal_notch), "second": len(bcg_signal_notch) // signal_fs }, "bcg_signal":{ "name": "BCG_Signal_Raw", "data": bcg_signal, "fs": bcg_fs, "length": len(bcg_signal), "second": len(bcg_signal) // bcg_fs }, "resp_signal": { "name": "Resp_Signal", "data": normalized_resp_signal, "fs": resp_fs, "length": len(normalized_resp_signal), "second": len(normalized_resp_signal) // resp_fs } } np.savez_compressed(save_signal_path, **bcg_data) np.savez_compressed(save_segment_path, segment_list=segment_list, disable_segment_list=disable_segment_list) if verbose: print(f"Saved processed signals to: {save_signal_path}") print(f"Saved segment list to: {save_segment_path}") if draw_segment: psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose) psg_data["HR"] = { "name": "HR", "data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"]), "fs": psg_data["ECG_Sync"]["fs"], "length": psg_data["ECG_Sync"]["length"], "second": psg_data["ECG_Sync"]["second"] } psg_label = utils.read_psg_label(sa_label_corrected_path, verbose=verbose) psg_event_mask, _ = utils.generate_event_mask(event_df=psg_label, signal_second=psg_data["ECG_Sync"]["second"], use_correct=False) total_len = len(segment_list) + len(disable_segment_list) if verbose: print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}") draw_tools.draw_psg_bcg_label(psg_data=psg_data, psg_label=psg_event_mask, bcg_data=bcg_data, event_mask=event_mask, segment_list=segment_list, save_path=visual_path / f"{samp_id}" / "enable", verbose=verbose, multi_p=multi_p, multi_task_id=multi_task_id ) draw_tools.draw_psg_bcg_label( psg_data=psg_data, psg_label=psg_event_mask, bcg_data=bcg_data, event_mask=event_mask, segment_list=disable_segment_list, save_path=visual_path / f"{samp_id}" / "disable", verbose=verbose, multi_p=multi_p, multi_task_id=multi_task_id ) def multiprocess_entry(_progress, task_id, _id): build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id) def multiprocess_with_tqdm(args_list, n_processes): from concurrent.futures import ProcessPoolExecutor from rich import progress with progress.Progress( "[progress.description]{task.description}", progress.BarColumn(), "[progress.percentage]{task.percentage:>3.0f}%", progress.MofNCompleteColumn(), progress.TimeRemainingColumn(), progress.TimeElapsedColumn(), refresh_per_second=1, # bit slower updates transient=False ) as progress: futures = [] with multiprocessing.Manager() as manager: _progress = manager.dict() overall_progress_task = progress.add_task("[green]All jobs progress:") with ProcessPoolExecutor(max_workers=n_processes) as executor: for i_args in range(len(args_list)): args = args_list[i_args] task_id = progress.add_task(f"task {i_args}", visible=True) futures.append(executor.submit(multiprocess_entry, _progress, task_id, args_list[i_args])) # monitor the progress: while (n_finished := sum([future.done() for future in futures])) < len( futures ): progress.update( overall_progress_task, completed=n_finished, total=len(futures) ) for task_id, update_data in _progress.items(): desc = update_data.get("desc", "") # update the progress bar for this task: progress.update( task_id, completed=update_data.get("progress", 0), total=update_data.get("total", 0), description=desc ) # raise any errors: for future in futures: future.result() if __name__ == '__main__': yaml_path = project_root_path / "dataset_config/HYS_config.yaml" conf = utils.load_dataset_conf(yaml_path) select_ids = conf["select_ids"] root_path = Path(conf["root_path"]) mask_path = Path(conf["mask_save_path"]) save_path = Path(conf["dataset_config"]["dataset_save_path"]) visual_path = Path(conf["dataset_config"]["dataset_visual_path"]) dataset_config = conf["dataset_config"] visual_path.mkdir(parents=True, exist_ok=True) save_processed_signal_path = save_path / "Signals" save_processed_signal_path.mkdir(parents=True, exist_ok=True) save_segment_list_path = save_path / "Segments_List" save_segment_list_path.mkdir(parents=True, exist_ok=True) save_processed_label_path = save_path / "Labels" save_processed_label_path.mkdir(parents=True, exist_ok=True) # print(f"select_ids: {select_ids}") # print(f"root_path: {root_path}") # print(f"save_path: {save_path}") # print(f"visual_path: {visual_path}") org_signal_root_path = root_path / "OrgBCG_Aligned" psg_signal_root_path = root_path / "PSG_Aligned" build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True) # for samp_id in select_ids: # print(f"Processing sample ID: {samp_id}") # build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) # multiprocess_with_tqdm(args_list=select_ids, n_processes=16)