DataPrepare/dataset_tools/shhs_annotations_check.py

100 lines
3.8 KiB
Python

import argparse
from pathlib import Path
from lxml import etree
from tqdm import tqdm
from collections import Counter
def main():
# 设定目标文件夹路径,你可以修改这里的路径,或者运行脚本时手动输入
# 默认为当前目录 '.'
# target_dir = "/mnt/disk_wd/marques_dataset/shhs/polysomnography/annotations-events-nsrr/shhs1"
target_dir = "/mnt/disk_wd/marques_dataset/shhs/polysomnography/annotations-events-nsrr/shhs2"
folder_path = Path(target_dir)
if not folder_path.exists():
print(f"错误: 路径 '{folder_path}' 不存在。")
return
# 1. 获取所有 XML 文件 (扁平结构,不递归子目录)
xml_files = list(folder_path.glob("*.xml"))
total_files = len(xml_files)
if total_files == 0:
print(f"'{folder_path}' 中没有找到 XML 文件。")
return
print(f"找到 {total_files} 个 XML 文件,准备开始处理...")
# 用于统计 (EventType, EventConcept) 组合的计数器
stats_counter = Counter()
# 2. 遍历文件,使用 tqdm 显示进度条
for xml_file in tqdm(xml_files, desc="Processing XMLs", unit="file"):
try:
# 使用 lxml 解析
tree = etree.parse(str(xml_file))
root = tree.getroot()
# 3. 定位到 ScoredEvent 节点
# SHHS XML 结构通常是: PSGAnnotation -> ScoredEvents -> ScoredEvent
# 我们直接查找所有的 ScoredEvent 节点
events = root.findall(".//ScoredEvent")
for event in events:
# 提取 EventType
type_node = event.find("EventType")
# 处理节点不存在或文本为空的情况
e_type = type_node.text.strip() if (type_node is not None and type_node.text) else "N/A"
# 提取 EventConcept
concept_node = event.find("EventConcept")
e_concept = concept_node.text.strip() if (concept_node is not None and concept_node.text) else "N/A"
# 4. 组合并计数
# 组合键为元组 (EventType, EventConcept)
key = (e_type, e_concept)
stats_counter[key] += 1
except etree.XMLSyntaxError:
print(f"\n[警告] 文件格式错误,跳过: {xml_file.name}")
except Exception as e:
print(f"\n[错误] 处理文件 {xml_file.name} 时出错: {e}")
# 5. 打印结果到终端
if stats_counter:
# --- 动态计算列宽 ---
# 获取所有 EventType 的最大长度,默认长度 9
max_type_width = max((len(k[0]) for k in stats_counter.keys()), default=9)
max_type_width = max(max_type_width, 9)
# 获取所有 EventConcept 的最大长度,默认长度 12
max_conc_width = max((len(k[1]) for k in stats_counter.keys()), default=12)
max_conc_width = max(max_conc_width, 12)
# 计算表格总宽度
total_line_width = max_type_width + max_conc_width + 10 + 6
print("\n" + "=" * total_line_width)
print(f"{'EventType':<{max_type_width}} | {'EventConcept':<{max_conc_width}} | {'Count':>10}")
print("-" * total_line_width)
# --- 修改处:按名称排序 ---
# sorted() 默认会对元组 (EventType, EventConcept) 进行字典序排序
# 即先按 EventType A-Z 排序,再按 EventConcept A-Z 排序
for (e_type, e_concept), count in sorted(stats_counter.items()):
print(f"{e_type:<{max_type_width}} | {e_concept:<{max_conc_width}} | {count:>10}")
print("=" * total_line_width)
else:
print("\n未提取到任何事件数据。")
print("=" * 90)
print(f"统计完成。共扫描 {total_files} 个文件。")
if __name__ == "__main__":
main()