添加多进程处理功能,重构数据处理逻辑,更新配置文件以支持新参数
This commit is contained in:
parent
097c9cbf0b
commit
92e26425f0
@ -63,7 +63,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T
|
|||||||
|
|
||||||
|
|
||||||
# 都调整至100Hz采样率
|
# 都调整至100Hz采样率
|
||||||
target_fs = 100
|
target_fs = conf["target_fs"]
|
||||||
normalized_tho_signal = utils.adjust_sample_rate(normalized_tho_signal, tho_fs, target_fs)
|
normalized_tho_signal = utils.adjust_sample_rate(normalized_tho_signal, tho_fs, target_fs)
|
||||||
normalized_abd_signal = utils.adjust_sample_rate(normalized_abd_signal, abd_fs, target_fs)
|
normalized_abd_signal = utils.adjust_sample_rate(normalized_abd_signal, abd_fs, target_fs)
|
||||||
normalized_flowp_signal = utils.adjust_sample_rate(normalized_flowp_signal, flowp_fs, target_fs)
|
normalized_flowp_signal = utils.adjust_sample_rate(normalized_flowp_signal, flowp_fs, target_fs)
|
||||||
@ -90,8 +90,8 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T
|
|||||||
|
|
||||||
spo2_data_filt_fill, spo2_disable_mask = utils.fill_spo2_anomaly(spo2_data=spo2_data_filt,
|
spo2_data_filt_fill, spo2_disable_mask = utils.fill_spo2_anomaly(spo2_data=spo2_data_filt,
|
||||||
spo2_fs=target_fs,
|
spo2_fs=target_fs,
|
||||||
max_fill_duration=30,
|
max_fill_duration=conf["spo2_fill__anomaly"]["max_fill_duration"],
|
||||||
min_gap_duration=10,)
|
min_gap_duration=conf["spo2_fill__anomaly"]["min_gap_duration"])
|
||||||
|
|
||||||
|
|
||||||
draw_tools.draw_psg_signal(
|
draw_tools.draw_psg_signal(
|
||||||
@ -135,7 +135,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T
|
|||||||
"EnableSegment": utils.event_mask_2_list(1 - event_mask["Disable_Label"]),
|
"EnableSegment": utils.event_mask_2_list(1 - event_mask["Disable_Label"]),
|
||||||
"DisableSegment": utils.event_mask_2_list(event_mask["Disable_Label"])}
|
"DisableSegment": utils.event_mask_2_list(event_mask["Disable_Label"])}
|
||||||
|
|
||||||
spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=95)
|
spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=conf["spo2_fill__anomaly"]["nan_to_num_value"])
|
||||||
|
|
||||||
segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose)
|
segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose)
|
||||||
if verbose:
|
if verbose:
|
||||||
|
|||||||
@ -210,6 +210,31 @@ def multiprocess_with_tqdm(args_list, n_processes):
|
|||||||
future.result()
|
future.result()
|
||||||
|
|
||||||
|
|
||||||
|
def multiprocess_with_pool(args_list, n_processes):
|
||||||
|
"""使用Pool,每个worker处理固定数量任务后重启"""
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
|
# maxtasksperchild 设置每个worker处理多少任务后重启(释放内存)
|
||||||
|
with Pool(processes=n_processes, maxtasksperchild=2) as pool:
|
||||||
|
results = []
|
||||||
|
for samp_id in args_list:
|
||||||
|
result = pool.apply_async(
|
||||||
|
build_HYS_dataset_segment,
|
||||||
|
args=(samp_id, False, True, False, None, None)
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# 等待所有任务完成
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
try:
|
||||||
|
result.get()
|
||||||
|
print(f"Completed {i + 1}/{len(args_list)}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing task {i}: {e}")
|
||||||
|
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -249,4 +274,5 @@ if __name__ == '__main__':
|
|||||||
# print(f"Processing sample ID: {samp_id}")
|
# print(f"Processing sample ID: {samp_id}")
|
||||||
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
|
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
|
||||||
|
|
||||||
multiprocess_with_tqdm(args_list=select_ids, n_processes=16)
|
# multiprocess_with_tqdm(args_list=select_ids, n_processes=16)
|
||||||
|
multiprocess_with_pool(args_list=select_ids, n_processes=16)
|
||||||
@ -26,6 +26,8 @@ select_ids:
|
|||||||
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
|
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
|
||||||
mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG
|
mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG
|
||||||
|
|
||||||
|
target_fs: 100
|
||||||
|
|
||||||
dataset_config:
|
dataset_config:
|
||||||
window_sec: 180
|
window_sec: 180
|
||||||
stride_sec: 60
|
stride_sec: 60
|
||||||
@ -42,6 +44,9 @@ effort_filter:
|
|||||||
high_cut: 0.5
|
high_cut: 0.5
|
||||||
order: 3
|
order: 3
|
||||||
|
|
||||||
|
average_filter:
|
||||||
|
window_size_sec: 20
|
||||||
|
|
||||||
flow:
|
flow:
|
||||||
downsample_fs: 10
|
downsample_fs: 10
|
||||||
|
|
||||||
@ -51,6 +56,11 @@ flow_filter:
|
|||||||
high_cut: 0.5
|
high_cut: 0.5
|
||||||
order: 3
|
order: 3
|
||||||
|
|
||||||
|
spo2_fill__anomaly:
|
||||||
|
max_fill_duration: 30
|
||||||
|
min_gap_duration: 10
|
||||||
|
nan_to_num_value: 95
|
||||||
|
|
||||||
#ecg:
|
#ecg:
|
||||||
# downsample_fs: 100
|
# downsample_fs: 100
|
||||||
#
|
#
|
||||||
|
|||||||
41
dataset_config/SHHS1_config.yaml
Normal file
41
dataset_config/SHHS1_config.yaml
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
root_path: /mnt/disk_wd/marques_dataset/shhs/polysomnography/shhs1
|
||||||
|
mask_save_path: /mnt/disk_code/marques/dataprepare/output/shhs1
|
||||||
|
|
||||||
|
effort_target_fs: 10
|
||||||
|
ecg_target_fs: 100
|
||||||
|
|
||||||
|
|
||||||
|
dataset_config:
|
||||||
|
window_sec: 180
|
||||||
|
stride_sec: 60
|
||||||
|
dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/SHHS1_dataset
|
||||||
|
dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/SHHS1_dataset/visualization
|
||||||
|
|
||||||
|
|
||||||
|
effort:
|
||||||
|
downsample_fs: 10
|
||||||
|
|
||||||
|
effort_filter:
|
||||||
|
filter_type: bandpass
|
||||||
|
low_cut: 0.05
|
||||||
|
high_cut: 0.5
|
||||||
|
order: 3
|
||||||
|
|
||||||
|
average_filter:
|
||||||
|
window_size_sec: 20
|
||||||
|
|
||||||
|
flow:
|
||||||
|
downsample_fs: 10
|
||||||
|
|
||||||
|
flow_filter:
|
||||||
|
filter_type: bandpass
|
||||||
|
low_cut: 0.05
|
||||||
|
high_cut: 0.5
|
||||||
|
order: 3
|
||||||
|
|
||||||
|
spo2_fill__anomaly:
|
||||||
|
max_fill_duration: 30
|
||||||
|
min_gap_duration: 10
|
||||||
|
nan_to_num_value: 95
|
||||||
|
|
||||||
|
|
||||||
100
dataset_tools/shhs_annotations_check.py
Normal file
100
dataset_tools/shhs_annotations_check.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from lxml import etree
|
||||||
|
from tqdm import tqdm
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 设定目标文件夹路径,你可以修改这里的路径,或者运行脚本时手动输入
|
||||||
|
# 默认为当前目录 '.'
|
||||||
|
# target_dir = "/mnt/disk_wd/marques_dataset/shhs/polysomnography/annotations-events-nsrr/shhs1"
|
||||||
|
target_dir = "/mnt/disk_wd/marques_dataset/shhs/polysomnography/annotations-events-nsrr/shhs2"
|
||||||
|
|
||||||
|
folder_path = Path(target_dir)
|
||||||
|
|
||||||
|
if not folder_path.exists():
|
||||||
|
print(f"错误: 路径 '{folder_path}' 不存在。")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 1. 获取所有 XML 文件 (扁平结构,不递归子目录)
|
||||||
|
xml_files = list(folder_path.glob("*.xml"))
|
||||||
|
total_files = len(xml_files)
|
||||||
|
|
||||||
|
if total_files == 0:
|
||||||
|
print(f"在 '{folder_path}' 中没有找到 XML 文件。")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"找到 {total_files} 个 XML 文件,准备开始处理...")
|
||||||
|
|
||||||
|
# 用于统计 (EventType, EventConcept) 组合的计数器
|
||||||
|
stats_counter = Counter()
|
||||||
|
|
||||||
|
# 2. 遍历文件,使用 tqdm 显示进度条
|
||||||
|
for xml_file in tqdm(xml_files, desc="Processing XMLs", unit="file"):
|
||||||
|
try:
|
||||||
|
# 使用 lxml 解析
|
||||||
|
tree = etree.parse(str(xml_file))
|
||||||
|
root = tree.getroot()
|
||||||
|
|
||||||
|
# 3. 定位到 ScoredEvent 节点
|
||||||
|
# SHHS XML 结构通常是: PSGAnnotation -> ScoredEvents -> ScoredEvent
|
||||||
|
# 我们直接查找所有的 ScoredEvent 节点
|
||||||
|
events = root.findall(".//ScoredEvent")
|
||||||
|
|
||||||
|
for event in events:
|
||||||
|
# 提取 EventType
|
||||||
|
type_node = event.find("EventType")
|
||||||
|
# 处理节点不存在或文本为空的情况
|
||||||
|
e_type = type_node.text.strip() if (type_node is not None and type_node.text) else "N/A"
|
||||||
|
|
||||||
|
# 提取 EventConcept
|
||||||
|
concept_node = event.find("EventConcept")
|
||||||
|
e_concept = concept_node.text.strip() if (concept_node is not None and concept_node.text) else "N/A"
|
||||||
|
|
||||||
|
# 4. 组合并计数
|
||||||
|
# 组合键为元组 (EventType, EventConcept)
|
||||||
|
key = (e_type, e_concept)
|
||||||
|
stats_counter[key] += 1
|
||||||
|
|
||||||
|
except etree.XMLSyntaxError:
|
||||||
|
print(f"\n[警告] 文件格式错误,跳过: {xml_file.name}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n[错误] 处理文件 {xml_file.name} 时出错: {e}")
|
||||||
|
|
||||||
|
# 5. 打印结果到终端
|
||||||
|
if stats_counter:
|
||||||
|
# --- 动态计算列宽 ---
|
||||||
|
# 获取所有 EventType 的最大长度,默认长度 9
|
||||||
|
max_type_width = max((len(k[0]) for k in stats_counter.keys()), default=9)
|
||||||
|
max_type_width = max(max_type_width, 9)
|
||||||
|
|
||||||
|
# 获取所有 EventConcept 的最大长度,默认长度 12
|
||||||
|
max_conc_width = max((len(k[1]) for k in stats_counter.keys()), default=12)
|
||||||
|
max_conc_width = max(max_conc_width, 12)
|
||||||
|
|
||||||
|
# 计算表格总宽度
|
||||||
|
total_line_width = max_type_width + max_conc_width + 10 + 6
|
||||||
|
|
||||||
|
print("\n" + "=" * total_line_width)
|
||||||
|
print(f"{'EventType':<{max_type_width}} | {'EventConcept':<{max_conc_width}} | {'Count':>10}")
|
||||||
|
print("-" * total_line_width)
|
||||||
|
|
||||||
|
# --- 修改处:按名称排序 ---
|
||||||
|
# sorted() 默认会对元组 (EventType, EventConcept) 进行字典序排序
|
||||||
|
# 即先按 EventType A-Z 排序,再按 EventConcept A-Z 排序
|
||||||
|
for (e_type, e_concept), count in sorted(stats_counter.items()):
|
||||||
|
print(f"{e_type:<{max_type_width}} | {e_concept:<{max_conc_width}} | {count:>10}")
|
||||||
|
|
||||||
|
print("=" * total_line_width)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("\n未提取到任何事件数据。")
|
||||||
|
|
||||||
|
print("=" * 90)
|
||||||
|
print(f"统计完成。共扫描 {total_files} 个文件。")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
42
event_mask_process/SHHS1_process.py
Normal file
42
event_mask_process/SHHS1_process.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
|
project_root_path = Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import draw_tools
|
||||||
|
import utils
|
||||||
|
import numpy as np
|
||||||
|
import signal_method
|
||||||
|
import os
|
||||||
|
import mne
|
||||||
|
from tqdm import tqdm
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
import re
|
||||||
|
|
||||||
|
# 获取分期和事件标签,以及不可用区间
|
||||||
|
|
||||||
|
def process_one_signal(samp_id, show=False):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
yaml_path = project_root_path / "dataset_config/SHHS1_config.yaml"
|
||||||
|
|
||||||
|
conf = utils.load_dataset_conf(yaml_path)
|
||||||
|
|
||||||
|
root_path = Path(conf["root_path"])
|
||||||
|
save_path = Path(conf["mask_save_path"])
|
||||||
|
|
||||||
|
print(f"root_path: {root_path}")
|
||||||
|
print(f"save_path: {save_path}")
|
||||||
|
|
||||||
|
org_signal_root_path = root_path / "OrgBCG_Aligned"
|
||||||
|
label_root_path = root_path / "Label"
|
||||||
|
|
||||||
62
signal_method/shhs_tools.py
Normal file
62
signal_method/shhs_tools.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
ANNOTATION_MAP = {
|
||||||
|
"Wake|0": 0,
|
||||||
|
"Stage 1 sleep|1": 1,
|
||||||
|
"Stage 2 sleep|2": 2,
|
||||||
|
"Stage 3 sleep|3": 3,
|
||||||
|
"Stage 4 sleep|4": 4,
|
||||||
|
"REM sleep|5": 5,
|
||||||
|
"Unscored|9": 9,
|
||||||
|
"Movement|6": 6
|
||||||
|
}
|
||||||
|
|
||||||
|
SA_EVENTS = ['Central apnea', 'Hypopnea', 'Obstructive apnea']
|
||||||
|
|
||||||
|
def parse_sleep_annotations(annotation_path):
|
||||||
|
"""解析睡眠分期注释"""
|
||||||
|
try:
|
||||||
|
tree = ET.parse(annotation_path)
|
||||||
|
root = tree.getroot()
|
||||||
|
events = []
|
||||||
|
for scored_event in root.findall('.//ScoredEvent'):
|
||||||
|
event_type = scored_event.find('EventType').text
|
||||||
|
if event_type != "Stages|Stages":
|
||||||
|
continue
|
||||||
|
description = scored_event.find('EventConcept').text
|
||||||
|
start = float(scored_event.find('Start').text)
|
||||||
|
duration = float(scored_event.find('Duration').text)
|
||||||
|
if description not in ANNOTATION_MAP:
|
||||||
|
continue
|
||||||
|
events.append({
|
||||||
|
'onset': start,
|
||||||
|
'duration': duration,
|
||||||
|
'description': description,
|
||||||
|
'stage': ANNOTATION_MAP[description]
|
||||||
|
})
|
||||||
|
return events
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_osa_events(annotation_path):
|
||||||
|
"""提取睡眠呼吸暂停事件"""
|
||||||
|
try:
|
||||||
|
tree = ET.parse(annotation_path)
|
||||||
|
root = tree.getroot()
|
||||||
|
events = []
|
||||||
|
for scored_event in root.findall('.//ScoredEvent'):
|
||||||
|
event_concept = scored_event.find('EventConcept').text
|
||||||
|
event_type = event_concept.split('|')[0].strip()
|
||||||
|
if event_type in SA_EVENTS:
|
||||||
|
start = float(scored_event.find('Start').text)
|
||||||
|
duration = float(scored_event.find('Duration').text)
|
||||||
|
if duration >= 10:
|
||||||
|
events.append({
|
||||||
|
'start': start,
|
||||||
|
'duration': duration,
|
||||||
|
'end': start + duration,
|
||||||
|
'type': event_type
|
||||||
|
})
|
||||||
|
return events
|
||||||
|
except Exception as e:
|
||||||
|
return []
|
||||||
@ -52,7 +52,7 @@ def psg_effort_filter(conf, effort_data_raw, effort_fs):
|
|||||||
high_cut=conf["effort_filter"]["high_cut"], order=conf["effort_filter"]["order"],
|
high_cut=conf["effort_filter"]["high_cut"], order=conf["effort_filter"]["order"],
|
||||||
sample_rate=effort_fs)
|
sample_rate=effort_fs)
|
||||||
# 移动平均
|
# 移动平均
|
||||||
effort_data_2 = utils.average_filter(raw_data=effort_data_1, sample_rate=effort_fs, window_size_sec=20)
|
effort_data_2 = utils.average_filter(raw_data=effort_data_1, sample_rate=effort_fs, window_size_sec=conf["average_filter"]["window_size_sec"])
|
||||||
return effort_data_raw, effort_data_2, effort_fs
|
return effort_data_raw, effort_data_2, effort_fs
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user