DataPrepare/dataset_builder/HYS_dataset.py
marques 8ee5980906 feat: Add utility functions for signal processing and event mapping
- Created a new module `utils/__init__.py` to consolidate utility imports.
- Added `event_map.py` for mapping apnea event types to numerical values and colors.
- Implemented various filtering functions in `filter_func.py`, including Butterworth, Bessel, downsampling, and notch filters.
- Developed `operation_tools.py` for dataset configuration loading, event mask generation, and signal processing utilities.
- Introduced `split_method.py` for segmenting data based on movement and amplitude criteria.
- Added `statistics_metrics.py` for calculating amplitude metrics and generating confusion matrices.
- Included a new Excel file for additional data storage.
2026-03-24 21:15:05 +08:00

233 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import multiprocessing
import sys
from pathlib import Path
import os
import numpy as np
os.environ['DISPLAY'] = "localhost:10.0"
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import utils
import signal_method
import draw_tools
import shutil
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None):
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
if not signal_path:
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
signal_path = signal_path[0]
if verbose:
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv")
if verbose:
print(f"mask_excel_path: {mask_excel_path}")
event_mask, event_list = utils.read_mask_execl(mask_excel_path)
bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=verbose)
bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs, verbose=verbose)
normalized_resp_signal = signal_method.normalize_resp_signal_by_segment(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"])
# 如果signal_data采样率过进行降采样
if signal_fs == 1000:
bcg_signal_notch = utils.downsample_signal_fast(original_signal=bcg_signal_notch, original_fs=signal_fs, target_fs=100)
bcg_signal_raw = utils.downsample_signal_fast(original_signal=bcg_signal_raw, original_fs=signal_fs,
target_fs=100)
signal_fs = 100
if bcg_fs == 1000:
bcg_signal = utils.downsample_signal_fast(original_signal=bcg_signal, original_fs=bcg_fs, target_fs=100)
bcg_fs = 100
# draw_tools.draw_signal_with_mask(samp_id=samp_id,
# signal_data=resp_signal,
# signal_fs=resp_fs,
# resp_data=normalized_resp_signal,
# resp_fs=resp_fs,
# bcg_data=bcg_signal,
# bcg_fs=bcg_fs,
# signal_disable_mask=event_mask["Disable_Label"],
# resp_low_amp_mask=event_mask["Resp_LowAmp_Label"],
# resp_movement_mask=event_mask["Resp_Movement_Label"],
# resp_change_mask=event_mask["Resp_AmpChange_Label"],
# resp_sa_mask=event_mask["SA_Label"],
# bcg_low_amp_mask=event_mask["BCG_LowAmp_Label"],
# bcg_movement_mask=event_mask["BCG_Movement_Label"],
# bcg_change_mask=event_mask["BCG_AmpChange_Label"],
# show=show,
# save_path=None)
segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose)
if verbose:
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
# 复制mask到processed_Labels文件夹
save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv"
shutil.copyfile(mask_excel_path, save_mask_excel_path)
# 复制SA Label_corrected.csv到processed_Labels文件夹
sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_corrected.csv")
if sa_label_corrected_path.exists():
save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv"
shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path)
else:
if verbose:
print(f"Warning: {sa_label_corrected_path} does not exist.")
# 保存处理后的信号和截取的片段列表
save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz"
save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz"
bcg_data = {
"bcg_signal_notch": {
"name": "BCG_Signal_Notch",
"data": bcg_signal_notch,
"fs": signal_fs,
"length": len(bcg_signal_notch),
"second": len(bcg_signal_notch) // signal_fs
},
"bcg_signal":{
"name": "BCG_Signal_Raw",
"data": bcg_signal,
"fs": bcg_fs,
"length": len(bcg_signal),
"second": len(bcg_signal) // bcg_fs
},
"resp_signal": {
"name": "Resp_Signal",
"data": normalized_resp_signal,
"fs": resp_fs,
"length": len(normalized_resp_signal),
"second": len(normalized_resp_signal) // resp_fs
}
}
np.savez_compressed(save_signal_path, **bcg_data)
np.savez_compressed(save_segment_path,
segment_list=segment_list,
disable_segment_list=disable_segment_list)
if verbose:
print(f"Saved processed signals to: {save_signal_path}")
print(f"Saved segment list to: {save_segment_path}")
if draw_segment:
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
psg_data["HR"] = {
"name": "HR",
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"], psg_data["Rpeak"]["fs"]),
"fs": psg_data["ECG_Sync"]["fs"],
"length": psg_data["ECG_Sync"]["length"],
"second": psg_data["ECG_Sync"]["second"]
}
psg_label = utils.read_psg_label(sa_label_corrected_path, verbose=verbose)
psg_event_mask, _ = utils.generate_event_mask(event_df=psg_label, signal_second=psg_data["ECG_Sync"]["second"], use_correct=False)
total_len = len(segment_list) + len(disable_segment_list)
if verbose:
print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}")
# draw_tools.draw_psg_bcg_label(psg_data=psg_data,
# psg_label=psg_event_mask,
# bcg_data=bcg_data,
# event_mask=event_mask,
# segment_list=segment_list,
# save_path=visual_path / f"{samp_id}" / "enable",
# verbose=verbose,
# multi_p=multi_p,
# multi_task_id=multi_task_id
# )
#
# draw_tools.draw_psg_bcg_label(
# psg_data=psg_data,
# psg_label=psg_event_mask,
# bcg_data=bcg_data,
# event_mask=event_mask,
# segment_list=disable_segment_list,
# save_path=visual_path / f"{samp_id}" / "disable",
# verbose=verbose,
# multi_p=multi_p,
# multi_task_id=multi_task_id
# )
def multiprocess_entry(_progress, task_id, _id):
build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id)
def multiprocess_with_pool(args_list, n_processes):
"""使用Pool每个worker处理固定数量任务后重启"""
from multiprocessing import Pool
# maxtasksperchild 设置每个worker处理多少任务后重启释放内存
with Pool(processes=n_processes, maxtasksperchild=2) as pool:
results = []
for samp_id in args_list:
result = pool.apply_async(
build_HYS_dataset_segment,
args=(samp_id, False, True, False, None, None)
)
results.append(result)
# 等待所有任务完成
for i, result in enumerate(results):
try:
result.get()
print(f"Completed {i + 1}/{len(args_list)}")
except Exception as e:
print(f"Error processing task {i}: {e}")
pool.close()
pool.join()
if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/HYS_config.yaml"
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
mask_path = Path(conf["mask_save_path"])
save_path = Path(conf["dataset_config"]["dataset_save_path"])
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
dataset_config = conf["dataset_config"]
visual_path.mkdir(parents=True, exist_ok=True)
save_processed_signal_path = save_path / "Signals"
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
save_segment_list_path = save_path / "Segments_List"
save_segment_list_path.mkdir(parents=True, exist_ok=True)
save_processed_label_path = save_path / "Labels"
save_processed_label_path.mkdir(parents=True, exist_ok=True)
# print(f"select_ids: {select_ids}")
# print(f"root_path: {root_path}")
# print(f"save_path: {save_path}")
# print(f"visual_path: {visual_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned"
print(select_ids)
# build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True)
# for samp_id in select_ids:
# print(f"Processing sample ID: {samp_id}")
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
# multiprocess_with_tqdm(args_list=select_ids, n_processes=16)
multiprocess_with_pool(args_list=select_ids, n_processes=16)