From 92e26425f0dda737645069814d91d63573dc82e1 Mon Sep 17 00:00:00 2001 From: marques Date: Mon, 26 Jan 2026 14:03:37 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=A4=9A=E8=BF=9B=E7=A8=8B?= =?UTF-8?q?=E5=A4=84=E7=90=86=E5=8A=9F=E8=83=BD=EF=BC=8C=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E6=9B=B4=E6=96=B0=E9=85=8D=E7=BD=AE=E6=96=87=E4=BB=B6=E4=BB=A5?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E6=96=B0=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dataset_builder/HYS_PSG_dataset.py | 8 +- dataset_builder/HYS_dataset.py | 28 ++++++- dataset_config/HYS_PSG_config.yaml | 10 +++ dataset_config/SHHS1_config.yaml | 41 ++++++++++ dataset_tools/shhs_annotations_check.py | 100 ++++++++++++++++++++++++ event_mask_process/SHHS1_process.py | 42 ++++++++++ signal_method/shhs_tools.py | 62 +++++++++++++++ signal_method/signal_process.py | 2 +- 8 files changed, 287 insertions(+), 6 deletions(-) create mode 100644 dataset_config/SHHS1_config.yaml create mode 100644 dataset_tools/shhs_annotations_check.py create mode 100644 event_mask_process/SHHS1_process.py create mode 100644 signal_method/shhs_tools.py diff --git a/dataset_builder/HYS_PSG_dataset.py b/dataset_builder/HYS_PSG_dataset.py index e061e25..54e8d05 100644 --- a/dataset_builder/HYS_PSG_dataset.py +++ b/dataset_builder/HYS_PSG_dataset.py @@ -63,7 +63,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T # 都调整至100Hz采样率 - target_fs = 100 + target_fs = conf["target_fs"] normalized_tho_signal = utils.adjust_sample_rate(normalized_tho_signal, tho_fs, target_fs) normalized_abd_signal = utils.adjust_sample_rate(normalized_abd_signal, abd_fs, target_fs) normalized_flowp_signal = utils.adjust_sample_rate(normalized_flowp_signal, flowp_fs, target_fs) @@ -90,8 +90,8 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T spo2_data_filt_fill, spo2_disable_mask = utils.fill_spo2_anomaly(spo2_data=spo2_data_filt, spo2_fs=target_fs, - max_fill_duration=30, - min_gap_duration=10,) + max_fill_duration=conf["spo2_fill__anomaly"]["max_fill_duration"], + min_gap_duration=conf["spo2_fill__anomaly"]["min_gap_duration"]) draw_tools.draw_psg_signal( @@ -135,7 +135,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T "EnableSegment": utils.event_mask_2_list(1 - event_mask["Disable_Label"]), "DisableSegment": utils.event_mask_2_list(event_mask["Disable_Label"])} - spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=95) + spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=conf["spo2_fill__anomaly"]["nan_to_num_value"]) segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose) if verbose: diff --git a/dataset_builder/HYS_dataset.py b/dataset_builder/HYS_dataset.py index f944539..c5edd4c 100644 --- a/dataset_builder/HYS_dataset.py +++ b/dataset_builder/HYS_dataset.py @@ -210,6 +210,31 @@ def multiprocess_with_tqdm(args_list, n_processes): future.result() +def multiprocess_with_pool(args_list, n_processes): + """使用Pool,每个worker处理固定数量任务后重启""" + from multiprocessing import Pool + + # maxtasksperchild 设置每个worker处理多少任务后重启(释放内存) + with Pool(processes=n_processes, maxtasksperchild=2) as pool: + results = [] + for samp_id in args_list: + result = pool.apply_async( + build_HYS_dataset_segment, + args=(samp_id, False, True, False, None, None) + ) + results.append(result) + + # 等待所有任务完成 + for i, result in enumerate(results): + try: + result.get() + print(f"Completed {i + 1}/{len(args_list)}") + except Exception as e: + print(f"Error processing task {i}: {e}") + + pool.close() + pool.join() + if __name__ == '__main__': @@ -249,4 +274,5 @@ if __name__ == '__main__': # print(f"Processing sample ID: {samp_id}") # build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) - multiprocess_with_tqdm(args_list=select_ids, n_processes=16) \ No newline at end of file + # multiprocess_with_tqdm(args_list=select_ids, n_processes=16) + multiprocess_with_pool(args_list=select_ids, n_processes=16) \ No newline at end of file diff --git a/dataset_config/HYS_PSG_config.yaml b/dataset_config/HYS_PSG_config.yaml index 63233df..6496a72 100644 --- a/dataset_config/HYS_PSG_config.yaml +++ b/dataset_config/HYS_PSG_config.yaml @@ -26,6 +26,8 @@ select_ids: root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG +target_fs: 100 + dataset_config: window_sec: 180 stride_sec: 60 @@ -42,6 +44,9 @@ effort_filter: high_cut: 0.5 order: 3 +average_filter: + window_size_sec: 20 + flow: downsample_fs: 10 @@ -51,6 +56,11 @@ flow_filter: high_cut: 0.5 order: 3 +spo2_fill__anomaly: + max_fill_duration: 30 + min_gap_duration: 10 + nan_to_num_value: 95 + #ecg: # downsample_fs: 100 # diff --git a/dataset_config/SHHS1_config.yaml b/dataset_config/SHHS1_config.yaml new file mode 100644 index 0000000..0b57be4 --- /dev/null +++ b/dataset_config/SHHS1_config.yaml @@ -0,0 +1,41 @@ +root_path: /mnt/disk_wd/marques_dataset/shhs/polysomnography/shhs1 +mask_save_path: /mnt/disk_code/marques/dataprepare/output/shhs1 + +effort_target_fs: 10 +ecg_target_fs: 100 + + +dataset_config: + window_sec: 180 + stride_sec: 60 + dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/SHHS1_dataset + dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/SHHS1_dataset/visualization + + +effort: + downsample_fs: 10 + +effort_filter: + filter_type: bandpass + low_cut: 0.05 + high_cut: 0.5 + order: 3 + +average_filter: + window_size_sec: 20 + +flow: + downsample_fs: 10 + +flow_filter: + filter_type: bandpass + low_cut: 0.05 + high_cut: 0.5 + order: 3 + +spo2_fill__anomaly: + max_fill_duration: 30 + min_gap_duration: 10 + nan_to_num_value: 95 + + diff --git a/dataset_tools/shhs_annotations_check.py b/dataset_tools/shhs_annotations_check.py new file mode 100644 index 0000000..40061c6 --- /dev/null +++ b/dataset_tools/shhs_annotations_check.py @@ -0,0 +1,100 @@ + +import argparse +from pathlib import Path +from lxml import etree +from tqdm import tqdm +from collections import Counter + + +def main(): + # 设定目标文件夹路径,你可以修改这里的路径,或者运行脚本时手动输入 + # 默认为当前目录 '.' + # target_dir = "/mnt/disk_wd/marques_dataset/shhs/polysomnography/annotations-events-nsrr/shhs1" + target_dir = "/mnt/disk_wd/marques_dataset/shhs/polysomnography/annotations-events-nsrr/shhs2" + + folder_path = Path(target_dir) + + if not folder_path.exists(): + print(f"错误: 路径 '{folder_path}' 不存在。") + return + + # 1. 获取所有 XML 文件 (扁平结构,不递归子目录) + xml_files = list(folder_path.glob("*.xml")) + total_files = len(xml_files) + + if total_files == 0: + print(f"在 '{folder_path}' 中没有找到 XML 文件。") + return + + print(f"找到 {total_files} 个 XML 文件,准备开始处理...") + + # 用于统计 (EventType, EventConcept) 组合的计数器 + stats_counter = Counter() + + # 2. 遍历文件,使用 tqdm 显示进度条 + for xml_file in tqdm(xml_files, desc="Processing XMLs", unit="file"): + try: + # 使用 lxml 解析 + tree = etree.parse(str(xml_file)) + root = tree.getroot() + + # 3. 定位到 ScoredEvent 节点 + # SHHS XML 结构通常是: PSGAnnotation -> ScoredEvents -> ScoredEvent + # 我们直接查找所有的 ScoredEvent 节点 + events = root.findall(".//ScoredEvent") + + for event in events: + # 提取 EventType + type_node = event.find("EventType") + # 处理节点不存在或文本为空的情况 + e_type = type_node.text.strip() if (type_node is not None and type_node.text) else "N/A" + + # 提取 EventConcept + concept_node = event.find("EventConcept") + e_concept = concept_node.text.strip() if (concept_node is not None and concept_node.text) else "N/A" + + # 4. 组合并计数 + # 组合键为元组 (EventType, EventConcept) + key = (e_type, e_concept) + stats_counter[key] += 1 + + except etree.XMLSyntaxError: + print(f"\n[警告] 文件格式错误,跳过: {xml_file.name}") + except Exception as e: + print(f"\n[错误] 处理文件 {xml_file.name} 时出错: {e}") + + # 5. 打印结果到终端 + if stats_counter: + # --- 动态计算列宽 --- + # 获取所有 EventType 的最大长度,默认长度 9 + max_type_width = max((len(k[0]) for k in stats_counter.keys()), default=9) + max_type_width = max(max_type_width, 9) + + # 获取所有 EventConcept 的最大长度,默认长度 12 + max_conc_width = max((len(k[1]) for k in stats_counter.keys()), default=12) + max_conc_width = max(max_conc_width, 12) + + # 计算表格总宽度 + total_line_width = max_type_width + max_conc_width + 10 + 6 + + print("\n" + "=" * total_line_width) + print(f"{'EventType':<{max_type_width}} | {'EventConcept':<{max_conc_width}} | {'Count':>10}") + print("-" * total_line_width) + + # --- 修改处:按名称排序 --- + # sorted() 默认会对元组 (EventType, EventConcept) 进行字典序排序 + # 即先按 EventType A-Z 排序,再按 EventConcept A-Z 排序 + for (e_type, e_concept), count in sorted(stats_counter.items()): + print(f"{e_type:<{max_type_width}} | {e_concept:<{max_conc_width}} | {count:>10}") + + print("=" * total_line_width) + + else: + print("\n未提取到任何事件数据。") + + print("=" * 90) + print(f"统计完成。共扫描 {total_files} 个文件。") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/event_mask_process/SHHS1_process.py b/event_mask_process/SHHS1_process.py new file mode 100644 index 0000000..a1fdae6 --- /dev/null +++ b/event_mask_process/SHHS1_process.py @@ -0,0 +1,42 @@ +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) +project_root_path = Path(__file__).resolve().parent.parent + +import shutil +import draw_tools +import utils +import numpy as np +import signal_method +import os +import mne +from tqdm import tqdm +import xml.etree.ElementTree as ET +import re + +# 获取分期和事件标签,以及不可用区间 + +def process_one_signal(samp_id, show=False): + + + + + + + + +if __name__ == '__main__': + yaml_path = project_root_path / "dataset_config/SHHS1_config.yaml" + + conf = utils.load_dataset_conf(yaml_path) + + root_path = Path(conf["root_path"]) + save_path = Path(conf["mask_save_path"]) + + print(f"root_path: {root_path}") + print(f"save_path: {save_path}") + + org_signal_root_path = root_path / "OrgBCG_Aligned" + label_root_path = root_path / "Label" + diff --git a/signal_method/shhs_tools.py b/signal_method/shhs_tools.py new file mode 100644 index 0000000..fc4c653 --- /dev/null +++ b/signal_method/shhs_tools.py @@ -0,0 +1,62 @@ +import xml.etree.ElementTree as ET +ANNOTATION_MAP = { + "Wake|0": 0, + "Stage 1 sleep|1": 1, + "Stage 2 sleep|2": 2, + "Stage 3 sleep|3": 3, + "Stage 4 sleep|4": 4, + "REM sleep|5": 5, + "Unscored|9": 9, + "Movement|6": 6 +} + +SA_EVENTS = ['Central apnea', 'Hypopnea', 'Obstructive apnea'] + +def parse_sleep_annotations(annotation_path): + """解析睡眠分期注释""" + try: + tree = ET.parse(annotation_path) + root = tree.getroot() + events = [] + for scored_event in root.findall('.//ScoredEvent'): + event_type = scored_event.find('EventType').text + if event_type != "Stages|Stages": + continue + description = scored_event.find('EventConcept').text + start = float(scored_event.find('Start').text) + duration = float(scored_event.find('Duration').text) + if description not in ANNOTATION_MAP: + continue + events.append({ + 'onset': start, + 'duration': duration, + 'description': description, + 'stage': ANNOTATION_MAP[description] + }) + return events + except Exception as e: + return None + + +def extract_osa_events(annotation_path): + """提取睡眠呼吸暂停事件""" + try: + tree = ET.parse(annotation_path) + root = tree.getroot() + events = [] + for scored_event in root.findall('.//ScoredEvent'): + event_concept = scored_event.find('EventConcept').text + event_type = event_concept.split('|')[0].strip() + if event_type in SA_EVENTS: + start = float(scored_event.find('Start').text) + duration = float(scored_event.find('Duration').text) + if duration >= 10: + events.append({ + 'start': start, + 'duration': duration, + 'end': start + duration, + 'type': event_type + }) + return events + except Exception as e: + return [] \ No newline at end of file diff --git a/signal_method/signal_process.py b/signal_method/signal_process.py index a303ca4..168de8f 100644 --- a/signal_method/signal_process.py +++ b/signal_method/signal_process.py @@ -52,7 +52,7 @@ def psg_effort_filter(conf, effort_data_raw, effort_fs): high_cut=conf["effort_filter"]["high_cut"], order=conf["effort_filter"]["order"], sample_rate=effort_fs) # 移动平均 - effort_data_2 = utils.average_filter(raw_data=effort_data_1, sample_rate=effort_fs, window_size_sec=20) + effort_data_2 = utils.average_filter(raw_data=effort_data_1, sample_rate=effort_fs, window_size_sec=conf["average_filter"]["window_size_sec"]) return effort_data_raw, effort_data_2, effort_fs