diff --git a/README.md b/README.md index a83c153..24e92a3 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,5 @@ # DataPrepare - +## 操作步骤 +1. 信号预处理 +2. 数据集构建 +3. 数据可视化(可选) diff --git a/dataset_builder/HYS_dataset.py b/dataset_builder/HYS_dataset.py index 4442936..20024f5 100644 --- a/dataset_builder/HYS_dataset.py +++ b/dataset_builder/HYS_dataset.py @@ -1,8 +1,8 @@ +import multiprocessing import sys from pathlib import Path import os - import numpy as np os.environ['DISPLAY'] = "localhost:10.0" @@ -16,21 +16,23 @@ import draw_tools import shutil -def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False): +def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None): signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) if not signal_path: raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}") signal_path = signal_path[0] - print(f"Processing OrgBCG_Sync signal file: {signal_path}") + if verbose: + print(f"Processing OrgBCG_Sync signal file: {signal_path}") mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv") - print(f"mask_excel_path: {mask_excel_path}") + if verbose: + print(f"mask_excel_path: {mask_excel_path}") event_mask, event_list = utils.read_mask_execl(mask_excel_path) - bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float) + bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=verbose) - bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs) + bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs, verbose=verbose) normalized_resp_signal = signal_method.normalize_resp_signal(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"]) @@ -63,8 +65,9 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False): # show=show, # save_path=None) - segment_list = utils.resp_split(dataset_config, event_mask, event_list) - print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}") + segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose) + if verbose: + print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}") # 复制mask到processed_Labels文件夹 @@ -77,7 +80,8 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False): save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv" shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path) else: - print(f"Warning: {sa_label_corrected_path} does not exist.") + if verbose: + print(f"Warning: {sa_label_corrected_path} does not exist.") # 保存处理后的信号和截取的片段列表 save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz" @@ -109,12 +113,14 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False): np.savez_compressed(save_signal_path, **bcg_data) np.savez_compressed(save_segment_path, - segment_list=segment_list) - print(f"Saved processed signals to: {save_signal_path}") - print(f"Saved segment list to: {save_segment_path}") + segment_list=segment_list, + disable_segment_list=disable_segment_list) + if verbose: + print(f"Saved processed signals to: {save_signal_path}") + print(f"Saved segment list to: {save_segment_path}") if draw_segment: - psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8]) + psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose) psg_data["HR"] = { "name": "HR", "data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"]), @@ -124,14 +130,86 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False): } - psg_label = utils.read_psg_label(sa_label_corrected_path) + psg_label = utils.read_psg_label(sa_label_corrected_path, verbose=verbose) psg_event_mask, _ = utils.generate_event_mask(event_df=psg_label, signal_second=psg_data["ECG_Sync"]["second"], use_correct=False) + total_len = len(segment_list) + len(disable_segment_list) + if verbose: + print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}") + draw_tools.draw_psg_bcg_label(psg_data=psg_data, psg_label=psg_event_mask, bcg_data=bcg_data, event_mask=event_mask, segment_list=segment_list, - save_path=visual_path / f"{samp_id}") + save_path=visual_path / f"{samp_id}" / "enable", + verbose=verbose, + multi_p=multi_p, + multi_task_id=multi_task_id + ) + + draw_tools.draw_psg_bcg_label( + psg_data=psg_data, + psg_label=psg_event_mask, + bcg_data=bcg_data, + event_mask=event_mask, + segment_list=disable_segment_list, + save_path=visual_path / f"{samp_id}" / "disable", + verbose=verbose, + multi_p=multi_p, + multi_task_id=multi_task_id + ) + + + +def multiprocess_entry(_progress, task_id, _id): + build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id) + + +def multiprocess_with_tqdm(args_list, n_processes): + from concurrent.futures import ProcessPoolExecutor + from rich import progress + + with progress.Progress( + "[progress.description]{task.description}", + progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + progress.MofNCompleteColumn(), + progress.TimeRemainingColumn(), + progress.TimeElapsedColumn(), + refresh_per_second=1, # bit slower updates + transient=False + ) as progress: + futures = [] + with multiprocessing.Manager() as manager: + _progress = manager.dict() + overall_progress_task = progress.add_task("[green]All jobs progress:") + with ProcessPoolExecutor(max_workers=n_processes) as executor: + for i_args in range(len(args_list)): + args = args_list[i_args] + task_id = progress.add_task(f"task {i_args}", visible=True) + futures.append(executor.submit(multiprocess_entry, _progress, task_id, args_list[i_args])) + # monitor the progress: + while (n_finished := sum([future.done() for future in futures])) < len( + futures + ): + progress.update( + overall_progress_task, completed=n_finished, total=len(futures) + ) + for task_id, update_data in _progress.items(): + desc = update_data.get("desc", "") + # update the progress bar for this task: + progress.update( + task_id, + completed=update_data.get("progress", 0), + total=update_data.get("total", 0), + description=desc + ) + + # raise any errors: + for future in futures: + future.result() + + if __name__ == '__main__': @@ -156,16 +234,18 @@ if __name__ == '__main__': save_processed_label_path = save_path / "Labels" save_processed_label_path.mkdir(parents=True, exist_ok=True) - print(f"select_ids: {select_ids}") - print(f"root_path: {root_path}") - print(f"save_path: {save_path}") - print(f"visual_path: {visual_path}") + # print(f"select_ids: {select_ids}") + # print(f"root_path: {root_path}") + # print(f"save_path: {save_path}") + # print(f"visual_path: {visual_path}") org_signal_root_path = root_path / "OrgBCG_Aligned" psg_signal_root_path = root_path / "PSG_Aligned" - # build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True) + build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True) - for samp_id in select_ids: - print(f"Processing sample ID: {samp_id}") - build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) \ No newline at end of file + # for samp_id in select_ids: + # print(f"Processing sample ID: {samp_id}") + # build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) + + # multiprocess_with_tqdm(args_list=select_ids, n_processes=16) \ No newline at end of file diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index 80b8165..43544ea 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -19,8 +19,8 @@ resp: resp_filter: filter_type: bandpass - low_cut: 0.01 - high_cut: 0.7 + low_cut: 0.05 + high_cut: 0.6 order: 3 resp_low_amp: diff --git a/draw_tools/draw_label.py b/draw_tools/draw_label.py index 6b0dcb9..b93cc6c 100644 --- a/draw_tools/draw_label.py +++ b/draw_tools/draw_label.py @@ -72,7 +72,6 @@ def create_psg_bcg_figure(): axes[psg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white") axes.append(axes[psg_chn_name2ax["bcg"]].twinx()) - axes[psg_chn_name2ax["Stage"]].grid(True) # axes[7].xaxis.set_major_formatter(Params.FORMATTER) @@ -183,7 +182,8 @@ def score_mask2alpha(score_mask): return alpha_mask -def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None): +def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None, verbose=True, + multi_p=None, multi_task_id=None): if save_path is not None: save_path.mkdir(parents=True, exist_ok=True) @@ -191,13 +191,13 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, if mask.startswith("Resp_") or mask.startswith("BCG_"): event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0) - event_mask["SA_Score"] = utils.none_to_nan_mask(event_mask["SA_Score"], 0) + event_mask["SA_Score"] = utils.none_to_nan_mask(event_mask["SA_Score"], 0) # event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"]) # event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0) fig, axes = create_psg_bcg_figure() - for segment_start, segment_end in tqdm(segment_list): + for i, (segment_start, segment_end) in enumerate(segment_list): for ax in axes: ax.cla() @@ -213,17 +213,21 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end, - event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["resp_twinx"]]) + event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], + ax2=axes[psg_chn_name2ax["resp_twinx"]]) plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end, - event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["bcg_twinx"]]) - + event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], + ax2=axes[psg_chn_name2ax["bcg_twinx"]]) if save_path is not None: fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png") - tqdm.write(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}") # print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}") + if multi_p is not None: + multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"} + + def draw_resp_label(resp_data, resp_label, segment_list): for mask in resp_label.keys(): if mask.startswith("Resp_"): diff --git a/event_mask_process/HYS_process.py b/event_mask_process/HYS_process.py index dc33461..45fa76d 100644 --- a/event_mask_process/HYS_process.py +++ b/event_mask_process/HYS_process.py @@ -52,7 +52,7 @@ def process_one_signal(samp_id, show=False): save_samp_path = save_path / f"{samp_id}" save_samp_path.mkdir(parents=True, exist_ok=True) - signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, verbose=True) + signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=True) signal_data, resp_data, resp_fs, bcg_data, bcg_fs = signal_method.signal_filter_split(conf=conf, signal_data_raw=signal_data_raw, signal_fs=signal_fs) diff --git a/signal_method/signal_process.py b/signal_method/signal_process.py index eaaea59..c0c6699 100644 --- a/signal_method/signal_process.py +++ b/signal_method/signal_process.py @@ -2,11 +2,12 @@ import numpy as np import utils -def signal_filter_split(conf, signal_data_raw, signal_fs): +def signal_filter_split(conf, signal_data_raw, signal_fs, verbose=True): # 滤波 # 50Hz陷波滤波器 # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) - print("Applying 50Hz notch filter...") + if verbose: + print("Applying 50Hz notch filter...") signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs) @@ -17,7 +18,8 @@ def signal_filter_split(conf, signal_data_raw, signal_fs): low_cut=conf["resp_filter"]["low_cut"], high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"], sample_rate=resp_fs) - print("Begin plotting signal data...") + if verbose: + print("Begin plotting signal data...") # fig = plt.figure(figsize=(12, 8)) # # 绘制三个图raw_data、resp_data_1、resp_data_2 diff --git a/utils/HYS_FileReader.py b/utils/HYS_FileReader.py index 6f7d95a..f41d16a 100644 --- a/utils/HYS_FileReader.py +++ b/utils/HYS_FileReader.py @@ -20,6 +20,7 @@ def read_signal_txt(path: Union[str, Path], dtype, verbose=True, is_peak=False): Read a txt file and return the first column as a numpy array. Args: + :param is_peak: :param path: :param verbose: :param dtype: @@ -217,14 +218,15 @@ def read_mask_execl(path: Union[str, Path]): event_list = {"RespAmpChangeSegment": event_mask_2_list(1 - event_mask["Resp_AmpChange_Label"]), "BCGAmpChangeSegment": event_mask_2_list(1 - event_mask["BCG_AmpChange_Label"]), - "EnableSegment": event_mask_2_list(1 - event_mask["Disable_Label"]),} + "EnableSegment": event_mask_2_list(1 - event_mask["Disable_Label"]), + "DisableSegment": event_mask_2_list(event_mask["Disable_Label"])} return event_mask, event_list -def read_psg_channel(path_str: Union[str, Path], channel_number: list[int]): +def read_psg_channel(path_str: Union[str, Path], channel_number: list[int], verbose=True): """ 读取PSG文件中特定通道的数据。 @@ -254,16 +256,16 @@ def read_psg_channel(path_str: Union[str, Path], channel_number: list[int]): if ch_id == 8: # sleep stage 特例 读取为整数 - ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=str, verbose=True) + ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=str, verbose=verbose) # 转换为整数数组 for stage_str, stage_number in utils.Stage2N.items(): np.place(ch_signal, ch_signal == stage_str, stage_number) ch_signal = ch_signal.astype(int) elif ch_id == 1: # Rpeak 特例 读取为整数 - ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=int, verbose=True, is_peak=True) + ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=int, verbose=verbose, is_peak=True) else: - ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=float, verbose=True) + ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=float, verbose=verbose) channel_data[ch_name] = { "name": ch_name, "path": ch_path[0], diff --git a/utils/split_method.py b/utils/split_method.py index e9c151b..f113013 100644 --- a/utils/split_method.py +++ b/utils/split_method.py @@ -1,27 +1,58 @@ +def check_split(event_mask, current_start, window_sec, verbose=False): + # 检查当前窗口是否包含在禁用区间或低幅值区间内 + resp_movement_mask = event_mask["Resp_Movement_Label"][current_start : current_start + window_sec] + resp_low_amp_mask = event_mask["Resp_LowAmp_Label"][current_start : current_start + window_sec] + + # 体动与低幅值进行与计算 + low_move_mask = resp_movement_mask | resp_low_amp_mask + if low_move_mask.sum() > 2/3 * window_sec: + if verbose: + print(f"{current_start}-{current_start + window_sec} rejected due to movement/low amplitude mask more than 2/3") + return False + return True + - -def resp_split(dataset_config, event_mask, event_list): +def resp_split(dataset_config, event_mask, event_list, verbose=False): # 提取体动区间和呼吸低幅值区间 enable_list = event_list["EnableSegment"] + disable_list = event_list["DisableSegment"] # 读取数据集配置 window_sec = dataset_config["window_sec"] stride_sec = dataset_config["stride_sec"] segment_list = [] + disable_segment_list = [] # 遍历每个enable区间, 如果最后一个窗口不足stride的1/2,则舍弃,否则以enable_end为结尾截取一个窗口 for enable_start, enable_end in enable_list: current_start = enable_start while current_start + window_sec <= enable_end: - segment_list.append((current_start, current_start + window_sec)) + if check_split(event_mask, current_start, window_sec, verbose): + segment_list.append((current_start, current_start + window_sec)) + else: + disable_segment_list.append((current_start, current_start + window_sec)) current_start += stride_sec # 检查最后一个窗口是否需要添加 if (enable_end - current_start >= stride_sec / 2) and (enable_end - current_start >= window_sec): - segment_list.append((enable_end - window_sec, enable_end)) + if check_split(event_mask, enable_end - window_sec, window_sec, verbose): + segment_list.append((enable_end - window_sec, enable_end)) + else: + disable_segment_list.append((enable_end - window_sec, enable_end)) - return segment_list + # 遍历每个disable区间, 如果最后一个窗口不足stride的1/2,则舍弃,否则以disable_end为结尾截取一个窗口 + for disable_start, disable_end in disable_list: + current_start = disable_start + while current_start + window_sec <= disable_end: + disable_segment_list.append((current_start, current_start + window_sec)) + current_start += stride_sec + # 检查最后一个窗口是否需要添加 + if (disable_end - current_start >= stride_sec / 2) and (disable_end - current_start >= window_sec): + disable_segment_list.append((disable_end - window_sec, disable_end)) + + + return segment_list, disable_segment_list