feat: 添加数据处理功能,重构信号读取与处理逻辑,更新配置文件以支持新样本ID
This commit is contained in:
parent
a42a482e1c
commit
3adcf00abb
4
.vscode/settings.json
vendored
Normal file
4
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
{
|
||||
"python-envs.defaultEnvManager": "ms-python.python:conda",
|
||||
"python-envs.defaultPackageManager": "ms-python.python:conda"
|
||||
}
|
||||
@ -1,25 +1,163 @@
|
||||
import multiprocessing
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from utils import N2Chn
|
||||
|
||||
os.environ['DISPLAY'] = "localhost:10.0"
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
project_root_path = Path(__file__).resolve().parent.parent
|
||||
|
||||
import utils
|
||||
from utils import N2Chn
|
||||
import signal_method
|
||||
import draw_tools
|
||||
import shutil
|
||||
import gc
|
||||
|
||||
DEFAULT_YAML_PATH = project_root_path / "dataset_config/HYS_PSG_config.yaml"
|
||||
|
||||
conf = None
|
||||
select_ids = None
|
||||
root_path = None
|
||||
mask_path = None
|
||||
save_path = None
|
||||
visual_path = None
|
||||
dataset_config = None
|
||||
org_signal_root_path = None
|
||||
psg_signal_root_path = None
|
||||
save_processed_signal_path = None
|
||||
save_segment_list_path = None
|
||||
save_processed_label_path = None
|
||||
|
||||
|
||||
def get_missing_psg_channels(samp_id, channel_number=None):
|
||||
ensure_runtime_initialized()
|
||||
|
||||
if channel_number is None:
|
||||
channel_number = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
|
||||
sample_path = psg_signal_root_path / f"{samp_id}"
|
||||
if not sample_path.exists():
|
||||
return [f"PSG dir missing: {sample_path}"]
|
||||
|
||||
missing_channels = []
|
||||
for ch_id in channel_number:
|
||||
ch_name = N2Chn[ch_id]
|
||||
if not any(sample_path.glob(f"{ch_name}*.txt")):
|
||||
missing_channels.append(ch_name)
|
||||
|
||||
return missing_channels
|
||||
|
||||
|
||||
def filter_valid_psg_samples(sample_ids, verbose=True):
|
||||
valid_ids = []
|
||||
skipped_ids = []
|
||||
|
||||
for samp_id in sample_ids:
|
||||
missing_channels = get_missing_psg_channels(samp_id)
|
||||
if missing_channels:
|
||||
skipped_ids.append((samp_id, missing_channels))
|
||||
if verbose:
|
||||
print(
|
||||
f"Skipping sample {samp_id}: missing PSG channels {missing_channels}",
|
||||
flush=True,
|
||||
)
|
||||
continue
|
||||
|
||||
valid_ids.append(samp_id)
|
||||
|
||||
if verbose and skipped_ids:
|
||||
print(
|
||||
f"Filtered out {len(skipped_ids)} sample(s) with incomplete PSG inputs.",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return valid_ids, skipped_ids
|
||||
|
||||
|
||||
def initialize_runtime(yaml_path=DEFAULT_YAML_PATH):
|
||||
global conf
|
||||
global select_ids
|
||||
global root_path
|
||||
global mask_path
|
||||
global save_path
|
||||
global visual_path
|
||||
global dataset_config
|
||||
global org_signal_root_path
|
||||
global psg_signal_root_path
|
||||
global save_processed_signal_path
|
||||
global save_segment_list_path
|
||||
global save_processed_label_path
|
||||
|
||||
yaml_path = Path(yaml_path)
|
||||
|
||||
conf = utils.load_dataset_conf(yaml_path)
|
||||
select_ids = conf["select_ids"]
|
||||
root_path = Path(conf["root_path"])
|
||||
mask_path = Path(conf["mask_save_path"])
|
||||
save_path = Path(conf["dataset_config"]["dataset_save_path"])
|
||||
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
|
||||
dataset_config = conf["dataset_config"]
|
||||
|
||||
visual_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_processed_signal_path = save_path / "Signals"
|
||||
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_segment_list_path = save_path / "Segments_List"
|
||||
save_segment_list_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_processed_label_path = save_path / "Labels"
|
||||
save_processed_label_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
org_signal_root_path = root_path / "OrgBCG_Aligned"
|
||||
psg_signal_root_path = root_path / "PSG_Aligned"
|
||||
|
||||
|
||||
def ensure_runtime_initialized():
|
||||
if conf is None or psg_signal_root_path is None:
|
||||
initialize_runtime()
|
||||
|
||||
|
||||
def sanitize_rpeak_indices(samp_id, rpeak_indices, signal_length, verbose=True):
|
||||
rpeak_indices = np.asarray(rpeak_indices, dtype=int)
|
||||
|
||||
valid_mask = (rpeak_indices >= 0) & (rpeak_indices < signal_length)
|
||||
invalid_count = int((~valid_mask).sum())
|
||||
if invalid_count > 0:
|
||||
invalid_indices = rpeak_indices[~valid_mask]
|
||||
print(
|
||||
f"Sample {samp_id}: dropping {invalid_count} invalid Rpeak index/indices "
|
||||
f"outside [0, {signal_length - 1}]. "
|
||||
f"min_invalid={invalid_indices.min()}, max_invalid={invalid_indices.max()}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
rpeak_indices = rpeak_indices[valid_mask]
|
||||
if rpeak_indices.size == 0:
|
||||
raise ValueError(f"Sample {samp_id}: no valid Rpeak indices remain after bounds check.")
|
||||
|
||||
rpeak_indices = np.unique(rpeak_indices)
|
||||
if rpeak_indices.size < 2:
|
||||
raise ValueError(f"Sample {samp_id}: fewer than 2 valid Rpeak indices remain after bounds check.")
|
||||
|
||||
return rpeak_indices
|
||||
|
||||
|
||||
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None):
|
||||
ensure_runtime_initialized()
|
||||
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
|
||||
psg_data["Rpeak"]["data"] = sanitize_rpeak_indices(
|
||||
samp_id=samp_id,
|
||||
rpeak_indices=psg_data["Rpeak"]["data"],
|
||||
signal_length=psg_data["ECG_Sync"]["length"],
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
total_seconds = min(
|
||||
psg_data[i]["second"] for i in N2Chn.values() if i != "Rpeak"
|
||||
@ -283,10 +421,28 @@ def multiprocess_entry(_progress, task_id, _id):
|
||||
build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id)
|
||||
|
||||
|
||||
def _init_pool_worker(yaml_path=DEFAULT_YAML_PATH):
|
||||
# 让主进程统一响应 Ctrl+C,避免父子进程同时处理中断导致退出卡住。
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
initialize_runtime(yaml_path)
|
||||
|
||||
|
||||
def _shutdown_executor(executor, wait, cancel_futures=False):
|
||||
if executor is None:
|
||||
return
|
||||
|
||||
try:
|
||||
executor.shutdown(wait=wait, cancel_futures=cancel_futures)
|
||||
except TypeError:
|
||||
executor.shutdown(wait=wait)
|
||||
|
||||
|
||||
def multiprocess_with_tqdm(args_list, n_processes):
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from rich import progress
|
||||
|
||||
yaml_path = str(DEFAULT_YAML_PATH)
|
||||
|
||||
with progress.Progress(
|
||||
"[progress.description]{task.description}",
|
||||
progress.BarColumn(),
|
||||
@ -301,7 +457,13 @@ def multiprocess_with_tqdm(args_list, n_processes):
|
||||
with multiprocessing.Manager() as manager:
|
||||
_progress = manager.dict()
|
||||
overall_progress_task = progress.add_task("[green]All jobs progress:")
|
||||
with ProcessPoolExecutor(max_workers=n_processes, mp_context=multiprocessing.get_context("spawn")) as executor:
|
||||
executor = ProcessPoolExecutor(
|
||||
max_workers=n_processes,
|
||||
mp_context=multiprocessing.get_context("spawn"),
|
||||
initializer=_init_pool_worker,
|
||||
initargs=(yaml_path,),
|
||||
)
|
||||
try:
|
||||
for i_args in range(len(args_list)):
|
||||
args = args_list[i_args]
|
||||
task_id = progress.add_task(f"task {i_args}", visible=True)
|
||||
@ -322,69 +484,90 @@ def multiprocess_with_tqdm(args_list, n_processes):
|
||||
total=update_data.get("total", 0),
|
||||
description=desc
|
||||
)
|
||||
time.sleep(0.2)
|
||||
|
||||
# raise any errors:
|
||||
for future in futures:
|
||||
future.result()
|
||||
except KeyboardInterrupt:
|
||||
print("\nKeyboardInterrupt received, cancelling pending jobs...", flush=True)
|
||||
for future in futures:
|
||||
future.cancel()
|
||||
_shutdown_executor(executor, wait=False, cancel_futures=True)
|
||||
executor = None
|
||||
raise SystemExit(130)
|
||||
finally:
|
||||
_shutdown_executor(executor, wait=True)
|
||||
|
||||
|
||||
def multiprocess_with_pool(args_list, n_processes):
|
||||
"""使用Pool,每个worker处理固定数量任务后重启"""
|
||||
from multiprocessing import Pool
|
||||
if not args_list:
|
||||
return
|
||||
|
||||
# maxtasksperchild 设置每个worker处理多少任务后重启(释放内存)
|
||||
with Pool(processes=n_processes, maxtasksperchild=2) as pool:
|
||||
results = []
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
yaml_path = str(DEFAULT_YAML_PATH)
|
||||
pool = ctx.Pool(
|
||||
processes=n_processes,
|
||||
maxtasksperchild=2,
|
||||
initializer=_init_pool_worker,
|
||||
initargs=(yaml_path,),
|
||||
)
|
||||
|
||||
pending_results = {}
|
||||
completed = 0
|
||||
|
||||
try:
|
||||
for samp_id in args_list:
|
||||
result = pool.apply_async(
|
||||
pending_results[samp_id] = pool.apply_async(
|
||||
build_HYS_dataset_segment,
|
||||
args=(samp_id, False, True, False, None, None)
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# 等待所有任务完成
|
||||
for i, result in enumerate(results):
|
||||
try:
|
||||
result.get()
|
||||
print(f"Completed {i + 1}/{len(args_list)}")
|
||||
except Exception as e:
|
||||
print(f"Error processing task {i}: {e}")
|
||||
|
||||
pool.close()
|
||||
|
||||
while pending_results:
|
||||
finished_ids = []
|
||||
for samp_id, result in pending_results.items():
|
||||
if not result.ready():
|
||||
continue
|
||||
|
||||
try:
|
||||
result.get()
|
||||
completed += 1
|
||||
print(f"Completed {completed}/{len(args_list)}: {samp_id}", flush=True)
|
||||
except Exception as e:
|
||||
completed += 1
|
||||
print(f"Error processing {samp_id}: {e}", flush=True)
|
||||
|
||||
finished_ids.append(samp_id)
|
||||
|
||||
for samp_id in finished_ids:
|
||||
pending_results.pop(samp_id, None)
|
||||
|
||||
if pending_results:
|
||||
time.sleep(0.5)
|
||||
|
||||
pool.join()
|
||||
except KeyboardInterrupt:
|
||||
print("\nKeyboardInterrupt received, terminating worker processes...", flush=True)
|
||||
pool.terminate()
|
||||
pool.join()
|
||||
raise SystemExit(130)
|
||||
except Exception:
|
||||
pool.terminate()
|
||||
pool.join()
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml"
|
||||
|
||||
conf = utils.load_dataset_conf(yaml_path)
|
||||
select_ids = conf["select_ids"]
|
||||
root_path = Path(conf["root_path"])
|
||||
mask_path = Path(conf["mask_save_path"])
|
||||
save_path = Path(conf["dataset_config"]["dataset_save_path"])
|
||||
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
|
||||
dataset_config = conf["dataset_config"]
|
||||
|
||||
visual_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_processed_signal_path = save_path / "Signals"
|
||||
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_segment_list_path = save_path / "Segments_List"
|
||||
save_segment_list_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_processed_label_path = save_path / "Labels"
|
||||
save_processed_label_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# print(f"select_ids: {select_ids}")
|
||||
# print(f"root_path: {root_path}")
|
||||
# print(f"save_path: {save_path}")
|
||||
# print(f"visual_path: {visual_path}")
|
||||
|
||||
org_signal_root_path = root_path / "OrgBCG_Aligned"
|
||||
psg_signal_root_path = root_path / "PSG_Aligned"
|
||||
yaml_path = DEFAULT_YAML_PATH
|
||||
initialize_runtime(yaml_path)
|
||||
print(select_ids)
|
||||
|
||||
valid_select_ids, skipped_select_ids = filter_valid_psg_samples(select_ids, verbose=True)
|
||||
print(f"Valid PSG samples: {len(valid_select_ids)} / {len(select_ids)}", flush=True)
|
||||
|
||||
# build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True)
|
||||
|
||||
# for samp_id in select_ids:
|
||||
@ -392,4 +575,4 @@ if __name__ == '__main__':
|
||||
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
|
||||
|
||||
# multiprocess_with_tqdm(args_list=select_ids, n_processes=8)
|
||||
multiprocess_with_pool(args_list=select_ids, n_processes=8)
|
||||
# multiprocess_with_pool(args_list=valid_select_ids, n_processes=8)
|
||||
|
||||
@ -1,27 +1,48 @@
|
||||
select_ids:
|
||||
- 54
|
||||
- 88
|
||||
- 1000
|
||||
- 1004
|
||||
- 1006
|
||||
- 1009
|
||||
- 1010
|
||||
- 1300
|
||||
- 1302
|
||||
- 1308
|
||||
- 1314
|
||||
- 1354
|
||||
- 1378
|
||||
- 220
|
||||
- 221
|
||||
- 229
|
||||
- 282
|
||||
- 285
|
||||
- 286
|
||||
- 54
|
||||
- 541
|
||||
- 579
|
||||
- 582
|
||||
- 670
|
||||
- 671
|
||||
- 683
|
||||
- 684
|
||||
- 686
|
||||
- 703
|
||||
- 704
|
||||
- 726
|
||||
- 735
|
||||
- 736
|
||||
- 88
|
||||
- 893
|
||||
- 933
|
||||
- 935
|
||||
- 950
|
||||
- 952
|
||||
- 954
|
||||
- 955
|
||||
- 956
|
||||
- 960
|
||||
- 961
|
||||
- 962
|
||||
- 967
|
||||
- 1302
|
||||
- 971
|
||||
- 972
|
||||
|
||||
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
|
||||
mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG
|
||||
|
||||
@ -35,59 +35,77 @@ import os
|
||||
os.environ['DISPLAY'] = "localhost:10.0"
|
||||
|
||||
|
||||
def resolve_sample_file(sample_dir: Path, prefix: str, suffix=".txt", prefer_tokens=("Sync", "RoughCut")) -> Path:
|
||||
candidates = sorted(sample_dir.glob(f"{prefix}*{suffix}"))
|
||||
if not candidates:
|
||||
if sample_dir.exists():
|
||||
available_files = ", ".join(sorted(path.name for path in sample_dir.iterdir()))
|
||||
else:
|
||||
available_files = "<sample dir missing>"
|
||||
raise FileNotFoundError(
|
||||
f"{prefix} file not found in {sample_dir}. "
|
||||
f"searched pattern: {prefix}*{suffix}. available: {available_files}"
|
||||
)
|
||||
|
||||
for token in prefer_tokens:
|
||||
preferred = [
|
||||
path for path in candidates
|
||||
if f"_{token}_" in path.name or f"_{token}." in path.name
|
||||
]
|
||||
if preferred:
|
||||
if len(preferred) > 1:
|
||||
print(f"Warning!!! multiple preferred files found for {prefix}: {preferred}")
|
||||
return preferred[0]
|
||||
|
||||
if len(candidates) > 1:
|
||||
print(f"Warning!!! multiple files found for {prefix}: {candidates}")
|
||||
return candidates[0]
|
||||
|
||||
|
||||
def get_signal_duration_second(signal_path: Path) -> int:
|
||||
signal_fs = int(signal_path.stem.split("_")[-1])
|
||||
with signal_path.open("r", encoding="utf-8", errors="ignore") as file_obj:
|
||||
signal_length = sum(1 for _ in file_obj)
|
||||
return signal_length // signal_fs
|
||||
|
||||
|
||||
def process_one_signal(samp_id, show=False):
|
||||
pass
|
||||
sample_dir = org_signal_root_path / f"{samp_id}"
|
||||
label_dir = label_root_path / f"{samp_id}"
|
||||
|
||||
tho_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Tho_Sync_*.txt"))
|
||||
abd_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Abd_Sync_*.txt"))
|
||||
flowp_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow P_Sync_*.txt"))
|
||||
flowt_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow T_Sync_*.txt"))
|
||||
spo2_signal_path = list((org_signal_root_path / f"{samp_id}").glob("SpO2_Sync_*.txt"))
|
||||
stage_signal_path = list((org_signal_root_path / f"{samp_id}").glob("5_class_Sync_*.txt"))
|
||||
tho_signal_path = resolve_sample_file(sample_dir, "Effort Tho")
|
||||
abd_signal_path = resolve_sample_file(sample_dir, "Effort Abd")
|
||||
flowp_signal_path = resolve_sample_file(sample_dir, "Flow P")
|
||||
flowt_signal_path = resolve_sample_file(sample_dir, "Flow T")
|
||||
spo2_signal_path = resolve_sample_file(sample_dir, "SpO2")
|
||||
stage_signal_path = resolve_sample_file(sample_dir, "5_class")
|
||||
label_path = resolve_sample_file(label_dir, "SA Label", suffix=".csv")
|
||||
|
||||
if not tho_signal_path:
|
||||
raise FileNotFoundError(f"Effort Tho_Sync file not found for sample ID: {samp_id}")
|
||||
tho_signal_path = tho_signal_path[0]
|
||||
print(f"Processing Effort Tho_Sync signal file: {tho_signal_path}")
|
||||
if not abd_signal_path:
|
||||
raise FileNotFoundError(f"Effort Abd_Sync file not found for sample ID: {samp_id}")
|
||||
abd_signal_path = abd_signal_path[0]
|
||||
print(f"Processing Effort Abd_Sync signal file: {abd_signal_path}")
|
||||
if not flowp_signal_path:
|
||||
raise FileNotFoundError(f"Flow P_Sync file not found for sample ID: {samp_id}")
|
||||
flowp_signal_path = flowp_signal_path[0]
|
||||
print(f"Processing Flow P_Sync signal file: {flowp_signal_path}")
|
||||
if not flowt_signal_path:
|
||||
raise FileNotFoundError(f"Flow T_Sync file not found for sample ID: {samp_id}")
|
||||
flowt_signal_path = flowt_signal_path[0]
|
||||
print(f"Processing Flow T_Sync signal file: {flowt_signal_path}")
|
||||
if not spo2_signal_path:
|
||||
raise FileNotFoundError(f"SpO2_Sync file not found for sample ID: {samp_id}")
|
||||
spo2_signal_path = spo2_signal_path[0]
|
||||
print(f"Processing SpO2_Sync signal file: {spo2_signal_path}")
|
||||
if not stage_signal_path:
|
||||
raise FileNotFoundError(f"5_class_Sync file not found for sample ID: {samp_id}")
|
||||
stage_signal_path = stage_signal_path[0]
|
||||
print(f"Processing 5_class_Sync signal file: {stage_signal_path}")
|
||||
|
||||
|
||||
|
||||
label_path = (label_root_path / f"{samp_id}").glob("SA Label_Sync.csv")
|
||||
if not label_path:
|
||||
raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}")
|
||||
label_path = list(label_path)[0]
|
||||
print(f"Processing Label_corrected file: {label_path}")
|
||||
print(f"Processing Effort Tho signal file: {tho_signal_path}")
|
||||
print(f"Processing Effort Abd signal file: {abd_signal_path}")
|
||||
print(f"Processing Flow P signal file: {flowp_signal_path}")
|
||||
print(f"Processing Flow T signal file: {flowt_signal_path}")
|
||||
print(f"Processing SpO2 signal file: {spo2_signal_path}")
|
||||
print(f"Processing 5_class signal file: {stage_signal_path}")
|
||||
print(f"Processing SA Label file: {label_path}")
|
||||
#
|
||||
# # 保存处理后的数据和标签
|
||||
save_samp_path = save_path / f"{samp_id}"
|
||||
save_samp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
signal_seconds = {
|
||||
"Effort Tho": get_signal_duration_second(tho_signal_path),
|
||||
"Effort Abd": get_signal_duration_second(abd_signal_path),
|
||||
"Flow P": get_signal_duration_second(flowp_signal_path),
|
||||
"Flow T": get_signal_duration_second(flowt_signal_path),
|
||||
"SpO2": get_signal_duration_second(spo2_signal_path),
|
||||
"5_class": get_signal_duration_second(stage_signal_path),
|
||||
}
|
||||
common_second = min(signal_seconds.values())
|
||||
print(f"Sample {samp_id} signal seconds: {signal_seconds}")
|
||||
print(f"Sample {samp_id} common_second: {common_second}")
|
||||
|
||||
# # # 读取信号数据
|
||||
tho_data_raw, tho_length, tho_fs, tho_second = utils.read_signal_txt(tho_signal_path, dtype=float, verbose=True)
|
||||
# abd_data_raw, abd_length, abd_fs, abd_second = utils.read_signal_txt(abd_signal_path, dtype=float, verbose=True)
|
||||
# flowp_data_raw, flowp_length, flowp_fs, flowp_second = utils.read_signal_txt(flowp_signal_path, dtype=float, verbose=True)
|
||||
# flowt_data_raw, flowt_length, flowt_fs, flowt_second = utils.read_signal_txt(flowt_signal_path, dtype=float, verbose=True)
|
||||
# spo2_data_raw, spo2_length, spo2_fs, spo2_second = utils.read_signal_txt(spo2_signal_path, dtype=float, verbose=True)
|
||||
stage_data_raw, stage_length, stage_fs, stage_second = utils.read_signal_txt(stage_signal_path, dtype=str, verbose=True)
|
||||
|
||||
|
||||
@ -117,7 +135,7 @@ def process_one_signal(samp_id, show=False):
|
||||
# spo2_fs = spo2_fs
|
||||
|
||||
label_data = utils.read_raw_psg_label(path=label_path)
|
||||
event_mask, score_mask = utils.generate_event_mask(signal_second=tho_second, event_df=label_data, use_correct=False, with_score=False)
|
||||
event_mask, score_mask = utils.generate_event_mask(signal_second=common_second, event_df=label_data, use_correct=False, with_score=False)
|
||||
# event_mask > 0 的部分为1,其他为0
|
||||
score_mask = np.where(event_mask > 0, 1, 0)
|
||||
|
||||
@ -128,9 +146,7 @@ def process_one_signal(samp_id, show=False):
|
||||
# 合并短于120秒的觉醒区间
|
||||
wake_mask = utils.merge_short_gaps(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), max_gap_sec=60)
|
||||
|
||||
disable_label = wake_mask
|
||||
|
||||
disable_label = disable_label[:tho_second]
|
||||
disable_label = wake_mask[:common_second]
|
||||
|
||||
|
||||
# 复制事件文件 到保存路径
|
||||
@ -139,7 +155,7 @@ def process_one_signal(samp_id, show=False):
|
||||
#
|
||||
# 新建一个dataframe,分别是秒数、SA标签,
|
||||
save_dict = {
|
||||
"Second": np.arange(tho_second),
|
||||
"Second": np.arange(common_second),
|
||||
"SA_Label": event_mask,
|
||||
"SA_Score": score_mask,
|
||||
"Disable_Label": disable_label,
|
||||
@ -180,4 +196,4 @@ if __name__ == '__main__':
|
||||
print(f"Processing sample ID: {samp_id}")
|
||||
process_one_signal(samp_id, show=False)
|
||||
print(f"Finished processing sample ID: {samp_id}\n\n")
|
||||
pass
|
||||
pass
|
||||
|
||||
@ -307,7 +307,10 @@ def read_psg_channel(path_str: Union[str, Path], channel_number: list[int], verb
|
||||
ch_path = list(path.glob(f"{ch_name}*.txt"))
|
||||
|
||||
if not any(ch_path):
|
||||
raise FileNotFoundError(f"PSG Channel file not found: {ch_path}")
|
||||
raise FileNotFoundError(
|
||||
f"PSG channel '{ch_name}' file not found in {path} "
|
||||
f"(glob pattern: {ch_name}*.txt)"
|
||||
)
|
||||
|
||||
if len(ch_path) > 1:
|
||||
print(f"Warning!!! PSG Channel file more than one: {ch_path}")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user