feat: 添加数据处理功能,重构信号读取与处理逻辑,更新配置文件以支持新样本ID

This commit is contained in:
marques 2026-03-30 09:37:14 +08:00
parent a42a482e1c
commit 3adcf00abb
5 changed files with 329 additions and 102 deletions

4
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,4 @@
{
"python-envs.defaultEnvManager": "ms-python.python:conda",
"python-envs.defaultPackageManager": "ms-python.python:conda"
}

View File

@ -1,25 +1,163 @@
import multiprocessing import multiprocessing
import signal
import sys import sys
import time
from pathlib import Path from pathlib import Path
import os import os
import numpy as np import numpy as np
from utils import N2Chn
os.environ['DISPLAY'] = "localhost:10.0" os.environ['DISPLAY'] = "localhost:10.0"
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent project_root_path = Path(__file__).resolve().parent.parent
import utils import utils
from utils import N2Chn
import signal_method import signal_method
import draw_tools import draw_tools
import shutil import shutil
import gc import gc
DEFAULT_YAML_PATH = project_root_path / "dataset_config/HYS_PSG_config.yaml"
conf = None
select_ids = None
root_path = None
mask_path = None
save_path = None
visual_path = None
dataset_config = None
org_signal_root_path = None
psg_signal_root_path = None
save_processed_signal_path = None
save_segment_list_path = None
save_processed_label_path = None
def get_missing_psg_channels(samp_id, channel_number=None):
ensure_runtime_initialized()
if channel_number is None:
channel_number = [1, 2, 3, 4, 5, 6, 7, 8]
sample_path = psg_signal_root_path / f"{samp_id}"
if not sample_path.exists():
return [f"PSG dir missing: {sample_path}"]
missing_channels = []
for ch_id in channel_number:
ch_name = N2Chn[ch_id]
if not any(sample_path.glob(f"{ch_name}*.txt")):
missing_channels.append(ch_name)
return missing_channels
def filter_valid_psg_samples(sample_ids, verbose=True):
valid_ids = []
skipped_ids = []
for samp_id in sample_ids:
missing_channels = get_missing_psg_channels(samp_id)
if missing_channels:
skipped_ids.append((samp_id, missing_channels))
if verbose:
print(
f"Skipping sample {samp_id}: missing PSG channels {missing_channels}",
flush=True,
)
continue
valid_ids.append(samp_id)
if verbose and skipped_ids:
print(
f"Filtered out {len(skipped_ids)} sample(s) with incomplete PSG inputs.",
flush=True,
)
return valid_ids, skipped_ids
def initialize_runtime(yaml_path=DEFAULT_YAML_PATH):
global conf
global select_ids
global root_path
global mask_path
global save_path
global visual_path
global dataset_config
global org_signal_root_path
global psg_signal_root_path
global save_processed_signal_path
global save_segment_list_path
global save_processed_label_path
yaml_path = Path(yaml_path)
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
mask_path = Path(conf["mask_save_path"])
save_path = Path(conf["dataset_config"]["dataset_save_path"])
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
dataset_config = conf["dataset_config"]
visual_path.mkdir(parents=True, exist_ok=True)
save_processed_signal_path = save_path / "Signals"
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
save_segment_list_path = save_path / "Segments_List"
save_segment_list_path.mkdir(parents=True, exist_ok=True)
save_processed_label_path = save_path / "Labels"
save_processed_label_path.mkdir(parents=True, exist_ok=True)
org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned"
def ensure_runtime_initialized():
if conf is None or psg_signal_root_path is None:
initialize_runtime()
def sanitize_rpeak_indices(samp_id, rpeak_indices, signal_length, verbose=True):
rpeak_indices = np.asarray(rpeak_indices, dtype=int)
valid_mask = (rpeak_indices >= 0) & (rpeak_indices < signal_length)
invalid_count = int((~valid_mask).sum())
if invalid_count > 0:
invalid_indices = rpeak_indices[~valid_mask]
print(
f"Sample {samp_id}: dropping {invalid_count} invalid Rpeak index/indices "
f"outside [0, {signal_length - 1}]. "
f"min_invalid={invalid_indices.min()}, max_invalid={invalid_indices.max()}",
flush=True,
)
rpeak_indices = rpeak_indices[valid_mask]
if rpeak_indices.size == 0:
raise ValueError(f"Sample {samp_id}: no valid Rpeak indices remain after bounds check.")
rpeak_indices = np.unique(rpeak_indices)
if rpeak_indices.size < 2:
raise ValueError(f"Sample {samp_id}: fewer than 2 valid Rpeak indices remain after bounds check.")
return rpeak_indices
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None): def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None):
ensure_runtime_initialized()
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose) psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
psg_data["Rpeak"]["data"] = sanitize_rpeak_indices(
samp_id=samp_id,
rpeak_indices=psg_data["Rpeak"]["data"],
signal_length=psg_data["ECG_Sync"]["length"],
verbose=verbose,
)
total_seconds = min( total_seconds = min(
psg_data[i]["second"] for i in N2Chn.values() if i != "Rpeak" psg_data[i]["second"] for i in N2Chn.values() if i != "Rpeak"
@ -283,10 +421,28 @@ def multiprocess_entry(_progress, task_id, _id):
build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id) build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id)
def _init_pool_worker(yaml_path=DEFAULT_YAML_PATH):
# 让主进程统一响应 Ctrl+C避免父子进程同时处理中断导致退出卡住。
signal.signal(signal.SIGINT, signal.SIG_IGN)
initialize_runtime(yaml_path)
def _shutdown_executor(executor, wait, cancel_futures=False):
if executor is None:
return
try:
executor.shutdown(wait=wait, cancel_futures=cancel_futures)
except TypeError:
executor.shutdown(wait=wait)
def multiprocess_with_tqdm(args_list, n_processes): def multiprocess_with_tqdm(args_list, n_processes):
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from rich import progress from rich import progress
yaml_path = str(DEFAULT_YAML_PATH)
with progress.Progress( with progress.Progress(
"[progress.description]{task.description}", "[progress.description]{task.description}",
progress.BarColumn(), progress.BarColumn(),
@ -301,7 +457,13 @@ def multiprocess_with_tqdm(args_list, n_processes):
with multiprocessing.Manager() as manager: with multiprocessing.Manager() as manager:
_progress = manager.dict() _progress = manager.dict()
overall_progress_task = progress.add_task("[green]All jobs progress:") overall_progress_task = progress.add_task("[green]All jobs progress:")
with ProcessPoolExecutor(max_workers=n_processes, mp_context=multiprocessing.get_context("spawn")) as executor: executor = ProcessPoolExecutor(
max_workers=n_processes,
mp_context=multiprocessing.get_context("spawn"),
initializer=_init_pool_worker,
initargs=(yaml_path,),
)
try:
for i_args in range(len(args_list)): for i_args in range(len(args_list)):
args = args_list[i_args] args = args_list[i_args]
task_id = progress.add_task(f"task {i_args}", visible=True) task_id = progress.add_task(f"task {i_args}", visible=True)
@ -322,69 +484,90 @@ def multiprocess_with_tqdm(args_list, n_processes):
total=update_data.get("total", 0), total=update_data.get("total", 0),
description=desc description=desc
) )
time.sleep(0.2)
# raise any errors: # raise any errors:
for future in futures: for future in futures:
future.result() future.result()
except KeyboardInterrupt:
print("\nKeyboardInterrupt received, cancelling pending jobs...", flush=True)
for future in futures:
future.cancel()
_shutdown_executor(executor, wait=False, cancel_futures=True)
executor = None
raise SystemExit(130)
finally:
_shutdown_executor(executor, wait=True)
def multiprocess_with_pool(args_list, n_processes): def multiprocess_with_pool(args_list, n_processes):
"""使用Pool每个worker处理固定数量任务后重启""" """使用Pool每个worker处理固定数量任务后重启"""
from multiprocessing import Pool if not args_list:
return
# maxtasksperchild 设置每个worker处理多少任务后重启释放内存 ctx = multiprocessing.get_context("spawn")
with Pool(processes=n_processes, maxtasksperchild=2) as pool: yaml_path = str(DEFAULT_YAML_PATH)
results = [] pool = ctx.Pool(
processes=n_processes,
maxtasksperchild=2,
initializer=_init_pool_worker,
initargs=(yaml_path,),
)
pending_results = {}
completed = 0
try:
for samp_id in args_list: for samp_id in args_list:
result = pool.apply_async( pending_results[samp_id] = pool.apply_async(
build_HYS_dataset_segment, build_HYS_dataset_segment,
args=(samp_id, False, True, False, None, None) args=(samp_id, False, True, False, None, None)
) )
results.append(result)
# 等待所有任务完成
for i, result in enumerate(results):
try:
result.get()
print(f"Completed {i + 1}/{len(args_list)}")
except Exception as e:
print(f"Error processing task {i}: {e}")
pool.close() pool.close()
while pending_results:
finished_ids = []
for samp_id, result in pending_results.items():
if not result.ready():
continue
try:
result.get()
completed += 1
print(f"Completed {completed}/{len(args_list)}: {samp_id}", flush=True)
except Exception as e:
completed += 1
print(f"Error processing {samp_id}: {e}", flush=True)
finished_ids.append(samp_id)
for samp_id in finished_ids:
pending_results.pop(samp_id, None)
if pending_results:
time.sleep(0.5)
pool.join() pool.join()
except KeyboardInterrupt:
print("\nKeyboardInterrupt received, terminating worker processes...", flush=True)
pool.terminate()
pool.join()
raise SystemExit(130)
except Exception:
pool.terminate()
pool.join()
raise
if __name__ == '__main__': if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml" yaml_path = DEFAULT_YAML_PATH
initialize_runtime(yaml_path)
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
mask_path = Path(conf["mask_save_path"])
save_path = Path(conf["dataset_config"]["dataset_save_path"])
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
dataset_config = conf["dataset_config"]
visual_path.mkdir(parents=True, exist_ok=True)
save_processed_signal_path = save_path / "Signals"
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
save_segment_list_path = save_path / "Segments_List"
save_segment_list_path.mkdir(parents=True, exist_ok=True)
save_processed_label_path = save_path / "Labels"
save_processed_label_path.mkdir(parents=True, exist_ok=True)
# print(f"select_ids: {select_ids}")
# print(f"root_path: {root_path}")
# print(f"save_path: {save_path}")
# print(f"visual_path: {visual_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned"
print(select_ids) print(select_ids)
valid_select_ids, skipped_select_ids = filter_valid_psg_samples(select_ids, verbose=True)
print(f"Valid PSG samples: {len(valid_select_ids)} / {len(select_ids)}", flush=True)
# build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True) # build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True)
# for samp_id in select_ids: # for samp_id in select_ids:
@ -392,4 +575,4 @@ if __name__ == '__main__':
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) # build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
# multiprocess_with_tqdm(args_list=select_ids, n_processes=8) # multiprocess_with_tqdm(args_list=select_ids, n_processes=8)
multiprocess_with_pool(args_list=select_ids, n_processes=8) # multiprocess_with_pool(args_list=valid_select_ids, n_processes=8)

