添加多进程处理功能,重构数据处理逻辑,更新配置文件以支持新参数

This commit is contained in:
marques 2026-01-26 14:03:37 +08:00
parent 097c9cbf0b
commit 92e26425f0
8 changed files with 287 additions and 6 deletions

View File

@ -63,7 +63,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T
# 都调整至100Hz采样率 # 都调整至100Hz采样率
target_fs = 100 target_fs = conf["target_fs"]
normalized_tho_signal = utils.adjust_sample_rate(normalized_tho_signal, tho_fs, target_fs) normalized_tho_signal = utils.adjust_sample_rate(normalized_tho_signal, tho_fs, target_fs)
normalized_abd_signal = utils.adjust_sample_rate(normalized_abd_signal, abd_fs, target_fs) normalized_abd_signal = utils.adjust_sample_rate(normalized_abd_signal, abd_fs, target_fs)
normalized_flowp_signal = utils.adjust_sample_rate(normalized_flowp_signal, flowp_fs, target_fs) normalized_flowp_signal = utils.adjust_sample_rate(normalized_flowp_signal, flowp_fs, target_fs)
@ -90,8 +90,8 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T
spo2_data_filt_fill, spo2_disable_mask = utils.fill_spo2_anomaly(spo2_data=spo2_data_filt, spo2_data_filt_fill, spo2_disable_mask = utils.fill_spo2_anomaly(spo2_data=spo2_data_filt,
spo2_fs=target_fs, spo2_fs=target_fs,
max_fill_duration=30, max_fill_duration=conf["spo2_fill__anomaly"]["max_fill_duration"],
min_gap_duration=10,) min_gap_duration=conf["spo2_fill__anomaly"]["min_gap_duration"])
draw_tools.draw_psg_signal( draw_tools.draw_psg_signal(
@ -135,7 +135,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T
"EnableSegment": utils.event_mask_2_list(1 - event_mask["Disable_Label"]), "EnableSegment": utils.event_mask_2_list(1 - event_mask["Disable_Label"]),
"DisableSegment": utils.event_mask_2_list(event_mask["Disable_Label"])} "DisableSegment": utils.event_mask_2_list(event_mask["Disable_Label"])}
spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=95) spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=conf["spo2_fill__anomaly"]["nan_to_num_value"])
segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose) segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose)
if verbose: if verbose:

View File

@ -210,6 +210,31 @@ def multiprocess_with_tqdm(args_list, n_processes):
future.result() future.result()
def multiprocess_with_pool(args_list, n_processes):
"""使用Pool每个worker处理固定数量任务后重启"""
from multiprocessing import Pool
# maxtasksperchild 设置每个worker处理多少任务后重启释放内存
with Pool(processes=n_processes, maxtasksperchild=2) as pool:
results = []
for samp_id in args_list:
result = pool.apply_async(
build_HYS_dataset_segment,
args=(samp_id, False, True, False, None, None)
)
results.append(result)
# 等待所有任务完成
for i, result in enumerate(results):
try:
result.get()
print(f"Completed {i + 1}/{len(args_list)}")
except Exception as e:
print(f"Error processing task {i}: {e}")
pool.close()
pool.join()
if __name__ == '__main__': if __name__ == '__main__':
@ -249,4 +274,5 @@ if __name__ == '__main__':
# print(f"Processing sample ID: {samp_id}") # print(f"Processing sample ID: {samp_id}")
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) # build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
multiprocess_with_tqdm(args_list=select_ids, n_processes=16) # multiprocess_with_tqdm(args_list=select_ids, n_processes=16)
multiprocess_with_pool(args_list=select_ids, n_processes=16)

View File

@ -26,6 +26,8 @@ select_ids:
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG
target_fs: 100
dataset_config: dataset_config:
window_sec: 180 window_sec: 180
stride_sec: 60 stride_sec: 60
@ -42,6 +44,9 @@ effort_filter:
high_cut: 0.5 high_cut: 0.5
order: 3 order: 3
average_filter:
window_size_sec: 20
flow: flow:
downsample_fs: 10 downsample_fs: 10
@ -51,6 +56,11 @@ flow_filter:
high_cut: 0.5 high_cut: 0.5
order: 3 order: 3
spo2_fill__anomaly:
max_fill_duration: 30
min_gap_duration: 10
nan_to_num_value: 95
#ecg: #ecg:
# downsample_fs: 100 # downsample_fs: 100
# #

