DataPrepare/dataset_builder/HYS_dataset.py

171 lines
7.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
from pathlib import Path
import os
import numpy as np
os.environ['DISPLAY'] = "localhost:10.0"
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import utils
import signal_method
import draw_tools
import shutil
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
if not signal_path:
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
signal_path = signal_path[0]
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv")
print(f"mask_excel_path: {mask_excel_path}")
event_mask, event_list = utils.read_mask_execl(mask_excel_path)
bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float)
bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs)
normalized_resp_signal = signal_method.normalize_resp_signal(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"])
# 如果signal_data采样率过进行降采样
if signal_fs == 1000:
bcg_signal_notch = utils.downsample_signal_fast(original_signal=bcg_signal_notch, original_fs=signal_fs, target_fs=100)
bcg_signal_raw = utils.downsample_signal_fast(original_signal=bcg_signal_raw, original_fs=signal_fs,
target_fs=100)
signal_fs = 100
if bcg_fs == 1000:
bcg_signal = utils.downsample_signal_fast(original_signal=bcg_signal, original_fs=bcg_fs, target_fs=100)
bcg_fs = 100
# draw_tools.draw_signal_with_mask(samp_id=samp_id,
# signal_data=resp_signal,
# signal_fs=resp_fs,
# resp_data=normalized_resp_signal,
# resp_fs=resp_fs,
# bcg_data=bcg_signal,
# bcg_fs=bcg_fs,
# signal_disable_mask=event_mask["Disable_Label"],
# resp_low_amp_mask=event_mask["Resp_LowAmp_Label"],
# resp_movement_mask=event_mask["Resp_Movement_Label"],
# resp_change_mask=event_mask["Resp_AmpChange_Label"],
# resp_sa_mask=event_mask["SA_Label"],
# bcg_low_amp_mask=event_mask["BCG_LowAmp_Label"],
# bcg_movement_mask=event_mask["BCG_Movement_Label"],
# bcg_change_mask=event_mask["BCG_AmpChange_Label"],
# show=show,
# save_path=None)
segment_list = utils.resp_split(dataset_config, event_mask, event_list)
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
# 复制mask到processed_Labels文件夹
save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv"
shutil.copyfile(mask_excel_path, save_mask_excel_path)
# 复制SA Label_corrected.csv到processed_Labels文件夹
sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_corrected.csv")
if sa_label_corrected_path.exists():
save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv"
shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path)
else:
print(f"Warning: {sa_label_corrected_path} does not exist.")
# 保存处理后的信号和截取的片段列表
save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz"
save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz"
bcg_data = {
"bcg_signal_notch": {
"name": "BCG_Signal_Notch",
"data": bcg_signal_notch,
"fs": signal_fs,
"length": len(bcg_signal_notch),
"second": len(bcg_signal_notch) // signal_fs
},
"bcg_signal":{
"name": "BCG_Signal_Raw",
"data": bcg_signal,
"fs": bcg_fs,
"length": len(bcg_signal),
"second": len(bcg_signal) // bcg_fs
},
"resp_signal": {
"name": "Resp_Signal",
"data": normalized_resp_signal,
"fs": resp_fs,
"length": len(normalized_resp_signal),
"second": len(normalized_resp_signal) // resp_fs
}
}
np.savez_compressed(save_signal_path, **bcg_data)
np.savez_compressed(save_segment_path,
segment_list=segment_list)
print(f"Saved processed signals to: {save_signal_path}")
print(f"Saved segment list to: {save_segment_path}")
if draw_segment:
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8])
psg_data["HR"] = {
"name": "HR",
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"]),
"fs": psg_data["ECG_Sync"]["fs"],
"length": psg_data["ECG_Sync"]["length"],
"second": psg_data["ECG_Sync"]["second"]
}
psg_label = utils.read_psg_label(sa_label_corrected_path)
psg_event_mask, _ = utils.generate_event_mask(event_df=psg_label, signal_second=psg_data["ECG_Sync"]["second"], use_correct=False)
draw_tools.draw_psg_bcg_label(psg_data=psg_data,
psg_label=psg_event_mask,
bcg_data=bcg_data,
event_mask=event_mask,
segment_list=segment_list,
save_path=visual_path / f"{samp_id}")
if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/HYS_config.yaml"
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
mask_path = Path(conf["mask_save_path"])
save_path = Path(conf["dataset_config"]["dataset_save_path"])
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
dataset_config = conf["dataset_config"]
visual_path.mkdir(parents=True, exist_ok=True)
save_processed_signal_path = save_path / "Signals"
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
save_segment_list_path = save_path / "Segments_List"
save_segment_list_path.mkdir(parents=True, exist_ok=True)
save_processed_label_path = save_path / "Labels"
save_processed_label_path.mkdir(parents=True, exist_ok=True)
print(f"select_ids: {select_ids}")
print(f"root_path: {root_path}")
print(f"save_path: {save_path}")
print(f"visual_path: {visual_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned"
# build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True)
for samp_id in select_ids:
print(f"Processing sample ID: {samp_id}")
build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)