feat: 添加数据处理功能,重构信号读取与处理逻辑,更新配置文件以支持新样本ID

This commit is contained in:
marques 2026-03-30 09:37:14 +08:00
parent a42a482e1c
commit 3adcf00abb
5 changed files with 329 additions and 102 deletions

4
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,4 @@
{
"python-envs.defaultEnvManager": "ms-python.python:conda",
"python-envs.defaultPackageManager": "ms-python.python:conda"
}

View File

@ -1,25 +1,163 @@
import multiprocessing
import signal
import sys
import time
from pathlib import Path
import os
import numpy as np
from utils import N2Chn
os.environ['DISPLAY'] = "localhost:10.0"
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import utils
from utils import N2Chn
import signal_method
import draw_tools
import shutil
import gc
DEFAULT_YAML_PATH = project_root_path / "dataset_config/HYS_PSG_config.yaml"
conf = None
select_ids = None
root_path = None
mask_path = None
save_path = None
visual_path = None
dataset_config = None
org_signal_root_path = None
psg_signal_root_path = None
save_processed_signal_path = None
save_segment_list_path = None
save_processed_label_path = None
def get_missing_psg_channels(samp_id, channel_number=None):
ensure_runtime_initialized()
if channel_number is None:
channel_number = [1, 2, 3, 4, 5, 6, 7, 8]
sample_path = psg_signal_root_path / f"{samp_id}"
if not sample_path.exists():
return [f"PSG dir missing: {sample_path}"]
missing_channels = []
for ch_id in channel_number:
ch_name = N2Chn[ch_id]
if not any(sample_path.glob(f"{ch_name}*.txt")):
missing_channels.append(ch_name)
return missing_channels
def filter_valid_psg_samples(sample_ids, verbose=True):
valid_ids = []
skipped_ids = []
for samp_id in sample_ids:
missing_channels = get_missing_psg_channels(samp_id)
if missing_channels:
skipped_ids.append((samp_id, missing_channels))
if verbose:
print(
f"Skipping sample {samp_id}: missing PSG channels {missing_channels}",
flush=True,
)
continue
valid_ids.append(samp_id)
if verbose and skipped_ids:
print(
f"Filtered out {len(skipped_ids)} sample(s) with incomplete PSG inputs.",
flush=True,
)
return valid_ids, skipped_ids
def initialize_runtime(yaml_path=DEFAULT_YAML_PATH):
global conf
global select_ids
global root_path
global mask_path
global save_path
global visual_path
global dataset_config
global org_signal_root_path
global psg_signal_root_path
global save_processed_signal_path
global save_segment_list_path
global save_processed_label_path
yaml_path = Path(yaml_path)
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
mask_path = Path(conf["mask_save_path"])
save_path = Path(conf["dataset_config"]["dataset_save_path"])
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
dataset_config = conf["dataset_config"]
visual_path.mkdir(parents=True, exist_ok=True)
save_processed_signal_path = save_path / "Signals"
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
save_segment_list_path = save_path / "Segments_List"
save_segment_list_path.mkdir(parents=True, exist_ok=True)
save_processed_label_path = save_path / "Labels"
save_processed_label_path.mkdir(parents=True, exist_ok=True)
org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned"
def ensure_runtime_initialized():
if conf is None or psg_signal_root_path is None:
initialize_runtime()
def sanitize_rpeak_indices(samp_id, rpeak_indices, signal_length, verbose=True):
rpeak_indices = np.asarray(rpeak_indices, dtype=int)
valid_mask = (rpeak_indices >= 0) & (rpeak_indices < signal_length)
invalid_count = int((~valid_mask).sum())
if invalid_count > 0:
invalid_indices = rpeak_indices[~valid_mask]
print(
f"Sample {samp_id}: dropping {invalid_count} invalid Rpeak index/indices "
f"outside [0, {signal_length - 1}]. "
f"min_invalid={invalid_indices.min()}, max_invalid={invalid_indices.max()}",
flush=True,
)
rpeak_indices = rpeak_indices[valid_mask]
if rpeak_indices.size == 0:
raise ValueError(f"Sample {samp_id}: no valid Rpeak indices remain after bounds check.")
rpeak_indices = np.unique(rpeak_indices)
if rpeak_indices.size < 2:
raise ValueError(f"Sample {samp_id}: fewer than 2 valid Rpeak indices remain after bounds check.")
return rpeak_indices
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None):
ensure_runtime_initialized()
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
psg_data["Rpeak"]["data"] = sanitize_rpeak_indices(
samp_id=samp_id,
rpeak_indices=psg_data["Rpeak"]["data"],
signal_length=psg_data["ECG_Sync"]["length"],
verbose=verbose,
)
total_seconds = min(
psg_data[i]["second"] for i in N2Chn.values() if i != "Rpeak"
@ -283,10 +421,28 @@ def multiprocess_entry(_progress, task_id, _id):
build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id)
def _init_pool_worker(yaml_path=DEFAULT_YAML_PATH):
# 让主进程统一响应 Ctrl+C避免父子进程同时处理中断导致退出卡住。
signal.signal(signal.SIGINT, signal.SIG_IGN)
initialize_runtime(yaml_path)
def _shutdown_executor(executor, wait, cancel_futures=False):
if executor is None:
return
try:
executor.shutdown(wait=wait, cancel_futures=cancel_futures)
except TypeError:
executor.shutdown(wait=wait)
def multiprocess_with_tqdm(args_list, n_processes):
from concurrent.futures import ProcessPoolExecutor
from rich import progress
yaml_path = str(DEFAULT_YAML_PATH)
with progress.Progress(
"[progress.description]{task.description}",
progress.BarColumn(),
@ -301,7 +457,13 @@ def multiprocess_with_tqdm(args_list, n_processes):
with multiprocessing.Manager() as manager:
_progress = manager.dict()
overall_progress_task = progress.add_task("[green]All jobs progress:")
with ProcessPoolExecutor(max_workers=n_processes, mp_context=multiprocessing.get_context("spawn")) as executor:
executor = ProcessPoolExecutor(
max_workers=n_processes,
mp_context=multiprocessing.get_context("spawn"),
initializer=_init_pool_worker,
initargs=(yaml_path,),
)
try:
for i_args in range(len(args_list)):
args = args_list[i_args]
task_id = progress.add_task(f"task {i_args}", visible=True)
@ -322,69 +484,90 @@ def multiprocess_with_tqdm(args_list, n_processes):
total=update_data.get("total", 0),
description=desc
)
time.sleep(0.2)
# raise any errors:
for future in futures:
future.result()
except KeyboardInterrupt:
print("\nKeyboardInterrupt received, cancelling pending jobs...", flush=True)
for future in futures:
future.cancel()
_shutdown_executor(executor, wait=False, cancel_futures=True)
executor = None
raise SystemExit(130)
finally:
_shutdown_executor(executor, wait=True)
def multiprocess_with_pool(args_list, n_processes):
"""使用Pool每个worker处理固定数量任务后重启"""
from multiprocessing import Pool
if not args_list:
return
# maxtasksperchild 设置每个worker处理多少任务后重启释放内存
with Pool(processes=n_processes, maxtasksperchild=2) as pool:
results = []
ctx = multiprocessing.get_context("spawn")
yaml_path = str(DEFAULT_YAML_PATH)
pool = ctx.Pool(
processes=n_processes,
maxtasksperchild=2,
initializer=_init_pool_worker,
initargs=(yaml_path,),
)
pending_results = {}
completed = 0
try:
for samp_id in args_list:
result = pool.apply_async(
pending_results[samp_id] = pool.apply_async(
build_HYS_dataset_segment,
args=(samp_id, False, True, False, None, None)
)
results.append(result)
# 等待所有任务完成
for i, result in enumerate(results):
try:
result.get()
print(f"Completed {i + 1}/{len(args_list)}")
except Exception as e:
print(f"Error processing task {i}: {e}")
pool.close()
while pending_results:
finished_ids = []
for samp_id, result in pending_results.items():
if not result.ready():
continue
try:
result.get()
completed += 1
print(f"Completed {completed}/{len(args_list)}: {samp_id}", flush=True)
except Exception as e:
completed += 1
print(f"Error processing {samp_id}: {e}", flush=True)
finished_ids.append(samp_id)
for samp_id in finished_ids:
pending_results.pop(samp_id, None)
if pending_results:
time.sleep(0.5)
pool.join()
except KeyboardInterrupt:
print("\nKeyboardInterrupt received, terminating worker processes...", flush=True)
pool.terminate()
pool.join()
raise SystemExit(130)
except Exception:
pool.terminate()
pool.join()
raise
if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml"
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
mask_path = Path(conf["mask_save_path"])
save_path = Path(conf["dataset_config"]["dataset_save_path"])
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
dataset_config = conf["dataset_config"]
visual_path.mkdir(parents=True, exist_ok=True)
save_processed_signal_path = save_path / "Signals"
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
save_segment_list_path = save_path / "Segments_List"
save_segment_list_path.mkdir(parents=True, exist_ok=True)
save_processed_label_path = save_path / "Labels"
save_processed_label_path.mkdir(parents=True, exist_ok=True)
# print(f"select_ids: {select_ids}")
# print(f"root_path: {root_path}")
# print(f"save_path: {save_path}")
# print(f"visual_path: {visual_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned"
yaml_path = DEFAULT_YAML_PATH
initialize_runtime(yaml_path)
print(select_ids)
valid_select_ids, skipped_select_ids = filter_valid_psg_samples(select_ids, verbose=True)
print(f"Valid PSG samples: {len(valid_select_ids)} / {len(select_ids)}", flush=True)
# build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True)
# for samp_id in select_ids:
@ -392,4 +575,4 @@ if __name__ == '__main__':
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
# multiprocess_with_tqdm(args_list=select_ids, n_processes=8)
multiprocess_with_pool(args_list=select_ids, n_processes=8)
# multiprocess_with_pool(args_list=valid_select_ids, n_processes=8)

View File

@ -1,27 +1,48 @@
select_ids:
- 54
- 88
- 1000
- 1004
- 1006
- 1009
- 1010
- 1300
- 1302
- 1308
- 1314
- 1354
- 1378
- 220
- 221
- 229
- 282
- 285
- 286
- 54
- 541
- 579
- 582
- 670
- 671
- 683
- 684
- 686
- 703
- 704
- 726
- 735
- 736
- 88
- 893
- 933
- 935
- 950
- 952
- 954
- 955
- 956
- 960
- 961
- 962
- 967
- 1302
- 971
- 972
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG

View File

@ -35,59 +35,77 @@ import os
os.environ['DISPLAY'] = "localhost:10.0"
def resolve_sample_file(sample_dir: Path, prefix: str, suffix=".txt", prefer_tokens=("Sync", "RoughCut")) -> Path:
candidates = sorted(sample_dir.glob(f"{prefix}*{suffix}"))
if not candidates:
if sample_dir.exists():
available_files = ", ".join(sorted(path.name for path in sample_dir.iterdir()))
else:
available_files = "<sample dir missing>"
raise FileNotFoundError(
f"{prefix} file not found in {sample_dir}. "
f"searched pattern: {prefix}*{suffix}. available: {available_files}"
)
for token in prefer_tokens:
preferred = [
path for path in candidates
if f"_{token}_" in path.name or f"_{token}." in path.name
]
if preferred:
if len(preferred) > 1:
print(f"Warning!!! multiple preferred files found for {prefix}: {preferred}")
return preferred[0]
if len(candidates) > 1:
print(f"Warning!!! multiple files found for {prefix}: {candidates}")
return candidates[0]
def get_signal_duration_second(signal_path: Path) -> int:
signal_fs = int(signal_path.stem.split("_")[-1])
with signal_path.open("r", encoding="utf-8", errors="ignore") as file_obj:
signal_length = sum(1 for _ in file_obj)
return signal_length // signal_fs
def process_one_signal(samp_id, show=False):
pass
sample_dir = org_signal_root_path / f"{samp_id}"
label_dir = label_root_path / f"{samp_id}"
tho_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Tho_Sync_*.txt"))
abd_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Abd_Sync_*.txt"))
flowp_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow P_Sync_*.txt"))
flowt_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow T_Sync_*.txt"))
spo2_signal_path = list((org_signal_root_path / f"{samp_id}").glob("SpO2_Sync_*.txt"))
stage_signal_path = list((org_signal_root_path / f"{samp_id}").glob("5_class_Sync_*.txt"))
tho_signal_path = resolve_sample_file(sample_dir, "Effort Tho")
abd_signal_path = resolve_sample_file(sample_dir, "Effort Abd")
flowp_signal_path = resolve_sample_file(sample_dir, "Flow P")
flowt_signal_path = resolve_sample_file(sample_dir, "Flow T")
spo2_signal_path = resolve_sample_file(sample_dir, "SpO2")
stage_signal_path = resolve_sample_file(sample_dir, "5_class")
label_path = resolve_sample_file(label_dir, "SA Label", suffix=".csv")
if not tho_signal_path:
raise FileNotFoundError(f"Effort Tho_Sync file not found for sample ID: {samp_id}")
tho_signal_path = tho_signal_path[0]
print(f"Processing Effort Tho_Sync signal file: {tho_signal_path}")
if not abd_signal_path:
raise FileNotFoundError(f"Effort Abd_Sync file not found for sample ID: {samp_id}")
abd_signal_path = abd_signal_path[0]
print(f"Processing Effort Abd_Sync signal file: {abd_signal_path}")
if not flowp_signal_path:
raise FileNotFoundError(f"Flow P_Sync file not found for sample ID: {samp_id}")
flowp_signal_path = flowp_signal_path[0]
print(f"Processing Flow P_Sync signal file: {flowp_signal_path}")
if not flowt_signal_path:
raise FileNotFoundError(f"Flow T_Sync file not found for sample ID: {samp_id}")
flowt_signal_path = flowt_signal_path[0]
print(f"Processing Flow T_Sync signal file: {flowt_signal_path}")
if not spo2_signal_path:
raise FileNotFoundError(f"SpO2_Sync file not found for sample ID: {samp_id}")
spo2_signal_path = spo2_signal_path[0]
print(f"Processing SpO2_Sync signal file: {spo2_signal_path}")
if not stage_signal_path:
raise FileNotFoundError(f"5_class_Sync file not found for sample ID: {samp_id}")
stage_signal_path = stage_signal_path[0]
print(f"Processing 5_class_Sync signal file: {stage_signal_path}")
label_path = (label_root_path / f"{samp_id}").glob("SA Label_Sync.csv")
if not label_path:
raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}")
label_path = list(label_path)[0]
print(f"Processing Label_corrected file: {label_path}")
print(f"Processing Effort Tho signal file: {tho_signal_path}")
print(f"Processing Effort Abd signal file: {abd_signal_path}")
print(f"Processing Flow P signal file: {flowp_signal_path}")
print(f"Processing Flow T signal file: {flowt_signal_path}")
print(f"Processing SpO2 signal file: {spo2_signal_path}")
print(f"Processing 5_class signal file: {stage_signal_path}")
print(f"Processing SA Label file: {label_path}")
#
# # 保存处理后的数据和标签
save_samp_path = save_path / f"{samp_id}"
save_samp_path.mkdir(parents=True, exist_ok=True)
signal_seconds = {
"Effort Tho": get_signal_duration_second(tho_signal_path),
"Effort Abd": get_signal_duration_second(abd_signal_path),
"Flow P": get_signal_duration_second(flowp_signal_path),
"Flow T": get_signal_duration_second(flowt_signal_path),
"SpO2": get_signal_duration_second(spo2_signal_path),
"5_class": get_signal_duration_second(stage_signal_path),
}
common_second = min(signal_seconds.values())
print(f"Sample {samp_id} signal seconds: {signal_seconds}")
print(f"Sample {samp_id} common_second: {common_second}")
# # # 读取信号数据
tho_data_raw, tho_length, tho_fs, tho_second = utils.read_signal_txt(tho_signal_path, dtype=float, verbose=True)
# abd_data_raw, abd_length, abd_fs, abd_second = utils.read_signal_txt(abd_signal_path, dtype=float, verbose=True)
# flowp_data_raw, flowp_length, flowp_fs, flowp_second = utils.read_signal_txt(flowp_signal_path, dtype=float, verbose=True)
# flowt_data_raw, flowt_length, flowt_fs, flowt_second = utils.read_signal_txt(flowt_signal_path, dtype=float, verbose=True)
# spo2_data_raw, spo2_length, spo2_fs, spo2_second = utils.read_signal_txt(spo2_signal_path, dtype=float, verbose=True)
stage_data_raw, stage_length, stage_fs, stage_second = utils.read_signal_txt(stage_signal_path, dtype=str, verbose=True)
@ -117,7 +135,7 @@ def process_one_signal(samp_id, show=False):
# spo2_fs = spo2_fs
label_data = utils.read_raw_psg_label(path=label_path)
event_mask, score_mask = utils.generate_event_mask(signal_second=tho_second, event_df=label_data, use_correct=False, with_score=False)
event_mask, score_mask = utils.generate_event_mask(signal_second=common_second, event_df=label_data, use_correct=False, with_score=False)
# event_mask > 0 的部分为1其他为0
score_mask = np.where(event_mask > 0, 1, 0)
@ -128,9 +146,7 @@ def process_one_signal(samp_id, show=False):
# 合并短于120秒的觉醒区间
wake_mask = utils.merge_short_gaps(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), max_gap_sec=60)
disable_label = wake_mask
disable_label = disable_label[:tho_second]
disable_label = wake_mask[:common_second]
# 复制事件文件 到保存路径
@ -139,7 +155,7 @@ def process_one_signal(samp_id, show=False):
#
# 新建一个dataframe分别是秒数、SA标签
save_dict = {
"Second": np.arange(tho_second),
"Second": np.arange(common_second),
"SA_Label": event_mask,
"SA_Score": score_mask,
"Disable_Label": disable_label,

View File

@ -307,7 +307,10 @@ def read_psg_channel(path_str: Union[str, Path], channel_number: list[int], verb
ch_path = list(path.glob(f"{ch_name}*.txt"))
if not any(ch_path):
raise FileNotFoundError(f"PSG Channel file not found: {ch_path}")
raise FileNotFoundError(
f"PSG channel '{ch_name}' file not found in {path} "
f"(glob pattern: {ch_name}*.txt)"
)
if len(ch_path) > 1:
print(f"Warning!!! PSG Channel file more than one: {ch_path}")