DataPrepare/dataset_builder/HYS_PSG_dataset.py

395 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import multiprocessing
import sys
from pathlib import Path
import os
import numpy as np
from utils import N2Chn
os.environ['DISPLAY'] = "localhost:10.0"
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import utils
import signal_method
import draw_tools
import shutil
import gc
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None):
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
total_seconds = min(
psg_data[i]["second"] for i in N2Chn.values() if i != "Rpeak"
)
for i in N2Chn.values():
if i == "Rpeak":
continue
length = int(total_seconds * psg_data[i]["fs"])
psg_data[i]["data"] = psg_data[i]["data"][:length]
psg_data[i]["length"] = length
psg_data[i]["second"] = total_seconds
psg_data["HR"] = {
"name": "HR",
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"],
psg_data["Rpeak"]["fs"]),
"fs": psg_data["ECG_Sync"]["fs"],
"length": psg_data["ECG_Sync"]["length"],
"second": psg_data["ECG_Sync"]["second"]
}
# 预处理与滤波
tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Tho"]["data"], effort_fs=psg_data["Effort Tho"]["fs"])
abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Abd"]["data"], effort_fs=psg_data["Effort Abd"]["fs"])
flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow P"]["data"], effort_fs=psg_data["Flow P"]["fs"])
flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow T"]["data"], effort_fs=psg_data["Flow T"]["fs"])
rri, rri_fs = signal_method.rpeak2rri_interpolation(rpeak_indices=psg_data["Rpeak"]["data"], ecg_fs=psg_data["ECG_Sync"]["fs"], rri_fs=100)
mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv")
if verbose:
print(f"mask_excel_path: {mask_excel_path}")
event_mask, event_list = utils.read_mask_execl(mask_excel_path)
enable_list = [[0, psg_data["Effort Tho"]["second"]]]
normalized_tho_signal = signal_method.normalize_resp_signal_by_segment(tho_data_filt, tho_fs, np.zeros(psg_data["Effort Tho"]["second"]), enable_list)
normalized_abd_signal = signal_method.normalize_resp_signal_by_segment(abd_data_filt, abd_fs, np.zeros(psg_data["Effort Abd"]["second"]), enable_list)
normalized_flowp_signal = signal_method.normalize_resp_signal_by_segment(flowp_data_filt, flowp_fs, np.zeros(psg_data["Flow P"]["second"]), enable_list)
normalized_flowt_signal = signal_method.normalize_resp_signal_by_segment(flowt_data_filt, flowt_fs, np.zeros(psg_data["Flow T"]["second"]), enable_list)
# 都调整至100Hz采样率
target_fs = 100
normalized_tho_signal = utils.adjust_sample_rate(normalized_tho_signal, tho_fs, target_fs)
normalized_abd_signal = utils.adjust_sample_rate(normalized_abd_signal, abd_fs, target_fs)
normalized_flowp_signal = utils.adjust_sample_rate(normalized_flowp_signal, flowp_fs, target_fs)
normalized_flowt_signal = utils.adjust_sample_rate(normalized_flowt_signal, flowt_fs, target_fs)
spo2_data_filt = utils.adjust_sample_rate(psg_data["SpO2"]["data"], psg_data["SpO2"]["fs"], target_fs)
normalized_effort_signal = (normalized_tho_signal + normalized_abd_signal) / 2
rri = utils.adjust_sample_rate(rri, rri_fs, target_fs)
# 调整至相同长度
min_length = min(len(normalized_tho_signal), len(normalized_abd_signal), len(normalized_flowp_signal), len(normalized_flowt_signal), len(spo2_data_filt), len(normalized_effort_signal)
,len(rri))
min_length = min_length - min_length % target_fs # 保证是整数秒
normalized_tho_signal = normalized_tho_signal[:min_length]
normalized_abd_signal = normalized_abd_signal[:min_length]
normalized_flowp_signal = normalized_flowp_signal[:min_length]
normalized_flowt_signal = normalized_flowt_signal[:min_length]
spo2_data_filt = spo2_data_filt[:min_length]
normalized_effort_signal = normalized_effort_signal[:min_length]
rri = rri[:min_length]
tho_second = min_length / target_fs
for i in event_mask.keys():
event_mask[i] = event_mask[i][:int(tho_second)]
spo2_data_filt_fill, spo2_disable_mask = utils.fill_spo2_anomaly(spo2_data=spo2_data_filt,
spo2_fs=target_fs,
max_fill_duration=30,
min_gap_duration=10,)
draw_tools.draw_psg_signal(
samp_id=samp_id,
tho_signal=normalized_tho_signal,
abd_signal=normalized_abd_signal,
flowp_signal=normalized_flowp_signal,
flowt_signal=normalized_flowt_signal,
spo2_signal=spo2_data_filt,
effort_signal=normalized_effort_signal,
rri_signal = rri,
fs=target_fs,
event_mask=event_mask["SA_Label"],
save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots.png",
show=show
)
draw_tools.draw_psg_signal(
samp_id=samp_id,
tho_signal=normalized_tho_signal,
abd_signal=normalized_abd_signal,
flowp_signal=normalized_flowp_signal,
flowt_signal=normalized_flowt_signal,
spo2_signal=spo2_data_filt_fill,
effort_signal=normalized_effort_signal,
rri_signal = rri,
fs=target_fs,
event_mask=event_mask["SA_Label"],
save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots_fill.png",
show=show
)
spo2_disable_mask = spo2_disable_mask[::target_fs]
min_len = min(len(event_mask["Disable_Label"]), len(spo2_disable_mask))
if len(event_mask["Disable_Label"]) != len(spo2_disable_mask):
print(f"Warning: Data length mismatch! Truncating to {min_len}.")
event_mask["Disable_Label"] = event_mask["Disable_Label"][:min_len] & spo2_disable_mask[:min_len]
event_list = {
"EnableSegment": utils.event_mask_2_list(1 - event_mask["Disable_Label"]),
"DisableSegment": utils.event_mask_2_list(event_mask["Disable_Label"])}
spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=95)
segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose)
if verbose:
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
# 复制mask到processed_Labels文件夹
save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv"
shutil.copyfile(mask_excel_path, save_mask_excel_path)
# 复制SA Label_corrected.csv到processed_Labels文件夹
sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_Sync.csv")
if sa_label_corrected_path.exists():
save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv"
shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path)
else:
if verbose:
print(f"Warning: {sa_label_corrected_path} does not exist.")
# 保存处理后的信号和截取的片段列表
save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz"
save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz"
# psg_data更新为处理后的信号
# 用下划线替换键里面的空格
psg_data = {
"Effort Tho": {
"name": "Effort_Tho",
"data": normalized_tho_signal,
"fs": target_fs,
"length": len(normalized_tho_signal),
"second": len(normalized_tho_signal) / target_fs
},
"Effort Abd": {
"name": "Effort_Abd",
"data": normalized_abd_signal,
"fs": target_fs,
"length": len(normalized_abd_signal),
"second": len(normalized_abd_signal) / target_fs
},
"Effort": {
"name": "Effort",
"data": normalized_effort_signal,
"fs": target_fs,
"length": len(normalized_effort_signal),
"second": len(normalized_effort_signal) / target_fs
},
"Flow P": {
"name": "Flow_P",
"data": normalized_flowp_signal,
"fs": target_fs,
"length": len(normalized_flowp_signal),
"second": len(normalized_flowp_signal) / target_fs
},
"Flow T": {
"name": "Flow_T",
"data": normalized_flowt_signal,
"fs": target_fs,
"length": len(normalized_flowt_signal),
"second": len(normalized_flowt_signal) / target_fs
},
"SpO2": {
"name": "SpO2",
"data": spo2_data_filt_fill,
"fs": target_fs,
"length": len(spo2_data_filt_fill),
"second": len(spo2_data_filt_fill) / target_fs
},
"HR": {
"name": "HR",
"data": psg_data["HR"]["data"],
"fs": psg_data["HR"]["fs"],
"length": psg_data["HR"]["length"],
"second": psg_data["HR"]["second"]
},
"RRI": {
"name": "RRI",
"data": rri,
"fs": target_fs,
"length": len(rri),
"second": len(rri) / target_fs
},
"5_class": {
"name": "Stage",
"data": psg_data["5_class"]["data"],
"fs": psg_data["5_class"]["fs"],
"length": psg_data["5_class"]["length"],
"second": psg_data["5_class"]["second"]
}
}
np.savez_compressed(save_signal_path, **psg_data)
np.savez_compressed(save_segment_path,
segment_list=segment_list,
disable_segment_list=disable_segment_list)
if verbose:
print(f"Saved processed signals to: {save_signal_path}")
print(f"Saved segment list to: {save_segment_path}")
if draw_segment:
total_len = len(segment_list) + len(disable_segment_list)
if verbose:
print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}")
draw_tools.draw_psg_label(
psg_data=psg_data,
psg_label=event_mask["SA_Label"],
segment_list=segment_list,
save_path=visual_path / f"{samp_id}" / "enable",
verbose=verbose,
multi_p=multi_p,
multi_task_id=multi_task_id
)
draw_tools.draw_psg_label(
psg_data=psg_data,
psg_label=event_mask["SA_Label"],
segment_list=disable_segment_list,
save_path=visual_path / f"{samp_id}" / "disable",
verbose=verbose,
multi_p=multi_p,
multi_task_id=multi_task_id
)
# 显式删除大型对象
try:
del psg_data
del normalized_tho_signal, normalized_abd_signal
del normalized_flowp_signal, normalized_flowt_signal
del normalized_effort_signal
del spo2_data_filt, spo2_data_filt_fill
del rri
del event_mask, event_list
del segment_list, disable_segment_list
except:
pass
# 强制垃圾回收
gc.collect()
def multiprocess_entry(_progress, task_id, _id):
build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id)
def multiprocess_with_tqdm(args_list, n_processes):
from concurrent.futures import ProcessPoolExecutor
from rich import progress
with progress.Progress(
"[progress.description]{task.description}",
progress.BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
progress.MofNCompleteColumn(),
progress.TimeRemainingColumn(),
progress.TimeElapsedColumn(),
refresh_per_second=1, # bit slower updates
transient=False
) as progress:
futures = []
with multiprocessing.Manager() as manager:
_progress = manager.dict()
overall_progress_task = progress.add_task("[green]All jobs progress:")
with ProcessPoolExecutor(max_workers=n_processes, mp_context=multiprocessing.get_context("spawn")) as executor:
for i_args in range(len(args_list)):
args = args_list[i_args]
task_id = progress.add_task(f"task {i_args}", visible=True)
futures.append(executor.submit(multiprocess_entry, _progress, task_id, args_list[i_args]))
# monitor the progress:
while (n_finished := sum([future.done() for future in futures])) < len(
futures
):
progress.update(
overall_progress_task, completed=n_finished, total=len(futures)
)
for task_id, update_data in _progress.items():
desc = update_data.get("desc", "")
# update the progress bar for this task:
progress.update(
task_id,
completed=update_data.get("progress", 0),
total=update_data.get("total", 0),
description=desc
)
# raise any errors:
for future in futures:
future.result()
def multiprocess_with_pool(args_list, n_processes):
"""使用Pool每个worker处理固定数量任务后重启"""
from multiprocessing import Pool
# maxtasksperchild 设置每个worker处理多少任务后重启释放内存
with Pool(processes=n_processes, maxtasksperchild=2) as pool:
results = []
for samp_id in args_list:
result = pool.apply_async(
build_HYS_dataset_segment,
args=(samp_id, False, True, False, None, None)
)
results.append(result)
# 等待所有任务完成
for i, result in enumerate(results):
try:
result.get()
print(f"Completed {i + 1}/{len(args_list)}")
except Exception as e:
print(f"Error processing task {i}: {e}")
pool.close()
pool.join()
if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml"
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
mask_path = Path(conf["mask_save_path"])
save_path = Path(conf["dataset_config"]["dataset_save_path"])
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
dataset_config = conf["dataset_config"]
visual_path.mkdir(parents=True, exist_ok=True)
save_processed_signal_path = save_path / "Signals"
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
save_segment_list_path = save_path / "Segments_List"
save_segment_list_path.mkdir(parents=True, exist_ok=True)
save_processed_label_path = save_path / "Labels"
save_processed_label_path.mkdir(parents=True, exist_ok=True)
# print(f"select_ids: {select_ids}")
# print(f"root_path: {root_path}")
# print(f"save_path: {save_path}")
# print(f"visual_path: {visual_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned"
print(select_ids)
# build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True)
# for samp_id in select_ids:
# print(f"Processing sample ID: {samp_id}")
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
# multiprocess_with_tqdm(args_list=select_ids, n_processes=8)
multiprocess_with_pool(args_list=select_ids, n_processes=8)