View File

@ -0,0 +1,41 @@
root_path: /mnt/disk_wd/marques_dataset/shhs/polysomnography/shhs1
mask_save_path: /mnt/disk_code/marques/dataprepare/output/shhs1
effort_target_fs: 10
ecg_target_fs: 100
dataset_config:
window_sec: 180
stride_sec: 60
dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/SHHS1_dataset
dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/SHHS1_dataset/visualization
effort:
downsample_fs: 10
effort_filter:
filter_type: bandpass
low_cut: 0.05
high_cut: 0.5
order: 3
average_filter:
window_size_sec: 20
flow:
downsample_fs: 10
flow_filter:
filter_type: bandpass
low_cut: 0.05
high_cut: 0.5
order: 3
spo2_fill__anomaly:
max_fill_duration: 30
min_gap_duration: 10
nan_to_num_value: 95

View File

@ -0,0 +1,100 @@
import argparse
from pathlib import Path
from lxml import etree
from tqdm import tqdm
from collections import Counter
def main():
# 设定目标文件夹路径,你可以修改这里的路径,或者运行脚本时手动输入
# 默认为当前目录 '.'
# target_dir = "/mnt/disk_wd/marques_dataset/shhs/polysomnography/annotations-events-nsrr/shhs1"
target_dir = "/mnt/disk_wd/marques_dataset/shhs/polysomnography/annotations-events-nsrr/shhs2"
folder_path = Path(target_dir)
if not folder_path.exists():
print(f"错误: 路径 '{folder_path}' 不存在。")
return
# 1. 获取所有 XML 文件 (扁平结构,不递归子目录)
xml_files = list(folder_path.glob("*.xml"))
total_files = len(xml_files)
if total_files == 0:
print(f"'{folder_path}' 中没有找到 XML 文件。")
return
print(f"找到 {total_files} 个 XML 文件,准备开始处理...")
# 用于统计 (EventType, EventConcept) 组合的计数器
stats_counter = Counter()
# 2. 遍历文件,使用 tqdm 显示进度条
for xml_file in tqdm(xml_files, desc="Processing XMLs", unit="file"):
try:
# 使用 lxml 解析
tree = etree.parse(str(xml_file))
root = tree.getroot()
# 3. 定位到 ScoredEvent 节点
# SHHS XML 结构通常是: PSGAnnotation -> ScoredEvents -> ScoredEvent
# 我们直接查找所有的 ScoredEvent 节点
events = root.findall(".//ScoredEvent")
for event in events:
# 提取 EventType
type_node = event.find("EventType")
# 处理节点不存在或文本为空的情况
e_type = type_node.text.strip() if (type_node is not None and type_node.text) else "N/A"
# 提取 EventConcept
concept_node = event.find("EventConcept")
e_concept = concept_node.text.strip() if (concept_node is not None and concept_node.text) else "N/A"
# 4. 组合并计数
# 组合键为元组 (EventType, EventConcept)
key = (e_type, e_concept)
stats_counter[key] += 1
except etree.XMLSyntaxError:
print(f"\n[警告] 文件格式错误,跳过: {xml_file.name}")
except Exception as e:
print(f"\n[错误] 处理文件 {xml_file.name} 时出错: {e}")
# 5. 打印结果到终端
if stats_counter:
# --- 动态计算列宽 ---
# 获取所有 EventType 的最大长度,默认长度 9
max_type_width = max((len(k[0]) for k in stats_counter.keys()), default=9)
max_type_width = max(max_type_width, 9)
# 获取所有 EventConcept 的最大长度,默认长度 12
max_conc_width = max((len(k[1]) for k in stats_counter.keys()), default=12)
max_conc_width = max(max_conc_width, 12)
# 计算表格总宽度
total_line_width = max_type_width + max_conc_width + 10 + 6
print("\n" + "=" * total_line_width)
print(f"{'EventType':<{max_type_width}} | {'EventConcept':<{max_conc_width}} | {'Count':>10}")
print("-" * total_line_width)
# --- 修改处:按名称排序 ---
# sorted() 默认会对元组 (EventType, EventConcept) 进行字典序排序
# 即先按 EventType A-Z 排序,再按 EventConcept A-Z 排序
for (e_type, e_concept), count in sorted(stats_counter.items()):
print(f"{e_type:<{max_type_width}} | {e_concept:<{max_conc_width}} | {count:>10}")
print("=" * total_line_width)
else:
print("\n未提取到任何事件数据。")
print("=" * 90)
print(f"统计完成。共扫描 {total_files} 个文件。")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,42 @@
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import shutil
import draw_tools
import utils
import numpy as np
import signal_method
import os
import mne
from tqdm import tqdm
import xml.etree.ElementTree as ET
import re
# 获取分期和事件标签,以及不可用区间
def process_one_signal(samp_id, show=False):
if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/SHHS1_config.yaml"
conf = utils.load_dataset_conf(yaml_path)
root_path = Path(conf["root_path"])
save_path = Path(conf["mask_save_path"])
print(f"root_path: {root_path}")
print(f"save_path: {save_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned"
label_root_path = root_path / "Label"

View File

@ -0,0 +1,62 @@
import xml.etree.ElementTree as ET
ANNOTATION_MAP = {
"Wake|0": 0,
"Stage 1 sleep|1": 1,
"Stage 2 sleep|2": 2,
"Stage 3 sleep|3": 3,
"Stage 4 sleep|4": 4,
"REM sleep|5": 5,
"Unscored|9": 9,
"Movement|6": 6
}
SA_EVENTS = ['Central apnea', 'Hypopnea', 'Obstructive apnea']
def parse_sleep_annotations(annotation_path):
"""解析睡眠分期注释"""
try:
tree = ET.parse(annotation_path)
root = tree.getroot()
events = []
for scored_event in root.findall('.//ScoredEvent'):
event_type = scored_event.find('EventType').text
if event_type != "Stages|Stages":
continue
description = scored_event.find('EventConcept').text
start = float(scored_event.find('Start').text)
duration = float(scored_event.find('Duration').text)
if description not in ANNOTATION_MAP:
continue
events.append({
'onset': start,
'duration': duration,
'description': description,
'stage': ANNOTATION_MAP[description]
})
return events
except Exception as e:
return None
def extract_osa_events(annotation_path):
"""提取睡眠呼吸暂停事件"""
try:
tree = ET.parse(annotation_path)
root = tree.getroot()
events = []
for scored_event in root.findall('.//ScoredEvent'):
event_concept = scored_event.find('EventConcept').text
event_type = event_concept.split('|')[0].strip()
if event_type in SA_EVENTS:
start = float(scored_event.find('Start').text)
duration = float(scored_event.find('Duration').text)
if duration >= 10:
events.append({
'start': start,
'duration': duration,
'end': start + duration,
'type': event_type
})
return events
except Exception as e:
return []

View File

@ -52,7 +52,7 @@ def psg_effort_filter(conf, effort_data_raw, effort_fs):
high_cut=conf["effort_filter"]["high_cut"], order=conf["effort_filter"]["order"], high_cut=conf["effort_filter"]["high_cut"], order=conf["effort_filter"]["order"],
sample_rate=effort_fs) sample_rate=effort_fs)
# 移动平均 # 移动平均
effort_data_2 = utils.average_filter(raw_data=effort_data_1, sample_rate=effort_fs, window_size_sec=20) effort_data_2 = utils.average_filter(raw_data=effort_data_1, sample_rate=effort_fs, window_size_sec=conf["average_filter"]["window_size_sec"])
return effort_data_raw, effort_data_2, effort_fs return effort_data_raw, effort_data_2, effort_fs