DataPrepare/dataset_builder/HYS_PSG_dataset.py

579 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import multiprocessing
import signal
import sys
import time
from pathlib import Path
import os
import numpy as np
os.environ['DISPLAY'] = "localhost:10.0"
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import utils
from utils import N2Chn
import signal_method
import draw_tools
import shutil
import gc
DEFAULT_YAML_PATH = project_root_path / "dataset_config/HYS_PSG_config.yaml"
conf = None
select_ids = None
root_path = None
mask_path = None
save_path = None
visual_path = None
dataset_config = None
org_signal_root_path = None
psg_signal_root_path = None
save_processed_signal_path = None
save_segment_list_path = None
save_processed_label_path = None
def get_missing_psg_channels(samp_id, channel_number=None):
ensure_runtime_initialized()
if channel_number is None:
channel_number = [1, 2, 3, 4, 5, 6, 7, 8]
sample_path = psg_signal_root_path / f"{samp_id}"
if not sample_path.exists():
return [f"PSG dir missing: {sample_path}"]
missing_channels = []
for ch_id in channel_number:
ch_name = N2Chn[ch_id]
if not any(sample_path.glob(f"{ch_name}*.txt")):
missing_channels.append(ch_name)
return missing_channels
def filter_valid_psg_samples(sample_ids, verbose=True):
valid_ids = []
skipped_ids = []
for samp_id in sample_ids:
missing_channels = get_missing_psg_channels(samp_id)
if missing_channels:
skipped_ids.append((samp_id, missing_channels))
if verbose:
print(
f"Skipping sample {samp_id}: missing PSG channels {missing_channels}",
flush=True,
)
continue
valid_ids.append(samp_id)
if verbose and skipped_ids:
print(
f"Filtered out {len(skipped_ids)} sample(s) with incomplete PSG inputs.",
flush=True,
)
return valid_ids, skipped_ids
def initialize_runtime(yaml_path=DEFAULT_YAML_PATH):
global conf
global select_ids
global root_path
global mask_path
global save_path
global visual_path
global dataset_config
global org_signal_root_path
global psg_signal_root_path
global save_processed_signal_path
global save_segment_list_path
global save_processed_label_path
yaml_path = Path(yaml_path)
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
mask_path = Path(conf["mask_save_path"])
save_path = Path(conf["dataset_config"]["dataset_save_path"])
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
dataset_config = conf["dataset_config"]
visual_path.mkdir(parents=True, exist_ok=True)
save_processed_signal_path = save_path / "Signals"
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
save_segment_list_path = save_path / "Segments_List"
save_segment_list_path.mkdir(parents=True, exist_ok=True)
save_processed_label_path = save_path / "Labels"
save_processed_label_path.mkdir(parents=True, exist_ok=True)
org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned"
def ensure_runtime_initialized():
if conf is None or psg_signal_root_path is None:
initialize_runtime()
def sanitize_rpeak_indices(samp_id, rpeak_indices, signal_length, verbose=True):
rpeak_indices = np.asarray(rpeak_indices, dtype=int)
valid_mask = (rpeak_indices >= 0) & (rpeak_indices < signal_length)
invalid_count = int((~valid_mask).sum())
if invalid_count > 0:
invalid_indices = rpeak_indices[~valid_mask]
print(
f"Sample {samp_id}: dropping {invalid_count} invalid Rpeak index/indices "
f"outside [0, {signal_length - 1}]. "
f"min_invalid={invalid_indices.min()}, max_invalid={invalid_indices.max()}",
flush=True,
)
rpeak_indices = rpeak_indices[valid_mask]
if rpeak_indices.size == 0:
raise ValueError(f"Sample {samp_id}: no valid Rpeak indices remain after bounds check.")
rpeak_indices = np.unique(rpeak_indices)
if rpeak_indices.size < 2:
raise ValueError(f"Sample {samp_id}: fewer than 2 valid Rpeak indices remain after bounds check.")
return rpeak_indices
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None):
ensure_runtime_initialized()
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
psg_data["Rpeak"]["data"] = sanitize_rpeak_indices(
samp_id=samp_id,
rpeak_indices=psg_data["Rpeak"]["data"],
signal_length=psg_data["ECG_Sync"]["length"],
verbose=verbose,
)
total_seconds = min(
psg_data[i]["second"] for i in N2Chn.values() if i != "Rpeak"
)
for i in N2Chn.values():
if i == "Rpeak":
continue
length = int(total_seconds * psg_data[i]["fs"])
psg_data[i]["data"] = psg_data[i]["data"][:length]
psg_data[i]["length"] = length
psg_data[i]["second"] = total_seconds
psg_data["HR"] = {
"name": "HR",
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"],
psg_data["Rpeak"]["fs"]),
"fs": psg_data["ECG_Sync"]["fs"],
"length": psg_data["ECG_Sync"]["length"],
"second": psg_data["ECG_Sync"]["second"]
}
# 预处理与滤波
tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Tho"]["data"], effort_fs=psg_data["Effort Tho"]["fs"])
abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Abd"]["data"], effort_fs=psg_data["Effort Abd"]["fs"])
flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow P"]["data"], effort_fs=psg_data["Flow P"]["fs"])
flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow T"]["data"], effort_fs=psg_data["Flow T"]["fs"])
rri, rri_fs = signal_method.rpeak2rri_interpolation(rpeak_indices=psg_data["Rpeak"]["data"], ecg_fs=psg_data["ECG_Sync"]["fs"], rri_fs=100)
mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv")
if verbose:
print(f"mask_excel_path: {mask_excel_path}")
event_mask, event_list = utils.read_mask_execl(mask_excel_path)
enable_list = [[0, psg_data["Effort Tho"]["second"]]]
normalized_tho_signal = signal_method.normalize_resp_signal_by_segment(tho_data_filt, tho_fs, np.zeros(psg_data["Effort Tho"]["second"]), enable_list)
normalized_abd_signal = signal_method.normalize_resp_signal_by_segment(abd_data_filt, abd_fs, np.zeros(psg_data["Effort Abd"]["second"]), enable_list)
normalized_flowp_signal = signal_method.normalize_resp_signal_by_segment(flowp_data_filt, flowp_fs, np.zeros(psg_data["Flow P"]["second"]), enable_list)
normalized_flowt_signal = signal_method.normalize_resp_signal_by_segment(flowt_data_filt, flowt_fs, np.zeros(psg_data["Flow T"]["second"]), enable_list)
# 都调整至100Hz采样率
target_fs = conf["target_fs"]
normalized_tho_signal = utils.adjust_sample_rate(normalized_tho_signal, tho_fs, target_fs)
normalized_abd_signal = utils.adjust_sample_rate(normalized_abd_signal, abd_fs, target_fs)
normalized_flowp_signal = utils.adjust_sample_rate(normalized_flowp_signal, flowp_fs, target_fs)
normalized_flowt_signal = utils.adjust_sample_rate(normalized_flowt_signal, flowt_fs, target_fs)
spo2_data_filt = utils.adjust_sample_rate(psg_data["SpO2"]["data"], psg_data["SpO2"]["fs"], target_fs)
normalized_effort_signal = (normalized_tho_signal + normalized_abd_signal) / 2
rri = utils.adjust_sample_rate(rri, rri_fs, target_fs)
# 调整至相同长度
min_length = min(len(normalized_tho_signal), len(normalized_abd_signal), len(normalized_flowp_signal), len(normalized_flowt_signal), len(spo2_data_filt), len(normalized_effort_signal)
,len(rri))
min_length = min_length - min_length % target_fs # 保证是整数秒
normalized_tho_signal = normalized_tho_signal[:min_length]
normalized_abd_signal = normalized_abd_signal[:min_length]
normalized_flowp_signal = normalized_flowp_signal[:min_length]
normalized_flowt_signal = normalized_flowt_signal[:min_length]
spo2_data_filt = spo2_data_filt[:min_length]
normalized_effort_signal = normalized_effort_signal[:min_length]
rri = rri[:min_length]
tho_second = min_length / target_fs
for i in event_mask.keys():
event_mask[i] = event_mask[i][:int(tho_second)]
spo2_data_filt_fill, spo2_disable_mask = utils.fill_spo2_anomaly(spo2_data=spo2_data_filt,
spo2_fs=target_fs,
max_fill_duration=conf["spo2_fill__anomaly"]["max_fill_duration"],
min_gap_duration=conf["spo2_fill__anomaly"]["min_gap_duration"])
draw_tools.draw_psg_signal(
samp_id=samp_id,
tho_signal=normalized_tho_signal,
abd_signal=normalized_abd_signal,
flowp_signal=normalized_flowp_signal,
flowt_signal=normalized_flowt_signal,
spo2_signal=spo2_data_filt,
effort_signal=normalized_effort_signal,
rri_signal = rri,
fs=target_fs,
event_mask=event_mask["SA_Label"],
save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots.png",
show=show
)
draw_tools.draw_psg_signal(
samp_id=samp_id,
tho_signal=normalized_tho_signal,
abd_signal=normalized_abd_signal,
flowp_signal=normalized_flowp_signal,
flowt_signal=normalized_flowt_signal,
spo2_signal=spo2_data_filt_fill,
effort_signal=normalized_effort_signal,
rri_signal = rri,
fs=target_fs,
event_mask=event_mask["SA_Label"],
save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots_fill.png",
show=show
)
spo2_disable_mask = spo2_disable_mask[::target_fs]
min_len = min(len(event_mask["Disable_Label"]), len(spo2_disable_mask))
if len(event_mask["Disable_Label"]) != len(spo2_disable_mask):
print(f"Warning: Data length mismatch! Truncating to {min_len}.")
event_mask["Disable_Label"] = event_mask["Disable_Label"][:min_len] & spo2_disable_mask[:min_len]
event_list = {
"EnableSegment": utils.event_mask_2_list(1 - event_mask["Disable_Label"]),
"DisableSegment": utils.event_mask_2_list(event_mask["Disable_Label"])}
spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=conf["spo2_fill__anomaly"]["nan_to_num_value"])
segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose)
if verbose:
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
# 复制mask到processed_Labels文件夹
save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv"
shutil.copyfile(mask_excel_path, save_mask_excel_path)
# 复制SA Label_corrected.csv到processed_Labels文件夹
sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_Sync.csv")
if sa_label_corrected_path.exists():
save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv"
shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path)
else:
if verbose:
print(f"Warning: {sa_label_corrected_path} does not exist.")
# 保存处理后的信号和截取的片段列表
save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz"
save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz"
# psg_data更新为处理后的信号
# 用下划线替换键里面的空格
psg_data = {
"Effort Tho": {
"name": "Effort_Tho",
"data": normalized_tho_signal,
"fs": target_fs,
"length": len(normalized_tho_signal),
"second": len(normalized_tho_signal) / target_fs
},
"Effort Abd": {
"name": "Effort_Abd",
"data": normalized_abd_signal,
"fs": target_fs,
"length": len(normalized_abd_signal),
"second": len(normalized_abd_signal) / target_fs
},
"Effort": {
"name": "Effort",
"data": normalized_effort_signal,
"fs": target_fs,
"length": len(normalized_effort_signal),
"second": len(normalized_effort_signal) / target_fs
},
"Flow P": {
"name": "Flow_P",
"data": normalized_flowp_signal,
"fs": target_fs,
"length": len(normalized_flowp_signal),
"second": len(normalized_flowp_signal) / target_fs
},
"Flow T": {
"name": "Flow_T",
"data": normalized_flowt_signal,
"fs": target_fs,
"length": len(normalized_flowt_signal),
"second": len(normalized_flowt_signal) / target_fs
},
"SpO2": {
"name": "SpO2",
"data": spo2_data_filt_fill,
"fs": target_fs,
"length": len(spo2_data_filt_fill),
"second": len(spo2_data_filt_fill) / target_fs
},
"HR": {
"name": "HR",
"data": psg_data["HR"]["data"],
"fs": psg_data["HR"]["fs"],
"length": psg_data["HR"]["length"],
"second": psg_data["HR"]["second"]
},
"RRI": {
"name": "RRI",
"data": rri,
"fs": target_fs,
"length": len(rri),
"second": len(rri) / target_fs
},
"5_class": {
"name": "Stage",
"data": psg_data["5_class"]["data"],
"fs": psg_data["5_class"]["fs"],
"length": psg_data["5_class"]["length"],
"second": psg_data["5_class"]["second"]
}
}
np.savez_compressed(save_signal_path, **psg_data)
np.savez_compressed(save_segment_path,
segment_list=segment_list,
disable_segment_list=disable_segment_list)
if verbose:
print(f"Saved processed signals to: {save_signal_path}")
print(f"Saved segment list to: {save_segment_path}")
if draw_segment:
total_len = len(segment_list) + len(disable_segment_list)
if verbose:
print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}")
draw_tools.draw_psg_label(
psg_data=psg_data,
psg_label=event_mask["SA_Label"],
segment_list=segment_list,
save_path=visual_path / f"{samp_id}" / "enable",
verbose=verbose,
multi_p=multi_p,
multi_task_id=multi_task_id
)
draw_tools.draw_psg_label(
psg_data=psg_data,
psg_label=event_mask["SA_Label"],
segment_list=disable_segment_list,
save_path=visual_path / f"{samp_id}" / "disable",
verbose=verbose,
multi_p=multi_p,
multi_task_id=multi_task_id
)
# 显式删除大型对象
try:
del psg_data
del normalized_tho_signal, normalized_abd_signal
del normalized_flowp_signal, normalized_flowt_signal
del normalized_effort_signal
del spo2_data_filt, spo2_data_filt_fill
del rri
del event_mask, event_list
del segment_list, disable_segment_list
except:
pass
# 强制垃圾回收
gc.collect()
def multiprocess_entry(_progress, task_id, _id):
build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id)
def _init_pool_worker(yaml_path=DEFAULT_YAML_PATH):
# 让主进程统一响应 Ctrl+C避免父子进程同时处理中断导致退出卡住。
signal.signal(signal.SIGINT, signal.SIG_IGN)
initialize_runtime(yaml_path)
def _shutdown_executor(executor, wait, cancel_futures=False):
if executor is None:
return
try:
executor.shutdown(wait=wait, cancel_futures=cancel_futures)
except TypeError:
executor.shutdown(wait=wait)
def multiprocess_with_tqdm(args_list, n_processes):
from concurrent.futures import ProcessPoolExecutor
from rich import progress
yaml_path = str(DEFAULT_YAML_PATH)
with progress.Progress(
"[progress.description]{task.description}",
progress.BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
progress.MofNCompleteColumn(),
progress.TimeRemainingColumn(),
progress.TimeElapsedColumn(),
refresh_per_second=1, # bit slower updates
transient=False
) as progress:
futures = []
with multiprocessing.Manager() as manager:
_progress = manager.dict()
overall_progress_task = progress.add_task("[green]All jobs progress:")
executor = ProcessPoolExecutor(
max_workers=n_processes,
mp_context=multiprocessing.get_context("spawn"),
initializer=_init_pool_worker,
initargs=(yaml_path,),
)
try:
for i_args in range(len(args_list)):
args = args_list[i_args]
task_id = progress.add_task(f"task {i_args}", visible=True)
futures.append(executor.submit(multiprocess_entry, _progress, task_id, args_list[i_args]))
# monitor the progress:
while (n_finished := sum([future.done() for future in futures])) < len(
futures
):
progress.update(
overall_progress_task, completed=n_finished, total=len(futures)
)
for task_id, update_data in _progress.items():
desc = update_data.get("desc", "")
# update the progress bar for this task:
progress.update(
task_id,
completed=update_data.get("progress", 0),
total=update_data.get("total", 0),
description=desc
)
time.sleep(0.2)
# raise any errors:
for future in futures:
future.result()
except KeyboardInterrupt:
print("\nKeyboardInterrupt received, cancelling pending jobs...", flush=True)
for future in futures:
future.cancel()
_shutdown_executor(executor, wait=False, cancel_futures=True)
executor = None
raise SystemExit(130)
finally:
_shutdown_executor(executor, wait=True)
def multiprocess_with_pool(args_list, n_processes):
"""使用Pool每个worker处理固定数量任务后重启"""
if not args_list:
return
ctx = multiprocessing.get_context("spawn")
yaml_path = str(DEFAULT_YAML_PATH)
pool = ctx.Pool(
processes=n_processes,
maxtasksperchild=2,
initializer=_init_pool_worker,
initargs=(yaml_path,),
)
pending_results = {}
completed = 0
try:
for samp_id in args_list:
pending_results[samp_id] = pool.apply_async(
build_HYS_dataset_segment,
args=(samp_id, False, True, False, None, None)
)
pool.close()
while pending_results:
finished_ids = []
for samp_id, result in pending_results.items():
if not result.ready():
continue
try:
result.get()
completed += 1
print(f"Completed {completed}/{len(args_list)}: {samp_id}", flush=True)
except Exception as e:
completed += 1
print(f"Error processing {samp_id}: {e}", flush=True)
finished_ids.append(samp_id)
for samp_id in finished_ids:
pending_results.pop(samp_id, None)
if pending_results:
time.sleep(0.5)
pool.join()
except KeyboardInterrupt:
print("\nKeyboardInterrupt received, terminating worker processes...", flush=True)
pool.terminate()
pool.join()
raise SystemExit(130)
except Exception:
pool.terminate()
pool.join()
raise
if __name__ == '__main__':
yaml_path = DEFAULT_YAML_PATH
initialize_runtime(yaml_path)
print(select_ids)
valid_select_ids, skipped_select_ids = filter_valid_psg_samples(select_ids, verbose=True)
print(f"Valid PSG samples: {len(valid_select_ids)} / {len(select_ids)}", flush=True)
# build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True)
# for samp_id in select_ids:
# print(f"Processing sample ID: {samp_id}")
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
# multiprocess_with_tqdm(args_list=select_ids, n_processes=8)
# multiprocess_with_pool(args_list=valid_select_ids, n_processes=8)