View File

@ -1,27 +1,48 @@
select_ids: select_ids:
- 54 - 1000
- 88 - 1004
- 1006
- 1009
- 1010
- 1300
- 1302
- 1308
- 1314
- 1354
- 1378
- 220 - 220
- 221 - 221
- 229 - 229
- 282 - 282
- 285
- 286 - 286
- 54
- 541 - 541
- 579
- 582 - 582
- 670 - 670
- 671
- 683 - 683
- 684 - 684
- 686
- 703
- 704
- 726
- 735 - 735
- 736
- 88
- 893
- 933 - 933
- 935 - 935
- 950 - 950
- 952 - 952
- 954
- 955
- 956
- 960 - 960
- 961
- 962 - 962
- 967 - 967
- 1302 - 971
- 972
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG

View File

@ -35,59 +35,77 @@ import os
os.environ['DISPLAY'] = "localhost:10.0" os.environ['DISPLAY'] = "localhost:10.0"
def resolve_sample_file(sample_dir: Path, prefix: str, suffix=".txt", prefer_tokens=("Sync", "RoughCut")) -> Path:
candidates = sorted(sample_dir.glob(f"{prefix}*{suffix}"))
if not candidates:
if sample_dir.exists():
available_files = ", ".join(sorted(path.name for path in sample_dir.iterdir()))
else:
available_files = "<sample dir missing>"
raise FileNotFoundError(
f"{prefix} file not found in {sample_dir}. "
f"searched pattern: {prefix}*{suffix}. available: {available_files}"
)
for token in prefer_tokens:
preferred = [
path for path in candidates
if f"_{token}_" in path.name or f"_{token}." in path.name
]
if preferred:
if len(preferred) > 1:
print(f"Warning!!! multiple preferred files found for {prefix}: {preferred}")
return preferred[0]
if len(candidates) > 1:
print(f"Warning!!! multiple files found for {prefix}: {candidates}")
return candidates[0]
def get_signal_duration_second(signal_path: Path) -> int:
signal_fs = int(signal_path.stem.split("_")[-1])
with signal_path.open("r", encoding="utf-8", errors="ignore") as file_obj:
signal_length = sum(1 for _ in file_obj)
return signal_length // signal_fs
def process_one_signal(samp_id, show=False): def process_one_signal(samp_id, show=False):
pass sample_dir = org_signal_root_path / f"{samp_id}"
label_dir = label_root_path / f"{samp_id}"
tho_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Tho_Sync_*.txt")) tho_signal_path = resolve_sample_file(sample_dir, "Effort Tho")
abd_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Abd_Sync_*.txt")) abd_signal_path = resolve_sample_file(sample_dir, "Effort Abd")
flowp_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow P_Sync_*.txt")) flowp_signal_path = resolve_sample_file(sample_dir, "Flow P")
flowt_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow T_Sync_*.txt")) flowt_signal_path = resolve_sample_file(sample_dir, "Flow T")
spo2_signal_path = list((org_signal_root_path / f"{samp_id}").glob("SpO2_Sync_*.txt")) spo2_signal_path = resolve_sample_file(sample_dir, "SpO2")
stage_signal_path = list((org_signal_root_path / f"{samp_id}").glob("5_class_Sync_*.txt")) stage_signal_path = resolve_sample_file(sample_dir, "5_class")
label_path = resolve_sample_file(label_dir, "SA Label", suffix=".csv")
if not tho_signal_path: print(f"Processing Effort Tho signal file: {tho_signal_path}")
raise FileNotFoundError(f"Effort Tho_Sync file not found for sample ID: {samp_id}") print(f"Processing Effort Abd signal file: {abd_signal_path}")
tho_signal_path = tho_signal_path[0] print(f"Processing Flow P signal file: {flowp_signal_path}")
print(f"Processing Effort Tho_Sync signal file: {tho_signal_path}") print(f"Processing Flow T signal file: {flowt_signal_path}")
if not abd_signal_path: print(f"Processing SpO2 signal file: {spo2_signal_path}")
raise FileNotFoundError(f"Effort Abd_Sync file not found for sample ID: {samp_id}") print(f"Processing 5_class signal file: {stage_signal_path}")
abd_signal_path = abd_signal_path[0] print(f"Processing SA Label file: {label_path}")
print(f"Processing Effort Abd_Sync signal file: {abd_signal_path}")
if not flowp_signal_path:
raise FileNotFoundError(f"Flow P_Sync file not found for sample ID: {samp_id}")
flowp_signal_path = flowp_signal_path[0]
print(f"Processing Flow P_Sync signal file: {flowp_signal_path}")
if not flowt_signal_path:
raise FileNotFoundError(f"Flow T_Sync file not found for sample ID: {samp_id}")
flowt_signal_path = flowt_signal_path[0]
print(f"Processing Flow T_Sync signal file: {flowt_signal_path}")
if not spo2_signal_path:
raise FileNotFoundError(f"SpO2_Sync file not found for sample ID: {samp_id}")
spo2_signal_path = spo2_signal_path[0]
print(f"Processing SpO2_Sync signal file: {spo2_signal_path}")
if not stage_signal_path:
raise FileNotFoundError(f"5_class_Sync file not found for sample ID: {samp_id}")
stage_signal_path = stage_signal_path[0]
print(f"Processing 5_class_Sync signal file: {stage_signal_path}")
label_path = (label_root_path / f"{samp_id}").glob("SA Label_Sync.csv")
if not label_path:
raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}")
label_path = list(label_path)[0]
print(f"Processing Label_corrected file: {label_path}")
# #
# # 保存处理后的数据和标签 # # 保存处理后的数据和标签
save_samp_path = save_path / f"{samp_id}" save_samp_path = save_path / f"{samp_id}"
save_samp_path.mkdir(parents=True, exist_ok=True) save_samp_path.mkdir(parents=True, exist_ok=True)
signal_seconds = {
"Effort Tho": get_signal_duration_second(tho_signal_path),
"Effort Abd": get_signal_duration_second(abd_signal_path),
"Flow P": get_signal_duration_second(flowp_signal_path),
"Flow T": get_signal_duration_second(flowt_signal_path),
"SpO2": get_signal_duration_second(spo2_signal_path),
"5_class": get_signal_duration_second(stage_signal_path),
}
common_second = min(signal_seconds.values())
print(f"Sample {samp_id} signal seconds: {signal_seconds}")
print(f"Sample {samp_id} common_second: {common_second}")
# # # 读取信号数据 # # # 读取信号数据
tho_data_raw, tho_length, tho_fs, tho_second = utils.read_signal_txt(tho_signal_path, dtype=float, verbose=True)
# abd_data_raw, abd_length, abd_fs, abd_second = utils.read_signal_txt(abd_signal_path, dtype=float, verbose=True)
# flowp_data_raw, flowp_length, flowp_fs, flowp_second = utils.read_signal_txt(flowp_signal_path, dtype=float, verbose=True)
# flowt_data_raw, flowt_length, flowt_fs, flowt_second = utils.read_signal_txt(flowt_signal_path, dtype=float, verbose=True)
# spo2_data_raw, spo2_length, spo2_fs, spo2_second = utils.read_signal_txt(spo2_signal_path, dtype=float, verbose=True)
stage_data_raw, stage_length, stage_fs, stage_second = utils.read_signal_txt(stage_signal_path, dtype=str, verbose=True) stage_data_raw, stage_length, stage_fs, stage_second = utils.read_signal_txt(stage_signal_path, dtype=str, verbose=True)
@ -117,7 +135,7 @@ def process_one_signal(samp_id, show=False):
# spo2_fs = spo2_fs # spo2_fs = spo2_fs
label_data = utils.read_raw_psg_label(path=label_path) label_data = utils.read_raw_psg_label(path=label_path)
event_mask, score_mask = utils.generate_event_mask(signal_second=tho_second, event_df=label_data, use_correct=False, with_score=False) event_mask, score_mask = utils.generate_event_mask(signal_second=common_second, event_df=label_data, use_correct=False, with_score=False)
# event_mask > 0 的部分为1其他为0 # event_mask > 0 的部分为1其他为0
score_mask = np.where(event_mask > 0, 1, 0) score_mask = np.where(event_mask > 0, 1, 0)
@ -128,9 +146,7 @@ def process_one_signal(samp_id, show=False):
# 合并短于120秒的觉醒区间 # 合并短于120秒的觉醒区间
wake_mask = utils.merge_short_gaps(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), max_gap_sec=60) wake_mask = utils.merge_short_gaps(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), max_gap_sec=60)
disable_label = wake_mask disable_label = wake_mask[:common_second]
disable_label = disable_label[:tho_second]
# 复制事件文件 到保存路径 # 复制事件文件 到保存路径
@ -139,7 +155,7 @@ def process_one_signal(samp_id, show=False):
# #
# 新建一个dataframe分别是秒数、SA标签 # 新建一个dataframe分别是秒数、SA标签
save_dict = { save_dict = {
"Second": np.arange(tho_second), "Second": np.arange(common_second),
"SA_Label": event_mask, "SA_Label": event_mask,
"SA_Score": score_mask, "SA_Score": score_mask,
"Disable_Label": disable_label, "Disable_Label": disable_label,

View File

@ -307,7 +307,10 @@ def read_psg_channel(path_str: Union[str, Path], channel_number: list[int], verb
ch_path = list(path.glob(f"{ch_name}*.txt")) ch_path = list(path.glob(f"{ch_name}*.txt"))
if not any(ch_path): if not any(ch_path):
raise FileNotFoundError(f"PSG Channel file not found: {ch_path}") raise FileNotFoundError(
f"PSG channel '{ch_name}' file not found in {path} "
f"(glob pattern: {ch_name}*.txt)"
)
if len(ch_path) > 1: if len(ch_path) > 1:
print(f"Warning!!! PSG Channel file more than one: {ch_path}") print(f"Warning!!! PSG Channel file more than one: {ch_path}")