commit 8ee598090668efba97359f1bff09295c2f148d9d Author: marques Date: Tue Mar 24 21:15:05 2026 +0800 feat: Add utility functions for signal processing and event mapping - Created a new module `utils/__init__.py` to consolidate utility imports. - Added `event_map.py` for mapping apnea event types to numerical values and colors. - Implemented various filtering functions in `filter_func.py`, including Butterworth, Bessel, downsampling, and notch filters. - Developed `operation_tools.py` for dataset configuration loading, event mask generation, and signal processing utilities. - Introduced `split_method.py` for segmenting data based on movement and amplitude criteria. - Added `statistics_metrics.py` for calculating amplitude metrics and generating confusion matrices. - Included a new Excel file for additional data storage. diff --git a/AGENT_BACKGROUND_CN.md b/AGENT_BACKGROUND_CN.md new file mode 100644 index 0000000..b7608f3 --- /dev/null +++ b/AGENT_BACKGROUND_CN.md @@ -0,0 +1,470 @@ +# 项目背景文档(供 Agent 快速上手) + +## 1. 文档目的 + +这份文档用于帮助新进入仓库的 agent 快速建立项目心智模型,重点回答下面几个问题: + +- 这个仓库在做什么 +- 核心处理链路是什么 +- 代码从哪里开始看 +- 输入数据长什么样,输出产物长什么样 +- 哪些模块已经可用,哪些模块还只是占位或半成品 + +--- + +## 2. 一句话理解项目 + +`DataPrepare` 是一个睡眠呼吸暂停数据预处理仓库,主要用于把原始的 BCG / PSG 信号与呼吸暂停标签整理成可训练的数据集。 + +当前主线可以概括为两步: + +1. 在 `event_mask_process/` 中把原始标注和规则检测结果整理成逐秒标签掩码。 +2. 在 `dataset_builder/` 中根据这些掩码切出固定长度窗口,并将信号和窗口索引保存为 `npz` 数据集。 + +--- + +## 3. 仓库总览 + +| 目录 | 作用 | 备注 | +| --- | --- | --- | +| `dataset_config/` | 各数据集的 YAML 配置 | 包含路径、采样率、阈值、切窗参数 | +| `event_mask_process/` | 生成逐秒标签掩码 | 主入口之一 | +| `dataset_builder/` | 将信号切成窗口并保存数据集 | 主入口之二 | +| `signal_method/` | 信号预处理、规则检测、特征计算 | 规则逻辑核心 | +| `utils/` | 读文件、标签转换、切窗、滤波、通用工具 | 底层支撑 | +| `draw_tools/` | 全夜信号图、分段图、统计图 | 调试与可视化 | +| `dataset_tools/` | 辅助脚本 | 包括配对拷贝、SHHS 标注检查 | +| `output/` | 中间结果样例 | 仓库内已有若干处理结果示例 | + +--- + +## 4. 项目主线流程 + +### 4.1 HYS(BCG 主线) + +对应文件: + +- `event_mask_process/HYS_process.py` +- `dataset_builder/HYS_dataset.py` + +处理过程: + +1. 从 `root_path/OrgBCG_Aligned//OrgBCG_Sync_*.txt` 读取原始 BCG 同步信号。 +2. 从 `root_path/Label//SA Label_corrected.csv` 读取人工修正后的呼吸暂停标签。 +3. 对原始 BCG 进行 50Hz 陷波,再拆出: + - 呼吸分量 `resp_data` + - BCG 分量 `bcg_data` +4. 根据配置做规则检测,生成逐秒掩码: + - `Resp_LowAmp_Label` + - `Resp_Movement_Label` + - `Resp_AmpChange_Label` + - `BCG_LowAmp_Label` + - `BCG_Movement_Label` + - `BCG_AmpChange_Label` + - `Disable_Label`(来自 `排除区间.xlsx`) + - `SA_Label` / `SA_Score`(来自人工标签) +5. 将这些结果保存为 `output/HYS//_Processed_Labels.csv`。 +6. 在数据集构建阶段读取上述逐秒标签,按 `window_sec` 和 `stride_sec` 切成窗口。 +7. 保存处理后的信号 `Signals/*.npz`、窗口索引 `Segments_List/*.npz`、标签副本 `Labels/*.csv`。 + +### 4.2 HYS_PSG(PSG 主线) + +对应文件: + +- `event_mask_process/HYS_PSG_process.py` +- `dataset_builder/HYS_PSG_dataset.py` + +处理过程: + +1. 从 `root_path/PSG_Aligned//` 读取 PSG 通道,包括: + - `Rpeak` + - `ECG_Sync` + - `Effort Tho` + - `Effort Abd` + - `Flow P` + - `Flow T` + - `SpO2` + - `5_class` +2. 从同目录读取 `SA Label_Sync.csv`。 +3. 在 `HYS_PSG_process.py` 中生成较简化的逐秒标签: + - `SA_Label` + - `SA_Score` + - `Disable_Label`(主要由睡眠分期里的清醒期生成) + - 其余 `Resp_*` / `BCG_*` 掩码目前全部置零 +4. 在 `HYS_PSG_dataset.py` 中对胸腹带、流量、SpO2、RRI 做统一重采样与长度对齐。 +5. 对 SpO2 进行异常填补或置空,再切成窗口,保存为 PSG 数据集。 + +### 4.3 设计上的共性 + +无论是 HYS 还是 HYS_PSG,核心设计都是: + +1. 先把整夜记录整理成逐秒掩码。 +2. 再根据掩码和切窗规则提取训练片段。 + +这意味着以后如果要改“可用性判断”或“切窗逻辑”,优先看: + +- `event_mask_process/` +- `utils/split_method.py` + +--- + +## 5. 数据与目录约定 + +### 5.1 外部数据目录 + +这个仓库本身不包含原始数据,真正的数据目录由 YAML 中的绝对路径指定,例如: + +- `/mnt/disk_wd/marques_dataset/DataCombine2023/HYS` +- `/mnt/disk_wd/marques_dataset/shhs/polysomnography/shhs1` + +因此在新环境运行时,第一件事通常不是改代码,而是先改 `dataset_config/*.yaml` 里的绝对路径。 + +### 5.2 HYS 数据目录约定 + +代码默认外部目录至少包含: + +- `OrgBCG_Aligned//OrgBCG_Sync_.txt` +- `PSG_Aligned//...` +- `Label//SA Label_corrected.csv` + +### 5.3 HYS_PSG 数据目录约定 + +代码默认 `PSG_Aligned//` 内文件命名符合下面的模式: + +- `Rpeak*.txt` +- `ECG_Sync*.txt` +- `Effort Tho*.txt` +- `Effort Abd*.txt` +- `Flow P*.txt` +- `Flow T*.txt` +- `SpO2*.txt` +- `5_class*.txt` +- `SA Label_Sync.csv` + +### 5.4 处理中间结果目录 + +仓库内 `output/` 保存的是“中间标签与图像结果”,不是最终训练数据集。 +最终数据集通常会写到 YAML 中 `dataset_save_path` 指向的外部目录。 + +--- + +## 6. 核心入口脚本 + +### 6.1 主入口 + +| 场景 | 脚本 | +| --- | --- | +| HYS 逐秒标签生成 | `event_mask_process/HYS_process.py` | +| HYS 数据集构建 | `dataset_builder/HYS_dataset.py` | +| HYS_PSG 逐秒标签生成 | `event_mask_process/HYS_PSG_process.py` | +| HYS_PSG 数据集构建 | `dataset_builder/HYS_PSG_dataset.py` | + +### 6.2 辅助脚本 + +| 脚本 | 作用 | +| --- | --- | +| `dataset_tools/resp_pair_copy.py` | 将 HYS/ZD5Y 的 BCG 与 PSG 原始文件拷贝为配对数据集 | +| `dataset_tools/shhs_annotations_check.py` | 统计 SHHS XML 标注中的事件类型组合 | +| `event_mask_process/SHHS1_process.py` | SHHS1 处理入口占位,目前未实现 | + +### 6.3 入口脚本的共同特点 + +- 大多数脚本没有命令行参数接口。 +- 配置文件路径通常直接写在 `if __name__ == '__main__':` 中。 +- 运行逻辑依赖 YAML 中的绝对路径。 + +所以如果 agent 要“跑脚本”,通常需要先确认: + +1. 当前机器是否存在对应外部数据目录。 +2. YAML 路径是否匹配当前环境。 + +--- + +## 7. 关键模块说明 + +### 7.1 `utils/` + +#### `utils/HYS_FileReader.py` + +负责各种输入文件读取,是仓库最重要的底层模块之一: + +- `read_signal_txt`:读取单通道 txt,并根据文件名中的采样率推断 `fs` +- `read_label_csv`:读取人工修正的 HYS 标签 +- `read_raw_psg_label`:读取 PSG 原始同步标签 +- `read_disable_excel`:读取 `排除区间.xlsx` +- `read_mask_execl`:读取处理后的逐秒标签 CSV,并生成事件片段列表 +- `read_psg_channel`:按通道名读取 PSG 文件夹里的多通道数据 + +#### `utils/operation_tools.py` + +负责标签转换、片段提取和通用处理: + +- `load_dataset_conf`:读取 YAML +- `generate_event_mask`:把事件表转换成逐秒 `SA_Label` +- `generate_disable_mask`:把 Excel 中的排除区间转换成逐秒 `Disable_Label` +- `event_mask_2_list`:把 0/1 掩码转为 `[start, end]` 列表 +- `merge_short_gaps` / `remove_short_durations`:对逐秒掩码做时长后处理 +- `fill_spo2_anomaly`:修补 SpO2 异常段 + +#### `utils/split_method.py` + +真正的切窗规则在这里: + +- 默认按 `window_sec` / `stride_sec` 滑窗 +- 只在 `EnableSegment` 内生成可用窗口 +- 如果一个窗口中 `Resp_Movement_Label | Resp_LowAmp_Label` 超过窗口时长的 2/3,则该窗口转入 `disable_segment_list` + +#### `utils/filter_func.py` + +提供滤波和采样率处理: + +- Butterworth / Bessel +- 陷波 +- 整数倍降采样 +- 自动升降采样 +- 移动平均去趋势 + +### 7.2 `signal_method/` + +#### `signal_method/signal_process.py` + +负责信号预处理: + +- `signal_filter_split`:把原始 OrgBCG 信号拆成呼吸分量和 BCG 分量 +- `psg_effort_filter`:处理 PSG 努力带 / 流量信号 +- `rpeak2hr` / `rpeak2rri_interpolation`:由 R 峰生成 HR / RRI + +#### `signal_method/rule_base_event.py` + +规则检测主逻辑: + +- `detect_movement`:基于滑窗标准差和局部幅值比较检测体动 +- `movement_revise`:对体动掩码做二次修正 +- `detect_low_amplitude_signal`:检测低幅值 +- `position_based_sleep_recognition_v2/v3`:根据体动前后幅值变化标记姿势/幅值变化段 + +#### `signal_method/normalize_method.py` + +`normalize_resp_signal_by_segment` 会按“可用片段”做分段 z-score 标准化。 +HYS 中通常按 `Resp_AmpChange_Label` 的反向片段来归一化,目的是减少整夜幅值漂移带来的影响。 + +### 7.3 `draw_tools/` + +主要用于人工检查处理质量: + +- `draw_signal_with_mask`:画 HYS 全夜原始信号 + 规则掩码 +- `draw_psg_signal`:画 HYS_PSG 全夜 PSG 信号 +- `draw_psg_label` / `draw_psg_bcg_label`:按窗口导出分段图 + +### 7.4 `dataset_builder/` + +核心职责是把“整夜记录 + 逐秒掩码”转换成训练样本: + +- 保存处理后的多通道信号到 `npz` +- 保存窗口起止列表到 `npz` +- 把标签 CSV 一并拷贝到数据集目录 + +--- + +## 8. 标签、通道和编码约定 + +### 8.1 呼吸暂停事件编码 + +定义于 `utils/event_map.py`: + +| 事件 | 编码 | +| --- | --- | +| `Hypopnea` | 1 | +| `Central apnea` | 2 | +| `Obstructive apnea` | 3 | +| `Mixed apnea` | 4 | + +### 8.2 PSG 通道编号映射 + +同样定义于 `utils/event_map.py`: + +| 编号 | 通道名 | +| --- | --- | +| 1 | `Rpeak` | +| 2 | `ECG_Sync` | +| 3 | `Effort Tho` | +| 4 | `Effort Abd` | +| 5 | `Flow P` | +| 6 | `Flow T` | +| 7 | `SpO2` | +| 8 | `5_class` | + +### 8.3 睡眠分期编码 + +`5_class` 在读取时会转成整数: + +| 分期 | 编码 | +| --- | --- | +| `N3` | 1 | +| `N2` | 2 | +| `N1` | 3 | +| `R` | 4 | +| `W` | 5 | + +### 8.4 逐秒标签 CSV 字段 + +`Processed_Labels.csv` 的核心字段为: + +- `Second` +- `SA_Label` +- `SA_Score` +- `Disable_Label` +- `Resp_LowAmp_Label` +- `Resp_Movement_Label` +- `Resp_AmpChange_Label` +- `BCG_LowAmp_Label` +- `BCG_Movement_Label` +- `BCG_AmpChange_Label` + +样例可见仓库内: + +- `output/HYS/220/220_Processed_Labels.csv` +- `output/HYS_PSG/220/220_Processed_Labels.csv` + +--- + +## 9. 主要配置文件 + +### 9.1 HYS + +`dataset_config/HYS_config.yaml` 主要控制: + +- 样本 ID 列表 +- 原始数据根目录 +- 中间标签保存目录 +- 呼吸 / BCG 的滤波和降采样参数 +- 低幅值、体动、幅值变化检测阈值 +- 数据集窗口长度和步长 + +### 9.2 HYS_PSG + +`dataset_config/HYS_PSG_config.yaml` 主要控制: + +- 样本 ID 列表 +- 目标统一采样率 `target_fs` +- 努力带 / 流量滤波参数 +- SpO2 异常填补参数 +- 数据集输出路径 + +### 9.3 其他配置 + +- `dataset_config/ZD5Y_config.yaml`:另一套 BCG 规则配置 +- `dataset_config/SHHS1_config.yaml`:SHHS1 预留配置 +- `dataset_config/RESP_PAIR_HYS_config.yaml`:HYS 配对原始数据拷贝 +- `dataset_config/RESP_PAIR_ZD5Y_config.yaml`:ZD5Y 配对原始数据拷贝 + +--- + +## 10. 输出产物说明 + +### 10.1 中间输出 + +典型位置: + +- `output/HYS//_Processed_Labels.csv` +- `output/HYS//_Signal_Plots.png` +- `output/HYS_PSG//_Processed_Labels.csv` +- `output/HYS_PSG//_Signal_Plots.png` +- `output/HYS_PSG//_Signal_Plots_fill.png` + +### 10.2 最终数据集输出 + +由 `dataset_builder/*` 保存到 YAML 指定的外部目录,结构通常是: + +- `Signals/` +- `Segments_List/` +- `Labels/` + +HYS 的 `Signals/*.npz` 里主要保存: + +- `bcg_signal_notch` +- `bcg_signal` +- `resp_signal` + +HYS_PSG 的 `Signals/*.npz` 里主要保存: + +- `Effort Tho` +- `Effort Abd` +- `Effort` +- `Flow P` +- `Flow T` +- `SpO2` +- `HR` +- `RRI` +- `5_class` + +`Segments_List/*.npz` 中主要保存: + +- `segment_list` +- `disable_segment_list` + +--- + +## 11. 推荐阅读顺序 + +如果 agent 是第一次接触这个仓库,建议按下面顺序阅读: + +1. `dataset_config/HYS_config.yaml` +2. `event_mask_process/HYS_process.py` +3. `utils/operation_tools.py` +4. `utils/split_method.py` +5. `dataset_builder/HYS_dataset.py` +6. `signal_method/signal_process.py` +7. `signal_method/rule_base_event.py` +8. `utils/HYS_FileReader.py` + +如果要看 PSG 主线,再继续读: + +1. `dataset_config/HYS_PSG_config.yaml` +2. `event_mask_process/HYS_PSG_process.py` +3. `dataset_builder/HYS_PSG_dataset.py` +4. `draw_tools/draw_label.py` + +--- + +## 12. 常见修改入口 + +如果你要改不同类型的问题,可以优先从这些文件入手: + +| 需求 | 优先看哪里 | +| --- | --- | +| 调整阈值、采样率、窗口长度 | `dataset_config/*.yaml` | +| 修改逐秒标签生成逻辑 | `event_mask_process/*.py` | +| 修改体动 / 低幅值 / 幅值变化规则 | `signal_method/rule_base_event.py` | +| 修改滤波与重采样 | `signal_method/signal_process.py`、`utils/filter_func.py` | +| 修改切窗规则 | `utils/split_method.py` | +| 修改输入文件解析 | `utils/HYS_FileReader.py` | +| 修改事件编码或通道映射 | `utils/event_map.py` | +| 修改图像输出样式 | `draw_tools/*.py` | + +--- + +## 13. 当前实现状态与注意事项 + +下面这些点对 agent 很重要: + +1. 仓库没有依赖清单文件(如 `requirements.txt` / `pyproject.toml`),依赖需要从源码导入中反推,当前至少涉及 `numpy`、`pandas`、`scipy`、`matplotlib`、`seaborn`、`yaml`、`tqdm`、`rich`、`polars`、`lxml`、`mne`。 +2. 大部分脚本依赖绝对路径,迁移环境时优先修改 YAML。 +3. `event_mask_process/SHHS1_process.py` 目前基本为空,占位多于实现。 +4. `event_mask_process/HYS_PSG_process.py` 当前是“简化版标签生成”,核心只用了 SA 标签和睡眠分期,`Resp_*` / `BCG_*` 掩码还没有真正实现。 +5. `dataset_builder/HYS_dataset.py` 里虽然保留了分段可视化入口,但实际绘图调用被注释掉了,因此默认不会导出分段图。 +6. `utils/signal_process.py` 和 `utils/filter_func.py` 有部分重复实现;当前实际被 `utils/__init__.py` 导出并广泛使用的是 `filter_func.py` 中的版本。 +7. `README.md` 目前非常简略,真正的项目逻辑主要还是要靠源码理解。 + +--- + +## 14. 对 Agent 最有价值的结论 + +如果只能记住几件事,请记住下面这些: + +1. 这是一个“先做逐秒掩码,再做固定窗口切片”的数据准备仓库。 +2. HYS 主线比 HYS_PSG 更完整,规则检测主要服务于 HYS。 +3. 切窗逻辑集中在 `utils/split_method.py`,不是分散在各个 builder 里。 +4. 原始数据不在仓库里,仓库只是代码和部分中间结果样例。 +5. 大多数运行问题都不是代码逻辑错,而是路径、文件命名、采样率和外部数据结构不匹配。 + diff --git a/README.md b/README.md new file mode 100644 index 0000000..ae8277a --- /dev/null +++ b/README.md @@ -0,0 +1,191 @@ +# DataPrepare + +`DataPrepare` 是一个面向睡眠呼吸暂停数据的预处理仓库,用于把原始 BCG / PSG 信号与事件标签整理成可训练的数据集。 + +当前仓库的主线是: + +1. 生成逐秒标签掩码 +2. 按固定窗口切分数据集 +3. 输出可视化结果用于人工检查 + +更详细的项目背景说明见 [AGENT_BACKGROUND_CN.md](./AGENT_BACKGROUND_CN.md)。 + +## 项目在做什么 + +仓库主要服务两类数据: + +- `HYS`:以 `OrgBCG` 为核心,结合人工修正的呼吸暂停标签,提取呼吸分量、BCG 分量并生成逐秒可用性掩码 +- `HYS_PSG`:以 PSG 多通道信号为核心,整理胸腹带、流量、SpO2、RRI、睡眠分期与同步呼吸暂停标签 + +两条主线都采用同一个设计: + +1. 先把整夜记录整理成逐秒标签 +2. 再根据标签切出固定长度窗口 + +## 仓库结构 + +| 目录 | 作用 | +| --- | --- | +| `dataset_config/` | 数据集配置文件,包含路径、采样率、阈值、切窗参数 | +| `event_mask_process/` | 逐秒标签掩码生成脚本 | +| `dataset_builder/` | 数据集切片与保存脚本 | +| `signal_method/` | 信号预处理、规则检测、特征计算 | +| `utils/` | 文件读取、标签转换、切窗、滤波等通用工具 | +| `draw_tools/` | 全夜图、分段图和统计图绘制 | +| `dataset_tools/` | 辅助脚本,如配对数据拷贝、SHHS 标注检查 | +| `output/` | 仓库内的中间结果样例 | + +## 核心流程 + +### HYS + +对应脚本: + +- `event_mask_process/HYS_process.py` +- `dataset_builder/HYS_dataset.py` + +流程概要: + +1. 读取 `OrgBCG_Aligned//OrgBCG_Sync_*.txt` +2. 读取 `Label//SA Label_corrected.csv` +3. 对原始信号做陷波、呼吸分量提取和 BCG 分量提取 +4. 检测低幅值、体动、幅值变化,并结合 `排除区间.xlsx` 生成逐秒标签 +5. 保存 `Processed_Labels.csv` +6. 根据 `window_sec` 与 `stride_sec` 切分训练窗口并保存 `npz` + +### HYS_PSG + +对应脚本: + +- `event_mask_process/HYS_PSG_process.py` +- `dataset_builder/HYS_PSG_dataset.py` + +流程概要: + +1. 读取 `PSG_Aligned//` 下的多通道信号 +2. 读取 `SA Label_Sync.csv` +3. 根据同步标签与睡眠分期生成逐秒标签 +4. 对努力带、流量、SpO2、RRI 做统一重采样和长度对齐 +5. 切分窗口并保存 PSG 数据集 + +## 快速开始 + +### 1. 先检查配置 + +所有主脚本都依赖 `dataset_config/*.yaml` 中的绝对路径。 +在新机器或新数据目录下,优先修改这些配置文件: + +- `dataset_config/HYS_config.yaml` +- `dataset_config/HYS_PSG_config.yaml` +- `dataset_config/ZD5Y_config.yaml` +- `dataset_config/SHHS1_config.yaml` + +重点检查: + +- `root_path` +- `mask_save_path` 或 `save_path` +- `dataset_save_path` +- `dataset_visual_path` +- `select_ids` + +### 2. 生成逐秒标签 + +HYS: + +```bash +python event_mask_process/HYS_process.py +``` + +HYS_PSG: + +```bash +python event_mask_process/HYS_PSG_process.py +``` + +执行后通常会在 `output/` 下生成: + +- `*_Processed_Labels.csv` +- `*_Signal_Plots.png` + +### 3. 构建数据集 + +HYS: + +```bash +python dataset_builder/HYS_dataset.py +``` + +HYS_PSG: + +```bash +python dataset_builder/HYS_PSG_dataset.py +``` + +输出目录由对应 YAML 中的 `dataset_save_path` 控制,通常包含: + +- `Signals/` +- `Segments_List/` +- `Labels/` + +### 4. 可视化与辅助工具 + +配对拷贝原始数据: + +```bash +python dataset_tools/resp_pair_copy.py +``` + +检查 SHHS XML 标注: + +```bash +python dataset_tools/shhs_annotations_check.py +``` + +## 输入与输出约定 + +### 输入 + +仓库本身不包含原始数据,原始数据目录由 YAML 指定。代码默认外部数据目录中至少存在: + +- `OrgBCG_Aligned//OrgBCG_Sync_*.txt` +- `PSG_Aligned//...` +- `Label//SA Label_corrected.csv` +- `PSG_Aligned//SA Label_Sync.csv` + +### 中间输出 + +仓库内 `output/` 保存的是中间标签与图像样例,不是最终训练数据集。例如: + +- `output/HYS//_Processed_Labels.csv` +- `output/HYS_PSG//_Processed_Labels.csv` + +### 最终输出 + +由 `dataset_builder/` 写入 YAML 中配置的外部目录。典型结构为: + +- `Signals/*.npz` +- `Segments_List/*.npz` +- `Labels/*.csv` + +## 关键文件 + +如果你是第一次阅读这个仓库,推荐优先看: + +1. `dataset_config/HYS_config.yaml` +2. `event_mask_process/HYS_process.py` +3. `utils/operation_tools.py` +4. `utils/split_method.py` +5. `dataset_builder/HYS_dataset.py` +6. `signal_method/rule_base_event.py` + +## 当前状态与注意事项 + +- 仓库目前没有依赖清单文件,常见依赖包括 `numpy`、`pandas`、`scipy`、`matplotlib`、`seaborn`、`yaml`、`tqdm`、`rich`、`polars`、`lxml`、`mne` +- 大多数脚本没有命令行参数接口,配置文件路径直接写在脚本 `__main__` 中 +- `event_mask_process/SHHS1_process.py` 目前基本还是占位 +- `event_mask_process/HYS_PSG_process.py` 当前实现偏简化,`Resp_*` / `BCG_*` 掩码尚未真正展开 +- `output/` 里的文件更适合拿来理解格式与结果,不代表完整数据集 + +## 相关文档 + +- 详细项目背景:[AGENT_BACKGROUND_CN.md](./AGENT_BACKGROUND_CN.md) diff --git a/dataset_builder/HYS_PSG_dataset.py b/dataset_builder/HYS_PSG_dataset.py new file mode 100644 index 0000000..54e8d05 --- /dev/null +++ b/dataset_builder/HYS_PSG_dataset.py @@ -0,0 +1,395 @@ +import multiprocessing +import sys +from pathlib import Path + +import os +import numpy as np + +from utils import N2Chn + +os.environ['DISPLAY'] = "localhost:10.0" + +sys.path.append(str(Path(__file__).resolve().parent.parent)) +project_root_path = Path(__file__).resolve().parent.parent + +import utils +import signal_method +import draw_tools +import shutil +import gc + +def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None): + psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose) + + total_seconds = min( + psg_data[i]["second"] for i in N2Chn.values() if i != "Rpeak" + ) + for i in N2Chn.values(): + if i == "Rpeak": + continue + length = int(total_seconds * psg_data[i]["fs"]) + psg_data[i]["data"] = psg_data[i]["data"][:length] + psg_data[i]["length"] = length + psg_data[i]["second"] = total_seconds + + psg_data["HR"] = { + "name": "HR", + "data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"], + psg_data["Rpeak"]["fs"]), + "fs": psg_data["ECG_Sync"]["fs"], + "length": psg_data["ECG_Sync"]["length"], + "second": psg_data["ECG_Sync"]["second"] + } + # 预处理与滤波 + tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Tho"]["data"], effort_fs=psg_data["Effort Tho"]["fs"]) + abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Abd"]["data"], effort_fs=psg_data["Effort Abd"]["fs"]) + flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow P"]["data"], effort_fs=psg_data["Flow P"]["fs"]) + flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow T"]["data"], effort_fs=psg_data["Flow T"]["fs"]) + + rri, rri_fs = signal_method.rpeak2rri_interpolation(rpeak_indices=psg_data["Rpeak"]["data"], ecg_fs=psg_data["ECG_Sync"]["fs"], rri_fs=100) + + + mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv") + if verbose: + print(f"mask_excel_path: {mask_excel_path}") + + event_mask, event_list = utils.read_mask_execl(mask_excel_path) + + enable_list = [[0, psg_data["Effort Tho"]["second"]]] + normalized_tho_signal = signal_method.normalize_resp_signal_by_segment(tho_data_filt, tho_fs, np.zeros(psg_data["Effort Tho"]["second"]), enable_list) + normalized_abd_signal = signal_method.normalize_resp_signal_by_segment(abd_data_filt, abd_fs, np.zeros(psg_data["Effort Abd"]["second"]), enable_list) + normalized_flowp_signal = signal_method.normalize_resp_signal_by_segment(flowp_data_filt, flowp_fs, np.zeros(psg_data["Flow P"]["second"]), enable_list) + normalized_flowt_signal = signal_method.normalize_resp_signal_by_segment(flowt_data_filt, flowt_fs, np.zeros(psg_data["Flow T"]["second"]), enable_list) + + + # 都调整至100Hz采样率 + target_fs = conf["target_fs"] + normalized_tho_signal = utils.adjust_sample_rate(normalized_tho_signal, tho_fs, target_fs) + normalized_abd_signal = utils.adjust_sample_rate(normalized_abd_signal, abd_fs, target_fs) + normalized_flowp_signal = utils.adjust_sample_rate(normalized_flowp_signal, flowp_fs, target_fs) + normalized_flowt_signal = utils.adjust_sample_rate(normalized_flowt_signal, flowt_fs, target_fs) + spo2_data_filt = utils.adjust_sample_rate(psg_data["SpO2"]["data"], psg_data["SpO2"]["fs"], target_fs) + normalized_effort_signal = (normalized_tho_signal + normalized_abd_signal) / 2 + rri = utils.adjust_sample_rate(rri, rri_fs, target_fs) + + # 调整至相同长度 + min_length = min(len(normalized_tho_signal), len(normalized_abd_signal), len(normalized_flowp_signal), len(normalized_flowt_signal), len(spo2_data_filt), len(normalized_effort_signal) + ,len(rri)) + min_length = min_length - min_length % target_fs # 保证是整数秒 + normalized_tho_signal = normalized_tho_signal[:min_length] + normalized_abd_signal = normalized_abd_signal[:min_length] + normalized_flowp_signal = normalized_flowp_signal[:min_length] + normalized_flowt_signal = normalized_flowt_signal[:min_length] + spo2_data_filt = spo2_data_filt[:min_length] + normalized_effort_signal = normalized_effort_signal[:min_length] + rri = rri[:min_length] + + tho_second = min_length / target_fs + for i in event_mask.keys(): + event_mask[i] = event_mask[i][:int(tho_second)] + + spo2_data_filt_fill, spo2_disable_mask = utils.fill_spo2_anomaly(spo2_data=spo2_data_filt, + spo2_fs=target_fs, + max_fill_duration=conf["spo2_fill__anomaly"]["max_fill_duration"], + min_gap_duration=conf["spo2_fill__anomaly"]["min_gap_duration"]) + + + draw_tools.draw_psg_signal( + samp_id=samp_id, + tho_signal=normalized_tho_signal, + abd_signal=normalized_abd_signal, + flowp_signal=normalized_flowp_signal, + flowt_signal=normalized_flowt_signal, + spo2_signal=spo2_data_filt, + effort_signal=normalized_effort_signal, + rri_signal = rri, + fs=target_fs, + event_mask=event_mask["SA_Label"], + save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots.png", + show=show + ) + + draw_tools.draw_psg_signal( + samp_id=samp_id, + tho_signal=normalized_tho_signal, + abd_signal=normalized_abd_signal, + flowp_signal=normalized_flowp_signal, + flowt_signal=normalized_flowt_signal, + spo2_signal=spo2_data_filt_fill, + effort_signal=normalized_effort_signal, + rri_signal = rri, + fs=target_fs, + event_mask=event_mask["SA_Label"], + save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots_fill.png", + show=show + ) + + spo2_disable_mask = spo2_disable_mask[::target_fs] + min_len = min(len(event_mask["Disable_Label"]), len(spo2_disable_mask)) + + if len(event_mask["Disable_Label"]) != len(spo2_disable_mask): + print(f"Warning: Data length mismatch! Truncating to {min_len}.") + event_mask["Disable_Label"] = event_mask["Disable_Label"][:min_len] & spo2_disable_mask[:min_len] + + event_list = { + "EnableSegment": utils.event_mask_2_list(1 - event_mask["Disable_Label"]), + "DisableSegment": utils.event_mask_2_list(event_mask["Disable_Label"])} + + spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=conf["spo2_fill__anomaly"]["nan_to_num_value"]) + + segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose) + if verbose: + print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}") + + + # 复制mask到processed_Labels文件夹 + save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv" + shutil.copyfile(mask_excel_path, save_mask_excel_path) + + # 复制SA Label_corrected.csv到processed_Labels文件夹 + sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_Sync.csv") + if sa_label_corrected_path.exists(): + save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv" + shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path) + else: + if verbose: + print(f"Warning: {sa_label_corrected_path} does not exist.") + + # 保存处理后的信号和截取的片段列表 + save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz" + save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz" + + # psg_data更新为处理后的信号 + # 用下划线替换键里面的空格 + psg_data = { + "Effort Tho": { + "name": "Effort_Tho", + "data": normalized_tho_signal, + "fs": target_fs, + "length": len(normalized_tho_signal), + "second": len(normalized_tho_signal) / target_fs + }, + "Effort Abd": { + "name": "Effort_Abd", + "data": normalized_abd_signal, + "fs": target_fs, + "length": len(normalized_abd_signal), + "second": len(normalized_abd_signal) / target_fs + }, + "Effort": { + "name": "Effort", + "data": normalized_effort_signal, + "fs": target_fs, + "length": len(normalized_effort_signal), + "second": len(normalized_effort_signal) / target_fs + }, + "Flow P": { + "name": "Flow_P", + "data": normalized_flowp_signal, + "fs": target_fs, + "length": len(normalized_flowp_signal), + "second": len(normalized_flowp_signal) / target_fs + }, + "Flow T": { + "name": "Flow_T", + "data": normalized_flowt_signal, + "fs": target_fs, + "length": len(normalized_flowt_signal), + "second": len(normalized_flowt_signal) / target_fs + }, + "SpO2": { + "name": "SpO2", + "data": spo2_data_filt_fill, + "fs": target_fs, + "length": len(spo2_data_filt_fill), + "second": len(spo2_data_filt_fill) / target_fs + }, + "HR": { + "name": "HR", + "data": psg_data["HR"]["data"], + "fs": psg_data["HR"]["fs"], + "length": psg_data["HR"]["length"], + "second": psg_data["HR"]["second"] + }, + "RRI": { + "name": "RRI", + "data": rri, + "fs": target_fs, + "length": len(rri), + "second": len(rri) / target_fs + }, + "5_class": { + "name": "Stage", + "data": psg_data["5_class"]["data"], + "fs": psg_data["5_class"]["fs"], + "length": psg_data["5_class"]["length"], + "second": psg_data["5_class"]["second"] + } + } + + np.savez_compressed(save_signal_path, **psg_data) + np.savez_compressed(save_segment_path, + segment_list=segment_list, + disable_segment_list=disable_segment_list) + if verbose: + print(f"Saved processed signals to: {save_signal_path}") + print(f"Saved segment list to: {save_segment_path}") + + if draw_segment: + total_len = len(segment_list) + len(disable_segment_list) + if verbose: + print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}") + + + draw_tools.draw_psg_label( + psg_data=psg_data, + psg_label=event_mask["SA_Label"], + segment_list=segment_list, + save_path=visual_path / f"{samp_id}" / "enable", + verbose=verbose, + multi_p=multi_p, + multi_task_id=multi_task_id + ) + + draw_tools.draw_psg_label( + psg_data=psg_data, + psg_label=event_mask["SA_Label"], + segment_list=disable_segment_list, + save_path=visual_path / f"{samp_id}" / "disable", + verbose=verbose, + multi_p=multi_p, + multi_task_id=multi_task_id + ) + + # 显式删除大型对象 + try: + del psg_data + del normalized_tho_signal, normalized_abd_signal + del normalized_flowp_signal, normalized_flowt_signal + del normalized_effort_signal + del spo2_data_filt, spo2_data_filt_fill + del rri + del event_mask, event_list + del segment_list, disable_segment_list + except: + pass + + # 强制垃圾回收 + gc.collect() + + + +def multiprocess_entry(_progress, task_id, _id): + build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id) + + +def multiprocess_with_tqdm(args_list, n_processes): + from concurrent.futures import ProcessPoolExecutor + from rich import progress + + with progress.Progress( + "[progress.description]{task.description}", + progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + progress.MofNCompleteColumn(), + progress.TimeRemainingColumn(), + progress.TimeElapsedColumn(), + refresh_per_second=1, # bit slower updates + transient=False + ) as progress: + futures = [] + with multiprocessing.Manager() as manager: + _progress = manager.dict() + overall_progress_task = progress.add_task("[green]All jobs progress:") + with ProcessPoolExecutor(max_workers=n_processes, mp_context=multiprocessing.get_context("spawn")) as executor: + for i_args in range(len(args_list)): + args = args_list[i_args] + task_id = progress.add_task(f"task {i_args}", visible=True) + futures.append(executor.submit(multiprocess_entry, _progress, task_id, args_list[i_args])) + # monitor the progress: + while (n_finished := sum([future.done() for future in futures])) < len( + futures + ): + progress.update( + overall_progress_task, completed=n_finished, total=len(futures) + ) + for task_id, update_data in _progress.items(): + desc = update_data.get("desc", "") + # update the progress bar for this task: + progress.update( + task_id, + completed=update_data.get("progress", 0), + total=update_data.get("total", 0), + description=desc + ) + + # raise any errors: + for future in futures: + future.result() + + +def multiprocess_with_pool(args_list, n_processes): + """使用Pool,每个worker处理固定数量任务后重启""" + from multiprocessing import Pool + + # maxtasksperchild 设置每个worker处理多少任务后重启(释放内存) + with Pool(processes=n_processes, maxtasksperchild=2) as pool: + results = [] + for samp_id in args_list: + result = pool.apply_async( + build_HYS_dataset_segment, + args=(samp_id, False, True, False, None, None) + ) + results.append(result) + + # 等待所有任务完成 + for i, result in enumerate(results): + try: + result.get() + print(f"Completed {i + 1}/{len(args_list)}") + except Exception as e: + print(f"Error processing task {i}: {e}") + + pool.close() + pool.join() + + +if __name__ == '__main__': + yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml" + + conf = utils.load_dataset_conf(yaml_path) + select_ids = conf["select_ids"] + root_path = Path(conf["root_path"]) + mask_path = Path(conf["mask_save_path"]) + save_path = Path(conf["dataset_config"]["dataset_save_path"]) + visual_path = Path(conf["dataset_config"]["dataset_visual_path"]) + dataset_config = conf["dataset_config"] + + visual_path.mkdir(parents=True, exist_ok=True) + + save_processed_signal_path = save_path / "Signals" + save_processed_signal_path.mkdir(parents=True, exist_ok=True) + + save_segment_list_path = save_path / "Segments_List" + save_segment_list_path.mkdir(parents=True, exist_ok=True) + + save_processed_label_path = save_path / "Labels" + save_processed_label_path.mkdir(parents=True, exist_ok=True) + + # print(f"select_ids: {select_ids}") + # print(f"root_path: {root_path}") + # print(f"save_path: {save_path}") + # print(f"visual_path: {visual_path}") + + org_signal_root_path = root_path / "OrgBCG_Aligned" + psg_signal_root_path = root_path / "PSG_Aligned" + print(select_ids) + + # build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True) + + # for samp_id in select_ids: + # print(f"Processing sample ID: {samp_id}") + # build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) + + # multiprocess_with_tqdm(args_list=select_ids, n_processes=8) + multiprocess_with_pool(args_list=select_ids, n_processes=8) \ No newline at end of file diff --git a/dataset_builder/HYS_dataset.py b/dataset_builder/HYS_dataset.py new file mode 100644 index 0000000..a99c854 --- /dev/null +++ b/dataset_builder/HYS_dataset.py @@ -0,0 +1,233 @@ +import multiprocessing +import sys +from pathlib import Path + +import os +import numpy as np + +os.environ['DISPLAY'] = "localhost:10.0" + +sys.path.append(str(Path(__file__).resolve().parent.parent)) +project_root_path = Path(__file__).resolve().parent.parent + +import utils +import signal_method +import draw_tools +import shutil + + +def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None): + signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) + if not signal_path: + raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}") + signal_path = signal_path[0] + if verbose: + print(f"Processing OrgBCG_Sync signal file: {signal_path}") + + mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv") + if verbose: + print(f"mask_excel_path: {mask_excel_path}") + + event_mask, event_list = utils.read_mask_execl(mask_excel_path) + + bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=verbose) + + bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs, verbose=verbose) + normalized_resp_signal = signal_method.normalize_resp_signal_by_segment(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"]) + + + # 如果signal_data采样率过,进行降采样 + if signal_fs == 1000: + bcg_signal_notch = utils.downsample_signal_fast(original_signal=bcg_signal_notch, original_fs=signal_fs, target_fs=100) + bcg_signal_raw = utils.downsample_signal_fast(original_signal=bcg_signal_raw, original_fs=signal_fs, + target_fs=100) + signal_fs = 100 + + if bcg_fs == 1000: + bcg_signal = utils.downsample_signal_fast(original_signal=bcg_signal, original_fs=bcg_fs, target_fs=100) + bcg_fs = 100 + + # draw_tools.draw_signal_with_mask(samp_id=samp_id, + # signal_data=resp_signal, + # signal_fs=resp_fs, + # resp_data=normalized_resp_signal, + # resp_fs=resp_fs, + # bcg_data=bcg_signal, + # bcg_fs=bcg_fs, + # signal_disable_mask=event_mask["Disable_Label"], + # resp_low_amp_mask=event_mask["Resp_LowAmp_Label"], + # resp_movement_mask=event_mask["Resp_Movement_Label"], + # resp_change_mask=event_mask["Resp_AmpChange_Label"], + # resp_sa_mask=event_mask["SA_Label"], + # bcg_low_amp_mask=event_mask["BCG_LowAmp_Label"], + # bcg_movement_mask=event_mask["BCG_Movement_Label"], + # bcg_change_mask=event_mask["BCG_AmpChange_Label"], + # show=show, + # save_path=None) + + segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose) + if verbose: + print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}") + + + # 复制mask到processed_Labels文件夹 + save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv" + shutil.copyfile(mask_excel_path, save_mask_excel_path) + + # 复制SA Label_corrected.csv到processed_Labels文件夹 + sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_corrected.csv") + if sa_label_corrected_path.exists(): + save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv" + shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path) + else: + if verbose: + print(f"Warning: {sa_label_corrected_path} does not exist.") + + # 保存处理后的信号和截取的片段列表 + save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz" + save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz" + + bcg_data = { + "bcg_signal_notch": { + "name": "BCG_Signal_Notch", + "data": bcg_signal_notch, + "fs": signal_fs, + "length": len(bcg_signal_notch), + "second": len(bcg_signal_notch) // signal_fs + }, + "bcg_signal":{ + "name": "BCG_Signal_Raw", + "data": bcg_signal, + "fs": bcg_fs, + "length": len(bcg_signal), + "second": len(bcg_signal) // bcg_fs + }, + "resp_signal": { + "name": "Resp_Signal", + "data": normalized_resp_signal, + "fs": resp_fs, + "length": len(normalized_resp_signal), + "second": len(normalized_resp_signal) // resp_fs + } + } + + np.savez_compressed(save_signal_path, **bcg_data) + np.savez_compressed(save_segment_path, + segment_list=segment_list, + disable_segment_list=disable_segment_list) + if verbose: + print(f"Saved processed signals to: {save_signal_path}") + print(f"Saved segment list to: {save_segment_path}") + + if draw_segment: + psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose) + psg_data["HR"] = { + "name": "HR", + "data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"], psg_data["Rpeak"]["fs"]), + "fs": psg_data["ECG_Sync"]["fs"], + "length": psg_data["ECG_Sync"]["length"], + "second": psg_data["ECG_Sync"]["second"] + } + + + psg_label = utils.read_psg_label(sa_label_corrected_path, verbose=verbose) + psg_event_mask, _ = utils.generate_event_mask(event_df=psg_label, signal_second=psg_data["ECG_Sync"]["second"], use_correct=False) + total_len = len(segment_list) + len(disable_segment_list) + if verbose: + print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}") + + # draw_tools.draw_psg_bcg_label(psg_data=psg_data, + # psg_label=psg_event_mask, + # bcg_data=bcg_data, + # event_mask=event_mask, + # segment_list=segment_list, + # save_path=visual_path / f"{samp_id}" / "enable", + # verbose=verbose, + # multi_p=multi_p, + # multi_task_id=multi_task_id + # ) + # + # draw_tools.draw_psg_bcg_label( + # psg_data=psg_data, + # psg_label=psg_event_mask, + # bcg_data=bcg_data, + # event_mask=event_mask, + # segment_list=disable_segment_list, + # save_path=visual_path / f"{samp_id}" / "disable", + # verbose=verbose, + # multi_p=multi_p, + # multi_task_id=multi_task_id + # ) + + + +def multiprocess_entry(_progress, task_id, _id): + build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id) + + +def multiprocess_with_pool(args_list, n_processes): + """使用Pool,每个worker处理固定数量任务后重启""" + from multiprocessing import Pool + + # maxtasksperchild 设置每个worker处理多少任务后重启(释放内存) + with Pool(processes=n_processes, maxtasksperchild=2) as pool: + results = [] + for samp_id in args_list: + result = pool.apply_async( + build_HYS_dataset_segment, + args=(samp_id, False, True, False, None, None) + ) + results.append(result) + + # 等待所有任务完成 + for i, result in enumerate(results): + try: + result.get() + print(f"Completed {i + 1}/{len(args_list)}") + except Exception as e: + print(f"Error processing task {i}: {e}") + + pool.close() + pool.join() + + + +if __name__ == '__main__': + yaml_path = project_root_path / "dataset_config/HYS_config.yaml" + + conf = utils.load_dataset_conf(yaml_path) + select_ids = conf["select_ids"] + root_path = Path(conf["root_path"]) + mask_path = Path(conf["mask_save_path"]) + save_path = Path(conf["dataset_config"]["dataset_save_path"]) + visual_path = Path(conf["dataset_config"]["dataset_visual_path"]) + dataset_config = conf["dataset_config"] + + visual_path.mkdir(parents=True, exist_ok=True) + + save_processed_signal_path = save_path / "Signals" + save_processed_signal_path.mkdir(parents=True, exist_ok=True) + + save_segment_list_path = save_path / "Segments_List" + save_segment_list_path.mkdir(parents=True, exist_ok=True) + + save_processed_label_path = save_path / "Labels" + save_processed_label_path.mkdir(parents=True, exist_ok=True) + + # print(f"select_ids: {select_ids}") + # print(f"root_path: {root_path}") + # print(f"save_path: {save_path}") + # print(f"visual_path: {visual_path}") + + org_signal_root_path = root_path / "OrgBCG_Aligned" + psg_signal_root_path = root_path / "PSG_Aligned" + print(select_ids) + + # build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True) + + # for samp_id in select_ids: + # print(f"Processing sample ID: {samp_id}") + # build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) + + # multiprocess_with_tqdm(args_list=select_ids, n_processes=16) + multiprocess_with_pool(args_list=select_ids, n_processes=16) \ No newline at end of file diff --git a/dataset_builder/__init__.py b/dataset_builder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataset_config/HYS_PSG_config.yaml b/dataset_config/HYS_PSG_config.yaml new file mode 100644 index 0000000..6496a72 --- /dev/null +++ b/dataset_config/HYS_PSG_config.yaml @@ -0,0 +1,139 @@ +select_ids: + - 54 + - 88 + - 220 + - 221 + - 229 + - 282 + - 286 + - 541 + - 579 + - 582 + - 670 + - 671 + - 683 + - 684 + - 735 + - 933 + - 935 + - 950 + - 952 + - 960 + - 962 + - 967 + - 1302 + +root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS +mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG + +target_fs: 100 + +dataset_config: + window_sec: 180 + stride_sec: 60 + dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_PSG_dataset + dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_PSG_dataset/visualization + + +effort: + downsample_fs: 10 + +effort_filter: + filter_type: bandpass + low_cut: 0.05 + high_cut: 0.5 + order: 3 + +average_filter: + window_size_sec: 20 + +flow: + downsample_fs: 10 + +flow_filter: + filter_type: bandpass + low_cut: 0.05 + high_cut: 0.5 + order: 3 + +spo2_fill__anomaly: + max_fill_duration: 30 + min_gap_duration: 10 + nan_to_num_value: 95 + +#ecg: +# downsample_fs: 100 +# +#ecg_filter: +# filter_type: bandpass +# low_cut: 0.5 +# high_cut: 40 +# order: 5 + + +#resp: +# downsample_fs_1: None +# downsample_fs_2: 10 +# +#resp_filter: +# filter_type: bandpass +# low_cut: 0.05 +# high_cut: 0.5 +# order: 3 +# +#resp_low_amp: +# window_size_sec: 30 +# stride_sec: +# amplitude_threshold: 3 +# merge_gap_sec: 60 +# min_duration_sec: 60 +# +#resp_movement: +# window_size_sec: 20 +# stride_sec: 1 +# std_median_multiplier: 4 +# compare_intervals_sec: +# - 60 +# - 120 +## - 180 +# interval_multiplier: 3 +# merge_gap_sec: 30 +# min_duration_sec: 1 +# +#resp_movement_revise: +# up_interval_multiplier: 3 +# down_interval_multiplier: 2 +# compare_intervals_sec: 30 +# merge_gap_sec: 10 +# min_duration_sec: 1 +# +#resp_amp_change: +# mav_calc_window_sec: 4 +# threshold_amplitude: 0.25 +# threshold_energy: 0.4 +# +# +#bcg: +# downsample_fs: 100 +# +#bcg_filter: +# filter_type: bandpass +# low_cut: 1 +# high_cut: 10 +# order: 10 +# +#bcg_low_amp: +# window_size_sec: 1 +# stride_sec: +# amplitude_threshold: 8 +# merge_gap_sec: 20 +# min_duration_sec: 3 +# +# +#bcg_movement: +# window_size_sec: 2 +# stride_sec: +# merge_gap_sec: 20 +# min_duration_sec: 4 + + diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml new file mode 100644 index 0000000..43544ea --- /dev/null +++ b/dataset_config/HYS_config.yaml @@ -0,0 +1,86 @@ +select_ids: + - 1302 + - 286 + - 950 + - 220 + - 229 + - 541 + - 582 + - 670 + - 684 + - 960 + +root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS +mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS + +resp: + downsample_fs_1: 100 + downsample_fs_2: 10 + +resp_filter: + filter_type: bandpass + low_cut: 0.05 + high_cut: 0.6 + order: 3 + +resp_low_amp: + window_size_sec: 30 + stride_sec: + amplitude_threshold: 3 + merge_gap_sec: 60 + min_duration_sec: 60 + +resp_movement: + window_size_sec: 20 + stride_sec: 1 + std_median_multiplier: 4 + compare_intervals_sec: + - 60 + - 120 +# - 180 + interval_multiplier: 3 + merge_gap_sec: 30 + min_duration_sec: 1 + +resp_movement_revise: + up_interval_multiplier: 3 + down_interval_multiplier: 2 + compare_intervals_sec: 30 + merge_gap_sec: 10 + min_duration_sec: 1 + +resp_amp_change: + mav_calc_window_sec: 4 + threshold_amplitude: 0.25 + threshold_energy: 0.4 + + +bcg: + downsample_fs: 100 + +bcg_filter: + filter_type: bandpass + low_cut: 1 + high_cut: 10 + order: 10 + +bcg_low_amp: + window_size_sec: 1 + stride_sec: + amplitude_threshold: 8 + merge_gap_sec: 20 + min_duration_sec: 3 + + +bcg_movement: + window_size_sec: 2 + stride_sec: + merge_gap_sec: 20 + min_duration_sec: 4 + + +dataset_config: + window_sec: 180 + stride_sec: 60 + dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset + dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset/visualization diff --git a/dataset_config/RESP_PAIR_HYS_config.yaml b/dataset_config/RESP_PAIR_HYS_config.yaml new file mode 100644 index 0000000..f9bef0f --- /dev/null +++ b/dataset_config/RESP_PAIR_HYS_config.yaml @@ -0,0 +1,56 @@ +select_ids: + - 1000 + - 1004 + - 1006 + - 1009 + - 1010 + - 1300 + - 1301 + - 1302 + - 1308 + - 1314 + - 1354 + - 1374 + - 1378 + - 1478 + - 220 + - 221 + - 229 + - 282 + - 285 + - 286 + - 54 + - 541 + - 579 + - 582 + - 670 + - 671 + - 683 + - 684 + - 686 + - 703 + - 704 + - 726 + - 735 + - 736 + - 88 + - 893 + - 933 + - 935 + - 939 + - 950 + - 952 + - 954 + - 955 + - 956 + - 960 + - 961 + - 962 + - 967 + - 969 + - 971 + - 972 + + +root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS +pair_file_path: /mnt/disk_wd/marques_dataset/Resp_Pair_Dataset/HYS/Raw diff --git a/dataset_config/RESP_PAIR_ZD5Y_config.yaml b/dataset_config/RESP_PAIR_ZD5Y_config.yaml new file mode 100644 index 0000000..aaf0d16 --- /dev/null +++ b/dataset_config/RESP_PAIR_ZD5Y_config.yaml @@ -0,0 +1,32 @@ +select_ids: + - 3103 + - 3105 + - 3106 + - 3107 + - 3108 + - 3110 + - 3203 + - 3204 + - 3205 + - 3209 + - 3211 + - 3301 + - 3303 + - 3304 + - 3307 + - 3308 + - 3309 + - 3403 + - 3405 + - 3406 + - 3407 + - 3408 + - 3504 + + +root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/ZD5Y +pair_file_path: /mnt/disk_wd/marques_dataset/Resp_Pair_Dataset/ZD5Y/Raw + + + + diff --git a/dataset_config/SHHS1_config.yaml b/dataset_config/SHHS1_config.yaml new file mode 100644 index 0000000..0b57be4 --- /dev/null +++ b/dataset_config/SHHS1_config.yaml @@ -0,0 +1,41 @@ +root_path: /mnt/disk_wd/marques_dataset/shhs/polysomnography/shhs1 +mask_save_path: /mnt/disk_code/marques/dataprepare/output/shhs1 + +effort_target_fs: 10 +ecg_target_fs: 100 + + +dataset_config: + window_sec: 180 + stride_sec: 60 + dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/SHHS1_dataset + dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/SHHS1_dataset/visualization + + +effort: + downsample_fs: 10 + +effort_filter: + filter_type: bandpass + low_cut: 0.05 + high_cut: 0.5 + order: 3 + +average_filter: + window_size_sec: 20 + +flow: + downsample_fs: 10 + +flow_filter: + filter_type: bandpass + low_cut: 0.05 + high_cut: 0.5 + order: 3 + +spo2_fill__anomaly: + max_fill_duration: 30 + min_gap_duration: 10 + nan_to_num_value: 95 + + diff --git a/dataset_config/ZD5Y_config.yaml b/dataset_config/ZD5Y_config.yaml new file mode 100644 index 0000000..ff479fd --- /dev/null +++ b/dataset_config/ZD5Y_config.yaml @@ -0,0 +1,88 @@ +select_ids: + - 3103 + - 3105 + - 3106 + - 3107 + - 3108 + - 3110 + - 3203 + - 3204 + - 3205 + - 3207 + - 3208 + - 3209 + - 3212 + - 3301 + - 3303 + - 3307 + - 3403 + - 3504 + +root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/ZD5Y +save_path: /mnt/disk_code/marques/dataprepare/output/ZD5Y + +resp: + downsample_fs_1: 100 + downsample_fs_2: 10 + +resp_filter: + filter_type: bandpass + low_cut: 0.01 + high_cut: 0.7 + order: 3 + +resp_low_amp: + window_size_sec: 30 + stride_sec: + amplitude_threshold: 3 + merge_gap_sec: 60 + min_duration_sec: 60 + +resp_movement: + window_size_sec: 20 + stride_sec: 1 + std_median_multiplier: 5 + compare_intervals_sec: + - 60 + - 120 +# - 180 + interval_multiplier: 3.5 + merge_gap_sec: 30 + min_duration_sec: 1 + +resp_movement_revise: + up_interval_multiplier: 3 + down_interval_multiplier: 2 + compare_intervals_sec: 30 + merge_gap_sec: 10 + min_duration_sec: 1 + +resp_amp_change: + mav_calc_window_sec: 1 + threshold_amplitude: 0.25 + threshold_energy: 0.4 + + +bcg: + downsample_fs: 100 + +bcg_filter: + filter_type: bandpass + low_cut: 1 + high_cut: 10 + order: 10 + +bcg_low_amp: + window_size_sec: 1 + stride_sec: + amplitude_threshold: 8 + merge_gap_sec: 20 + min_duration_sec: 3 + + +bcg_movement: + window_size_sec: 2 + stride_sec: + merge_gap_sec: 20 + min_duration_sec: 4 + diff --git a/dataset_tools/resp_pair_copy.py b/dataset_tools/resp_pair_copy.py new file mode 100644 index 0000000..2014a2f --- /dev/null +++ b/dataset_tools/resp_pair_copy.py @@ -0,0 +1,67 @@ +from pathlib import Path +import sys +sys.path.append(str(Path(__file__).resolve().parent.parent)) +project_root_path = Path(__file__).resolve().parent.parent +import utils +import shutil + +def copy_one_resp_pair(one_id): + sync_type = "Sync" + + org_bcg_file_path = sync_bcg_path / f"{one_id}" + dest_bcg_file_path = pair_file_path / f"{one_id}" + dest_bcg_file_path.mkdir(parents=True, exist_ok=True) + if not list(org_bcg_file_path.glob("OrgBCG_Sync_*.txt")): + if not list(org_bcg_file_path.glob("OrgBCG_RoughCut_*.txt")): + print(f"No OrgBCG files found for ID {one_id}.") + return + else: + sync_type = "RoughCut" + print(f"Using RoughCut files for ID {one_id}.") + + + for file in org_bcg_file_path.glob(f"OrgBCG_{sync_type}_*.txt"): + shutil.copyfile(file, dest_bcg_file_path / f"{one_id}_{file.name}".replace("_RoughCut", "").replace("_Sync", "")) + psg_file_path = sync_psg_path / f"{one_id}" + dest_psg_file_path = pair_file_path / f"{one_id}" + dest_psg_file_path.mkdir(parents=True, exist_ok=True) + + # 检查上面的文件是否存在 + psg_file_patterns = [ + f"5_class_{sync_type}_*.txt", + f"Effort Abd_{sync_type}_*.txt", + f"Effort Tho_{sync_type}_*.txt", + f"Flow P_{sync_type}_*.txt", + f"Flow T_{sync_type}_*.txt", + f"SA Label_Sync.csv", + f"SpO2_{sync_type}_*.txt" + ] + for pattern in psg_file_patterns: + if not list(psg_file_path.glob(pattern)): + print(f"No PSG files found for ID {one_id} with pattern {pattern}.") + return + for pattern in psg_file_patterns: + for file in psg_file_path.glob(pattern): + shutil.copyfile(file, dest_psg_file_path / f"{one_id}_{file.name.replace('_RoughCut', '').replace('_Sync', '')}") + + + + + + + +if __name__ == '__main__': + yaml_path = project_root_path / "dataset_config/RESP_PAIR_ZD5Y_config.yaml" + + conf = utils.load_dataset_conf(yaml_path) + select_ids = conf["select_ids"] + root_path = Path(conf["root_path"]) + sync_bcg_path = root_path / "OrgBCG_Aligned" + sync_psg_path = root_path / "PSG_Aligned" + pair_file_path = Path(conf["pair_file_path"]) + + # copy_one_resp_pair(961) + + for samp_id in select_ids: + print(f"Processing {samp_id}...") + copy_one_resp_pair(samp_id) \ No newline at end of file diff --git a/dataset_tools/shhs_annotations_check.py b/dataset_tools/shhs_annotations_check.py new file mode 100644 index 0000000..40061c6 --- /dev/null +++ b/dataset_tools/shhs_annotations_check.py @@ -0,0 +1,100 @@ + +import argparse +from pathlib import Path +from lxml import etree +from tqdm import tqdm +from collections import Counter + + +def main(): + # 设定目标文件夹路径,你可以修改这里的路径,或者运行脚本时手动输入 + # 默认为当前目录 '.' + # target_dir = "/mnt/disk_wd/marques_dataset/shhs/polysomnography/annotations-events-nsrr/shhs1" + target_dir = "/mnt/disk_wd/marques_dataset/shhs/polysomnography/annotations-events-nsrr/shhs2" + + folder_path = Path(target_dir) + + if not folder_path.exists(): + print(f"错误: 路径 '{folder_path}' 不存在。") + return + + # 1. 获取所有 XML 文件 (扁平结构,不递归子目录) + xml_files = list(folder_path.glob("*.xml")) + total_files = len(xml_files) + + if total_files == 0: + print(f"在 '{folder_path}' 中没有找到 XML 文件。") + return + + print(f"找到 {total_files} 个 XML 文件,准备开始处理...") + + # 用于统计 (EventType, EventConcept) 组合的计数器 + stats_counter = Counter() + + # 2. 遍历文件,使用 tqdm 显示进度条 + for xml_file in tqdm(xml_files, desc="Processing XMLs", unit="file"): + try: + # 使用 lxml 解析 + tree = etree.parse(str(xml_file)) + root = tree.getroot() + + # 3. 定位到 ScoredEvent 节点 + # SHHS XML 结构通常是: PSGAnnotation -> ScoredEvents -> ScoredEvent + # 我们直接查找所有的 ScoredEvent 节点 + events = root.findall(".//ScoredEvent") + + for event in events: + # 提取 EventType + type_node = event.find("EventType") + # 处理节点不存在或文本为空的情况 + e_type = type_node.text.strip() if (type_node is not None and type_node.text) else "N/A" + + # 提取 EventConcept + concept_node = event.find("EventConcept") + e_concept = concept_node.text.strip() if (concept_node is not None and concept_node.text) else "N/A" + + # 4. 组合并计数 + # 组合键为元组 (EventType, EventConcept) + key = (e_type, e_concept) + stats_counter[key] += 1 + + except etree.XMLSyntaxError: + print(f"\n[警告] 文件格式错误,跳过: {xml_file.name}") + except Exception as e: + print(f"\n[错误] 处理文件 {xml_file.name} 时出错: {e}") + + # 5. 打印结果到终端 + if stats_counter: + # --- 动态计算列宽 --- + # 获取所有 EventType 的最大长度,默认长度 9 + max_type_width = max((len(k[0]) for k in stats_counter.keys()), default=9) + max_type_width = max(max_type_width, 9) + + # 获取所有 EventConcept 的最大长度,默认长度 12 + max_conc_width = max((len(k[1]) for k in stats_counter.keys()), default=12) + max_conc_width = max(max_conc_width, 12) + + # 计算表格总宽度 + total_line_width = max_type_width + max_conc_width + 10 + 6 + + print("\n" + "=" * total_line_width) + print(f"{'EventType':<{max_type_width}} | {'EventConcept':<{max_conc_width}} | {'Count':>10}") + print("-" * total_line_width) + + # --- 修改处:按名称排序 --- + # sorted() 默认会对元组 (EventType, EventConcept) 进行字典序排序 + # 即先按 EventType A-Z 排序,再按 EventConcept A-Z 排序 + for (e_type, e_concept), count in sorted(stats_counter.items()): + print(f"{e_type:<{max_type_width}} | {e_concept:<{max_conc_width}} | {count:>10}") + + print("=" * total_line_width) + + else: + print("\n未提取到任何事件数据。") + + print("=" * 90) + print(f"统计完成。共扫描 {total_files} 个文件。") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/draw_tools/__init__.py b/draw_tools/__init__.py new file mode 100644 index 0000000..ab3b0ae --- /dev/null +++ b/draw_tools/__init__.py @@ -0,0 +1,2 @@ +from .draw_statics import draw_signal_with_mask, draw_psg_signal +from .draw_label import draw_psg_bcg_label,draw_psg_label \ No newline at end of file diff --git a/draw_tools/draw_label.py b/draw_tools/draw_label.py new file mode 100644 index 0000000..29daa5d --- /dev/null +++ b/draw_tools/draw_label.py @@ -0,0 +1,337 @@ +from matplotlib.axes import Axes +from matplotlib.gridspec import GridSpec +from matplotlib.colors import PowerNorm +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +import seaborn as sns +import numpy as np +from tqdm.rich import tqdm +import utils +import gc +# 添加with_prediction参数 + +psg_bcg_chn_name2ax = { + "SpO2": 0, + "Flow T": 1, + "Flow P": 2, + "Effort Tho": 3, + "Effort Abd": 4, + "HR": 5, + "resp": 6, + "bcg": 7, + "Stage": 8, + "resp_twinx": 9, + "bcg_twinx": 10, +} + +psg_chn_name2ax = { + "SpO2": 0, + "Flow T": 1, + "Flow P": 2, + "Effort Tho": 3, + "Effort Abd": 4, + "Effort": 5, + "HR": 6, + "RRI": 7, + "Stage": 8, +} + + +resp_chn_name2ax = { + "resp": 0, + "bcg": 1, +} + + +def create_psg_bcg_figure(): + fig = plt.figure(figsize=(12, 8), dpi=200) + gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1]) + fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) + axes = [] + for i in range(9): + ax = fig.add_subplot(gs[i]) + axes.append(ax) + + axes[psg_bcg_chn_name2ax["SpO2"]].grid(True) + # axes[0].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["SpO2"]].set_ylim((85, 100)) + axes[psg_bcg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["Flow T"]].grid(True) + # axes[1].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["Flow P"]].grid(True) + # axes[2].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["Effort Tho"]].grid(True) + # axes[3].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["Effort Abd"]].grid(True) + # axes[4].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["HR"]].grid(True) + axes[psg_bcg_chn_name2ax["HR"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["resp"]].grid(True) + axes[psg_bcg_chn_name2ax["resp"]].tick_params(axis='x', colors="white") + axes.append(axes[psg_bcg_chn_name2ax["resp"]].twinx()) + + axes[psg_bcg_chn_name2ax["bcg"]].grid(True) + # axes[5].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white") + axes.append(axes[psg_bcg_chn_name2ax["bcg"]].twinx()) + + axes[psg_bcg_chn_name2ax["Stage"]].grid(True) + # axes[7].xaxis.set_major_formatter(Params.FORMATTER) + + return fig, axes + + +def create_psg_figure(): + fig = plt.figure(figsize=(12, 8), dpi=200) + gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 1, 1, 1]) + fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) + axes = [] + for i in range(9): + ax = fig.add_subplot(gs[i]) + axes.append(ax) + + axes[psg_chn_name2ax["SpO2"]].grid(True) + # axes[0].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_chn_name2ax["SpO2"]].set_ylim((85, 100)) + axes[psg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white") + + axes[psg_chn_name2ax["Flow T"]].grid(True) + # axes[1].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white") + + axes[psg_chn_name2ax["Flow P"]].grid(True) + # axes[2].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white") + + axes[psg_chn_name2ax["Effort Tho"]].grid(True) + # axes[3].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white") + + axes[psg_chn_name2ax["Effort Abd"]].grid(True) + # axes[4].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white") + + axes[psg_chn_name2ax["Effort"]].grid(True) + axes[psg_chn_name2ax["Effort"]].tick_params(axis='x', colors="white") + + axes[psg_chn_name2ax["HR"]].grid(True) + axes[psg_chn_name2ax["HR"]].tick_params(axis='x', colors="white") + + axes[psg_chn_name2ax["RRI"]].grid(True) + axes[psg_chn_name2ax["RRI"]].tick_params(axis='x', colors="white") + + + axes[psg_chn_name2ax["Stage"]].grid(True) + + + return fig, axes + +def create_resp_figure(): + fig = plt.figure(figsize=(12, 6), dpi=100) + gs = GridSpec(2, 1, height_ratios=[3, 2]) + fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) + axes = [] + for i in range(2): + ax = fig.add_subplot(gs[i]) + axes.append(ax) + + axes[0].grid(True) + # axes[0].xaxis.set_major_formatter(Params.FORMATTER) + axes[0].tick_params(axis='x', colors="white") + + axes[1].grid(True) + # axes[1].xaxis.set_major_formatter(Params.FORMATTER) + axes[1].tick_params(axis='x', colors="white") + + return fig, axes + + +def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None, + event_codes: list[int] = None, multi_labels=None, ax2: Axes = None): + signal_fs = signal_data["fs"] + chn_signal = signal_data["data"] + time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs) + ax.plot(time_axis, chn_signal[segment_start * signal_fs:segment_end * signal_fs], color='black', + label=signal_data["name"]) + if event_mask is not None: + if multi_labels is None and event_codes is not None: + for event_code in event_codes: + mask = event_mask[segment_start:segment_end].repeat(signal_fs) == event_code + y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * mask).astype(float) + np.place(y, y == 0, np.nan) + ax.plot(time_axis, y, color=utils.ColorCycle[event_code]) + elif multi_labels == "resp" and event_codes is not None: + ax.set_ylim(-6, 6) + # 建立第二个y轴坐标 + ax2.cla() + ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, + color='blue', alpha=0.8, label='Low Amplitude Mask') + ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, + color='orange', alpha=0.8, label='Movement Mask') + ax2.plot(time_axis, event_mask["Resp_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3, + color='green', alpha=0.8, label='Amplitude Change Mask') + for event_code in event_codes: + sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code + score_mask = event_mask["SA_Score"][segment_start:segment_end].repeat(signal_fs) + # y = (sa_mask * score_mask).astype(float) + y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * sa_mask).astype(float) + np.place(y, y == 0, np.nan) + ax.plot(time_axis, y, color=utils.ColorCycle[event_code]) + ax2.plot(time_axis, score_mask, color="orange") + ax2.set_ylim(-4, 5) + elif multi_labels == "bcg" and event_codes is not None: + # 建立第二个y轴坐标 + ax2.cla() + ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, + color='blue', alpha=0.8, label='Low Amplitude Mask') + ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, + color='orange', alpha=0.8, label='Movement Mask') + ax2.plot(time_axis, event_mask["BCG_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3, + color='green', alpha=0.8, label='Amplitude Change Mask') + + ax2.set_ylim(-4, 4) + + ax.set_ylabel("Amplitude") + ax.legend(loc=1) + + +def plt_stage_on_ax(ax, stage_data, segment_start, segment_end): + stage_signal = stage_data["data"] + stage_fs = stage_data["fs"] + time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * stage_fs) + ax.plot(time_axis, stage_signal[segment_start * stage_fs:segment_end * stage_fs], color='black', label=stage_data["name"]) + ax.set_ylim(0, 6) + ax.set_yticks([1, 2, 3, 4, 5]) + ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"]) + ax.set_ylabel("Stage") + ax.legend(loc=1) + + +def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end): + spo2_signal = spo2_data["data"] + spo2_fs = spo2_data["fs"] + time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * spo2_fs) + ax.plot(time_axis, spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs], color='black', label=spo2_data["name"]) + + if spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() < 85: + ax.set_ylim((spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() - 5, 100)) + else: + ax.set_ylim((85, 100)) + ax.set_ylabel("SpO2 (%)") + ax.legend(loc=1) + + +def score_mask2alpha(score_mask): + alpha_mask = np.zeros_like(score_mask, dtype=float) + alpha_mask[score_mask <= 0] = 0 + alpha_mask[score_mask == 1] = 0.9 + alpha_mask[score_mask == 2] = 0.6 + alpha_mask[score_mask == 3] = 0.1 + return alpha_mask + + +def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None, verbose=True, + multi_p=None, multi_task_id=None): + if save_path is not None: + save_path.mkdir(parents=True, exist_ok=True) + + for mask in event_mask.keys(): + if mask.startswith("Resp_") or mask.startswith("BCG_"): + event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0) + + event_mask["SA_Score"] = utils.none_to_nan_mask(event_mask["SA_Score"], 0) + + # event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"]) + # event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0) + + fig, axes = create_psg_bcg_figure() + for i, (segment_start, segment_end) in enumerate(segment_list): + for ax in axes: + ax.cla() + + plt_spo2_on_ax(axes[psg_bcg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end) + plt_stage_on_ax(axes[psg_bcg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end, + event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], + ax2=axes[psg_bcg_chn_name2ax["resp_twinx"]]) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end, + event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], + ax2=axes[psg_bcg_chn_name2ax["bcg_twinx"]]) + + + if save_path is not None: + fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png") + # print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}") + + if multi_p is not None: + multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"} + + plt.close(fig) + plt.close('all') + gc.collect() + + +def draw_psg_label(psg_data, psg_label, segment_list, save_path=None, verbose=True, + multi_p=None, multi_task_id=None): + if save_path is not None: + save_path.mkdir(parents=True, exist_ok=True) + + + if multi_p is None: + # 遍历psg_data中所有数据的长度 + for i in range(len(psg_data.keys())): + chn_name = list(psg_data.keys())[i] + print(f"{chn_name} data length: {len(psg_data[chn_name]['data'])}, fs: {psg_data[chn_name]['fs']}") + # psg_label的长度 + print(f"psg_label length: {len(psg_label)}") + + fig, axes = create_psg_figure() + for i, (segment_start, segment_end) in enumerate(segment_list): + for ax in axes: + ax.cla() + + plt_spo2_on_ax(axes[psg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end) + plt_stage_on_ax(axes[psg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end) + plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort"]], psg_data["Effort"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) + plt_signal_label_on_ax(axes[psg_chn_name2ax["RRI"]], psg_data["RRI"], segment_start, segment_end) + + if save_path is not None: + fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png") + # print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}") + + if multi_p is not None: + multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"} + plt.close(fig) + plt.close('all') + gc.collect() + + diff --git a/draw_tools/draw_statics.py b/draw_tools/draw_statics.py new file mode 100644 index 0000000..06e39aa --- /dev/null +++ b/draw_tools/draw_statics.py @@ -0,0 +1,371 @@ +from matplotlib.axes import Axes +from matplotlib.gridspec import GridSpec +from matplotlib.colors import PowerNorm +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +import seaborn as sns +import numpy as np + + +def draw_ax_confusion_matrix(ax:Axes, confusion_matrix, segment_count_matrix, confusion_matrix_percent, + valid_signal_length, total_duration, time_labels, amp_labels, signal_type=''): + # 创建用于热图注释的文本矩阵 + text_matrix = np.empty((len(amp_labels), len(time_labels)), dtype=object) + percent_matrix = np.zeros((len(amp_labels), len(time_labels))) + + # 填充文本矩阵和百分比矩阵 + for i in range(len(amp_labels)): + for j in range(len(time_labels)): + val = confusion_matrix.iloc[i, j] + segment_count = segment_count_matrix[i, j] + percent = confusion_matrix_percent.iloc[i, j] + text_matrix[i, j] = f"[{int(segment_count)}]{int(val)}\n({percent:.2f}%)" + percent_matrix[i, j] = percent + + # 绘制热图,调整颜色条位置 + sns_heatmap = sns.heatmap(percent_matrix, annot=text_matrix, fmt='', + xticklabels=time_labels, yticklabels=amp_labels, + cmap='YlGnBu', ax=ax, vmin=0, vmax=100, + norm=PowerNorm(gamma=0.5, vmin=0, vmax=100), + cbar_kws={'label': '百分比 (%)', 'shrink': 0.6, 'pad': 0.15}, + # annot_kws={'fontsize': 12} + ) + + # 添加行统计(右侧) + row_sums = confusion_matrix['总计'] + row_percents = confusion_matrix_percent['总计'] + ax.text(len(time_labels) + 1, -0.5, "各幅值时长\n(有效区间百分比%)", ha='center', va='center') + for i, (val, perc) in enumerate(zip(row_sums, row_percents)): + ax.text(len(time_labels) + 0.5, i + 0.5, f"{int(val)}\n({perc:.2f}%)", + ha='center', va='center') + + # 添加列统计(底部) + col_sums = segment_count_matrix.sum(axis=0) + col_percents = confusion_matrix.sum(axis=0) / total_duration * 100 + ax.text(-1, len(amp_labels) + 0.5, "[各时长片段数]\n(有效区间百分比%)", ha='center', va='center', rotation=0) + for j, (val, perc) in enumerate(zip(col_sums, col_percents)): + ax.text(j + 0.5, len(amp_labels) + 0.5, f"[{int(val)}]\n({perc:.2f}%)", + ha='center', va='center', rotation=0) + + # 将x轴坐标移到顶部 + ax.xaxis.set_label_position('top') + ax.xaxis.tick_top() + + # 设置标题和标签 + ax.set_title('幅值-时长统计矩阵', pad=40) + ax.set_xlabel('持续时间区间 (秒)', labelpad=10) + ax.set_ylabel('幅值区间') + + # 设置坐标轴标签水平显示 + ax.set_xticklabels(time_labels, rotation=0) + ax.set_yticklabels(amp_labels, rotation=0) + + # 调整颜色条位置 + cbar = sns_heatmap.collections[0].colorbar + cbar.ax.yaxis.set_label_position('right') + + # 添加图例说明 + ax.text(-2, -1, "热图内:\n[片段数]时长\n(有效区间百分比%)", + ha='left', va='top', bbox=dict(facecolor='none', edgecolor='black', alpha=0.5)) + + # 总计 + # 总片段数 + total_segments = segment_count_matrix.sum() + # 有效总市场占比 + total_percent = valid_signal_length / total_duration * 100 + ax.text(len(time_labels) + 0.5, len(amp_labels) + 0.5, f"[{int(total_segments)}]{valid_signal_length}\n({total_percent:.2f}%)", + ha='center', va='center') + + + +def draw_ax_amp(ax, signal_name, original_times, origin_signal, no_movement_signal, mav_values, + movement_position_list, low_amp_position_list, signal_second_length, aml_list=None): + # 绘制信号线 + ax.plot(original_times, origin_signal, 'k-', linewidth=1, alpha=0.7) + + # 添加红色和蓝色的axvspan区域 + for start, end in movement_position_list: + if start < len(original_times) and end < len(original_times): + ax.axvspan(start, end, color='red', alpha=0.3) + + for start, end in low_amp_position_list: + if start < len(original_times) and end < len(original_times): + ax.axvspan(start, end, color='blue', alpha=0.2) + + # 如果存在AML列表,绘制水平线 + if aml_list is not None: + color_map = ['red', 'orange', 'green'] + for i, aml in enumerate(aml_list): + ax.hlines(aml, 0, signal_second_length, color=color_map[min(i, 2)], linestyle='dashed', linewidth=2, alpha=0.5, label=f'{aml} aml') + + + ax.plot(np.linspace(0, len(mav_values), len(mav_values)), mav_values, color='blue', linewidth=2, alpha=0.4, label='2sMAV') + + # 设置Y轴范围 + ax.set_ylim((-2000, 2000)) + + # 创建表示不同颜色区域的图例 + red_patch = mpatches.Patch(color='red', alpha=0.2, label='Movement Area') + blue_patch = mpatches.Patch(color='blue', alpha=0.2, label='Low Amplitude Area') + + # 添加新的图例项,并保留原来的图例项 + handles, labels = ax.get_legend_handles_labels() # 获取原有图例 + ax.legend(handles=[red_patch, blue_patch] + handles, + labels=['Movement Area', 'Low Amplitude Area'] + labels, + loc='upper right', + bbox_to_anchor=(1, 1.4), + framealpha=0.5) + + # 设置标题和标签 + ax.set_title(f'{signal_name} Signal') + ax.set_ylabel('Amplitude') + ax.set_xlabel('Time (s)') + + # 启用网格 + ax.grid(True, linestyle='--', alpha=0.7) + + +def draw_signal_metrics(bcg_origin_signal, resp_origin_signal, bcg_no_movement_signal, resp_no_movement_signal, + bcg_sampling_rate, resp_sampling_rate, bcg_movement_position_list, bcg_low_amp_position_list, + resp_movement_position_list, resp_low_amp_position_list, + bcg_mav_values, resp_mav_values, bcg_statistic_info, resp_statistic_info, + signal_info, show=False, save_path=None): + + # 创建图像 + fig = plt.figure(figsize=(18, 10)) + + gs = GridSpec(2, 2, height_ratios=[2, 2], width_ratios=[4, 2], hspace=0.5) + + signal_second_length = len(bcg_origin_signal) // bcg_sampling_rate + bcg_origin_times = np.linspace(0, signal_second_length, len(bcg_origin_signal)) + resp_origin_times = np.linspace(0, signal_second_length, len(resp_origin_signal)) + # 子图 1:原始信号 + ax1 = fig.add_subplot(gs[0]) + draw_ax_amp(ax=ax1, signal_name='BCG', original_times=bcg_origin_times, origin_signal=bcg_origin_signal, + no_movement_signal=bcg_no_movement_signal, mav_values=bcg_mav_values, + movement_position_list=bcg_movement_position_list, low_amp_position_list=bcg_low_amp_position_list, + signal_second_length=signal_second_length, aml_list=[200, 500, 1000]) + + ax2 = fig.add_subplot(gs[1]) + param_names = ['confusion_matrix', 'segment_count_matrix', 'confusion_matrix_percent', + 'valid_signal_length', 'total_duration', 'time_labels', 'amp_labels'] + params = dict(zip(param_names, bcg_statistic_info)) + params['ax'] = ax2 + draw_ax_confusion_matrix(**params) + + ax3 = fig.add_subplot(gs[2], sharex=ax1) + draw_ax_amp(ax=ax3, signal_name='RSEP', original_times=resp_origin_times, origin_signal=resp_origin_signal, + no_movement_signal=resp_no_movement_signal, mav_values=resp_mav_values, + movement_position_list=resp_movement_position_list, low_amp_position_list=resp_low_amp_position_list, + signal_second_length=signal_second_length, aml_list=[100, 300, 500]) + ax4 = fig.add_subplot(gs[3]) + params = dict(zip(param_names, resp_statistic_info)) + params['ax'] = ax4 + draw_ax_confusion_matrix(**params) + + # 全局标题 + fig.suptitle(f'{signal_info} Signal Metrics', fontsize=16, x=0.35, y=0.95) + + if save_path is not None: + # 保存图像 + plt.savefig(save_path, dpi=300) + + if show: + plt.show() + + plt.close() + + +def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, resp_fs, bcg_fs, + signal_disable_mask, resp_low_amp_mask, resp_movement_mask, resp_change_mask, + resp_sa_mask, bcg_low_amp_mask, bcg_movement_mask, bcg_change_mask, show=False, save_path=None + ): + # 第一行绘制去工频噪声的原始信号,右侧为不可用区间标记,左侧为信号幅值纵坐标 + # 第二行绘制呼吸分量,右侧低幅值、高幅值、幅值变换标记、SA标签,左侧为呼吸幅值纵坐标 + # 第三行绘制心冲击分量,右侧为低幅值、高幅值、幅值变换标记、,左侧为心冲击幅值纵坐标 + # mask为None,则生成全Nan掩码 + def _none_to_nan_mask(mask, ref): + if mask is None: + return np.full_like(ref, np.nan) + else: + # 将mask中的0替换为nan,其他的保持 + mask = np.where(mask == 0, np.nan, mask) + return mask + + signal_disable_mask = _none_to_nan_mask(signal_disable_mask, signal_data) + resp_low_amp_mask = _none_to_nan_mask(resp_low_amp_mask, resp_data) + resp_movement_mask = _none_to_nan_mask(resp_movement_mask, resp_data) + resp_change_mask = _none_to_nan_mask(resp_change_mask, resp_data) + resp_sa_mask = _none_to_nan_mask(resp_sa_mask, resp_data) + bcg_low_amp_mask = _none_to_nan_mask(bcg_low_amp_mask, bcg_data) + bcg_movement_mask = _none_to_nan_mask(bcg_movement_mask, bcg_data) + bcg_change_mask = _none_to_nan_mask(bcg_change_mask, bcg_data) + + + fig = plt.figure(figsize=(18, 10)) + ax0 = fig.add_subplot(3, 1, 1) + ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue') + ax0.set_title(f'Sample {samp_id} - Raw Signal Data') + ax0.set_ylabel('Amplitude') + # ax0.set_xticklabels([]) + + ax0_twin = ax0.twinx() + ax0_twin.plot(np.linspace(0, len(signal_disable_mask), len(signal_disable_mask)), signal_disable_mask, + color='red', alpha=0.5) + ax0_twin.autoscale(enable=False, axis='y', tight=True) + ax0_twin.set_ylim((-2, 2)) + ax0_twin.set_ylabel('Disable Mask') + ax0_twin.set_yticks([0, 1]) + ax0_twin.set_yticklabels(['Enabled', 'Disabled']) + ax0_twin.grid(False) + ax0_twin.legend(['Disable Mask'], loc='upper right') + + + ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) + ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='gray', alpha=0.5) + resp_data_no_movement = resp_data.copy() + resp_data_no_movement[resp_movement_mask.repeat(int(resp_fs)) == 1] = np.nan + ax1.plot(np.linspace(0, len(resp_data_no_movement) // resp_fs, len(resp_data_no_movement)), resp_data_no_movement, color='orange') + ax1.set_ylabel('Amplitude') + # ax1.set_xticklabels([]) + ax1_twin = ax1.twinx() + ax1_twin.plot(np.linspace(0, len(resp_low_amp_mask), len(resp_low_amp_mask)), resp_low_amp_mask*-1, + color='blue', alpha=0.5, label='Low Amplitude Mask') + ax1_twin.plot(np.linspace(0, len(resp_movement_mask), len(resp_movement_mask)), resp_movement_mask*-2, + color='red', alpha=0.5, label='Movement Mask') + ax1_twin.plot(np.linspace(0, len(resp_change_mask), len(resp_change_mask)), resp_change_mask*-3, + color='green', alpha=0.5, label='Amplitude Change Mask') + ax1_twin.plot(np.linspace(0, len(resp_sa_mask), len(resp_sa_mask)), resp_sa_mask, + color='purple', alpha=0.5, label='SA Mask') + ax1_twin.autoscale(enable=False, axis='y', tight=True) + ax1_twin.set_ylim((-4, 5)) + # ax1_twin.set_ylabel('Resp Masks') + # ax1_twin.set_yticks([0, 1]) + # ax1_twin.set_yticklabels(['No', 'Yes']) + ax1_twin.grid(False) + + ax1_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask', 'SA Mask'], loc='upper right') + ax1.set_title(f'Sample {samp_id} - Respiration Component') + + + + ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) + ax2.plot(np.linspace(0, len(bcg_data) // bcg_fs, len(bcg_data)), bcg_data, color='green') + ax2.set_ylabel('Amplitude') + ax2.set_xlabel('Time (s)') + ax2_twin = ax2.twinx() + ax2_twin.plot(np.linspace(0, len(bcg_low_amp_mask), len(bcg_low_amp_mask)), bcg_low_amp_mask*-1, + color='blue', alpha=0.5, label='Low Amplitude Mask') + ax2_twin.plot(np.linspace(0, len(bcg_movement_mask), len(bcg_movement_mask)), bcg_movement_mask*-2, + color='red', alpha=0.5, label='Movement Mask') + ax2_twin.plot(np.linspace(0, len(bcg_change_mask), len(bcg_change_mask)), bcg_change_mask*-3, + color='green', alpha=0.5, label='Amplitude Change Mask') + ax2_twin.autoscale(enable=False, axis='y', tight=True) + ax2_twin.set_ylim((-4, 2)) + ax2_twin.set_ylabel('BCG Masks') + # ax2_twin.set_yticks([0, 1]) + # ax2_twin.set_yticklabels(['No', 'Yes']) + ax2_twin.grid(False) + ax2_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask'], loc='upper right') + # ax2.set_title(f'Sample {samp_id} - BCG Component') + + ax0_twin._lim_lock = False + ax1_twin._lim_lock = False + ax2_twin._lim_lock = False + + def on_lims_change(event_ax): + if getattr(event_ax, '_lim_lock', False): + return + try: + event_ax._lim_lock = True + + if event_ax == ax0_twin: + # 重新锁定 ax1 的 Y 轴范围 + ax0_twin.set_ylim(-2, 2) + elif event_ax == ax1_twin: + ax1_twin.set_ylim(-4, 5) + elif event_ax == ax2_twin: + ax2_twin.set_ylim(-4, 2) + + finally: + event_ax._lim_lock = False + + + ax0_twin.callbacks.connect('ylim_changed', on_lims_change) + ax1_twin.callbacks.connect('ylim_changed', on_lims_change) + ax2_twin.callbacks.connect('ylim_changed', on_lims_change) + plt.tight_layout() + + if save_path is not None: + plt.savefig(save_path, dpi=300) + if show: + plt.show() + + +def draw_psg_signal(samp_id, tho_signal, abd_signal, flowp_signal, flowt_signal, spo2_signal, effort_signal, rri_signal, event_mask, fs, + show=False, save_path=None): + sa_mask = event_mask.repeat(fs) + fig, axs = plt.subplots(7, 1, figsize=(18, 12), sharex=True) + time_axis = np.linspace(0, len(tho_signal) / fs, len(tho_signal)) + axs[0].plot(time_axis, tho_signal, label='THO', color='black') + axs[0].set_title(f'Sample {samp_id} - PSG Signal Data') + axs[0].set_ylabel('THO Amplitude') + axs[0].legend(loc='upper right') + + ax0_twin = axs[0].twinx() + ax0_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask') + ax0_twin.autoscale(enable=False, axis='y', tight=True) + ax0_twin.set_ylim((-4, 5)) + + axs[1].plot(time_axis, abd_signal, label='ABD', color='black') + axs[1].set_ylabel('ABD Amplitude') + axs[1].legend(loc='upper right') + + ax1_twin = axs[1].twinx() + ax1_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask') + ax1_twin.autoscale(enable=False, axis='y', tight=True) + ax1_twin.set_ylim((-4, 5)) + + axs[2].plot(time_axis, effort_signal, label='EFFO', color='black') + axs[2].set_ylabel('EFFO Amplitude') + axs[2].legend(loc='upper right') + + ax2_twin = axs[2].twinx() + ax2_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask') + ax2_twin.autoscale(enable=False, axis='y', tight=True) + ax2_twin.set_ylim((-4, 5)) + + axs[3].plot(time_axis, flowp_signal, label='FLOWP', color='black') + axs[3].set_ylabel('FLOWP Amplitude') + axs[3].legend(loc='upper right') + + ax3_twin = axs[3].twinx() + ax3_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask') + ax3_twin.autoscale(enable=False, axis='y', tight=True) + ax3_twin.set_ylim((-4, 5)) + + axs[4].plot(time_axis, flowt_signal, label='FLOWT', color='black') + axs[4].set_ylabel('FLOWT Amplitude') + axs[4].legend(loc='upper right') + + ax4_twin = axs[4].twinx() + ax4_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask') + ax4_twin.autoscale(enable=False, axis='y', tight=True) + ax4_twin.set_ylim((-4, 5)) + + axs[5].plot(time_axis, rri_signal, label='RRI', color='black') + axs[5].set_ylabel('RRI Amplitude') + axs[5].legend(loc='upper right') + + + axs[6].plot(time_axis, spo2_signal, label='SPO2', color='black') + axs[6].set_ylabel('SPO2 Amplitude') + axs[6].set_xlabel('Time (s)') + axs[6].legend(loc='upper right') + + + if save_path is not None: + plt.savefig(save_path, dpi=300) + if show: + plt.show() + diff --git a/event_mask_process/HYS_PSG_process.py b/event_mask_process/HYS_PSG_process.py new file mode 100644 index 0000000..bd04f63 --- /dev/null +++ b/event_mask_process/HYS_PSG_process.py @@ -0,0 +1,183 @@ +""" +本脚本完成对呼研所数据的处理,包含以下功能: +1. 数据读取与预处理 + 从传入路径中,进行数据和标签的读取,并进行初步的预处理 + 预处理包括为数据进行滤波、去噪等操作 +2. 数据清洗与异常值处理 +3. 输出清晰后的统计信息 +4. 数据保存 + 将处理后的数据保存到指定路径,便于后续使用 + 主要是保存切分后的数据位置和标签 +5. 可视化 + 提供数据处理前后的可视化对比,帮助理解数据变化 + 绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等 + + + +# 低幅值区间规则标定与剔除 +# 高幅值连续体动规则标定与剔除 +# 手动标定不可用区间提剔除 +""" +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) +project_root_path = Path(__file__).resolve().parent.parent + +import shutil +import draw_tools +import utils +import numpy as np +import signal_method +import os + + +os.environ['DISPLAY'] = "localhost:10.0" + + +def process_one_signal(samp_id, show=False): + pass + + tho_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Tho_Sync_*.txt")) + abd_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Abd_Sync_*.txt")) + flowp_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow P_Sync_*.txt")) + flowt_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow T_Sync_*.txt")) + spo2_signal_path = list((org_signal_root_path / f"{samp_id}").glob("SpO2_Sync_*.txt")) + stage_signal_path = list((org_signal_root_path / f"{samp_id}").glob("5_class_Sync_*.txt")) + + if not tho_signal_path: + raise FileNotFoundError(f"Effort Tho_Sync file not found for sample ID: {samp_id}") + tho_signal_path = tho_signal_path[0] + print(f"Processing Effort Tho_Sync signal file: {tho_signal_path}") + if not abd_signal_path: + raise FileNotFoundError(f"Effort Abd_Sync file not found for sample ID: {samp_id}") + abd_signal_path = abd_signal_path[0] + print(f"Processing Effort Abd_Sync signal file: {abd_signal_path}") + if not flowp_signal_path: + raise FileNotFoundError(f"Flow P_Sync file not found for sample ID: {samp_id}") + flowp_signal_path = flowp_signal_path[0] + print(f"Processing Flow P_Sync signal file: {flowp_signal_path}") + if not flowt_signal_path: + raise FileNotFoundError(f"Flow T_Sync file not found for sample ID: {samp_id}") + flowt_signal_path = flowt_signal_path[0] + print(f"Processing Flow T_Sync signal file: {flowt_signal_path}") + if not spo2_signal_path: + raise FileNotFoundError(f"SpO2_Sync file not found for sample ID: {samp_id}") + spo2_signal_path = spo2_signal_path[0] + print(f"Processing SpO2_Sync signal file: {spo2_signal_path}") + if not stage_signal_path: + raise FileNotFoundError(f"5_class_Sync file not found for sample ID: {samp_id}") + stage_signal_path = stage_signal_path[0] + print(f"Processing 5_class_Sync signal file: {stage_signal_path}") + + + + label_path = (label_root_path / f"{samp_id}").glob("SA Label_Sync.csv") + if not label_path: + raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}") + label_path = list(label_path)[0] + print(f"Processing Label_corrected file: {label_path}") + # + # # 保存处理后的数据和标签 + save_samp_path = save_path / f"{samp_id}" + save_samp_path.mkdir(parents=True, exist_ok=True) + + # # # 读取信号数据 + tho_data_raw, tho_length, tho_fs, tho_second = utils.read_signal_txt(tho_signal_path, dtype=float, verbose=True) + # abd_data_raw, abd_length, abd_fs, abd_second = utils.read_signal_txt(abd_signal_path, dtype=float, verbose=True) + # flowp_data_raw, flowp_length, flowp_fs, flowp_second = utils.read_signal_txt(flowp_signal_path, dtype=float, verbose=True) + # flowt_data_raw, flowt_length, flowt_fs, flowt_second = utils.read_signal_txt(flowt_signal_path, dtype=float, verbose=True) + # spo2_data_raw, spo2_length, spo2_fs, spo2_second = utils.read_signal_txt(spo2_signal_path, dtype=float, verbose=True) + stage_data_raw, stage_length, stage_fs, stage_second = utils.read_signal_txt(stage_signal_path, dtype=str, verbose=True) + + + # + # # 预处理与滤波 + # tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=tho_data_raw, effort_fs=tho_fs) + # abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=abd_data_raw, effort_fs=abd_fs) + # flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=flowp_data_raw, effort_fs=flowp_fs) + # flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=flowt_data_raw, effort_fs=flowt_fs) + + # 降采样 + # old_tho_fs = tho_fs + # tho_fs = conf["effort"]["downsample_fs"] + # tho_data_filt = utils.downsample_signal_fast(original_signal=tho_data_filt, original_fs=old_tho_fs, target_fs=tho_fs) + # old_abd_fs = abd_fs + # abd_fs = conf["effort"]["downsample_fs"] + # abd_data_filt = utils.downsample_signal_fast(original_signal=abd_data_filt, original_fs=old_abd_fs, target_fs=abd_fs) + # old_flowp_fs = flowp_fs + # flowp_fs = conf["effort"]["downsample_fs"] + # flowp_data_filt = utils.downsample_signal_fast(original_signal=flowp_data_filt, original_fs=old_flowp_fs, target_fs=flowp_fs) + # old_flowt_fs = flowt_fs + # flowt_fs = conf["effort"]["downsample_fs"] + # flowt_data_filt = utils.downsample_signal_fast(original_signal=flowt_data_filt, original_fs=old_flowt_fs, target_fs=flowt_fs) + + # spo2不降采样 + # spo2_data_filt = spo2_data_raw + # spo2_fs = spo2_fs + + label_data = utils.read_raw_psg_label(path=label_path) + event_mask, score_mask = utils.generate_event_mask(signal_second=tho_second, event_df=label_data, use_correct=False, with_score=False) + # event_mask > 0 的部分为1,其他为0 + score_mask = np.where(event_mask > 0, 1, 0) + + # 根据睡眠分期生成不可用区间 + wake_mask = utils.get_wake_mask(stage_data_raw) + # 剔除短于60秒的觉醒区间 + wake_mask = utils.remove_short_durations(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), min_duration_sec=60) + # 合并短于120秒的觉醒区间 + wake_mask = utils.merge_short_gaps(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), max_gap_sec=60) + + disable_label = wake_mask + + disable_label = disable_label[:tho_second] + + + # 复制事件文件 到保存路径 + sa_label_save_name = f"{samp_id}_" + label_path.name + shutil.copyfile(label_path, save_samp_path / sa_label_save_name) + # + # 新建一个dataframe,分别是秒数、SA标签, + save_dict = { + "Second": np.arange(tho_second), + "SA_Label": event_mask, + "SA_Score": score_mask, + "Disable_Label": disable_label, + "Resp_LowAmp_Label": np.zeros_like(event_mask), + "Resp_Movement_Label": np.zeros_like(event_mask), + "Resp_AmpChange_Label": np.zeros_like(event_mask), + "BCG_LowAmp_Label": np.zeros_like(event_mask), + "BCG_Movement_Label": np.zeros_like(event_mask), + "BCG_AmpChange_Label": np.zeros_like(event_mask) + } + + mask_label_save_name = f"{samp_id}_Processed_Labels.csv" + utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict) + + +if __name__ == '__main__': + yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml" + # disable_df_path = project_root_path / "排除区间.xlsx" + # + conf = utils.load_dataset_conf(yaml_path) + + root_path = Path(conf["root_path"]) + save_path = Path(conf["mask_save_path"]) + select_ids = conf["select_ids"] + # + print(f"select_ids: {select_ids}") + print(f"root_path: {root_path}") + print(f"save_path: {save_path}") + # + org_signal_root_path = root_path / "PSG_Aligned" + label_root_path = root_path / "PSG_Aligned" + # + # all_samp_disable_df = utils.read_disable_excel(disable_df_path) + # + # process_one_signal(select_ids[0], show=True) + # # + for samp_id in select_ids: + print(f"Processing sample ID: {samp_id}") + process_one_signal(samp_id, show=False) + print(f"Finished processing sample ID: {samp_id}\n\n") + pass \ No newline at end of file diff --git a/event_mask_process/HYS_process.py b/event_mask_process/HYS_process.py new file mode 100644 index 0000000..45fa76d --- /dev/null +++ b/event_mask_process/HYS_process.py @@ -0,0 +1,246 @@ +""" +本脚本完成对呼研所数据的处理,包含以下功能: +1. 数据读取与预处理 + 从传入路径中,进行数据和标签的读取,并进行初步的预处理 + 预处理包括为数据进行滤波、去噪等操作 +2. 数据清洗与异常值处理 +3. 输出清晰后的统计信息 +4. 数据保存 + 将处理后的数据保存到指定路径,便于后续使用 + 主要是保存切分后的数据位置和标签 +5. 可视化 + 提供数据处理前后的可视化对比,帮助理解数据变化 + 绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等 + + + +# 低幅值区间规则标定与剔除 +# 高幅值连续体动规则标定与剔除 +# 手动标定不可用区间提剔除 +""" +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) +project_root_path = Path(__file__).resolve().parent.parent + +import shutil +import draw_tools +import utils +import numpy as np +import signal_method +import os + + +os.environ['DISPLAY'] = "localhost:10.0" + + +def process_one_signal(samp_id, show=False): + signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) + if not signal_path: + raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}") + signal_path = signal_path[0] + print(f"Processing OrgBCG_Sync signal file: {signal_path}") + + label_path = (label_root_path / f"{samp_id}").glob("SA Label_corrected.csv") + if not label_path: + raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}") + label_path = list(label_path)[0] + print(f"Processing Label_corrected file: {label_path}") + + # 保存处理后的数据和标签 + save_samp_path = save_path / f"{samp_id}" + save_samp_path.mkdir(parents=True, exist_ok=True) + + signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=True) + + signal_data, resp_data, resp_fs, bcg_data, bcg_fs = signal_method.signal_filter_split(conf=conf, signal_data_raw=signal_data_raw, signal_fs=signal_fs) + + # 降采样 + old_resp_fs = resp_fs + resp_fs = conf["resp"]["downsample_fs_2"] + resp_data = utils.downsample_signal_fast(original_signal=resp_data, original_fs=old_resp_fs, target_fs=resp_fs) + bcg_fs = conf["bcg"]["downsample_fs"] + bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs) + + label_data = utils.read_label_csv(path=label_path) + event_mask, score_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data) + + manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[ + all_samp_disable_df["id"] == samp_id]) + print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}") + + # 分析Resp的低幅值区间 + resp_low_amp_conf = conf.get("resp_low_amp", None) + if resp_low_amp_conf is not None: + resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal( + signal_data=resp_data, + sampling_rate=resp_fs, + **resp_low_amp_conf + ) + print( + f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}, count_low_amp_positions: {len(resp_low_amp_position_list)}") + else: + resp_low_amp_mask, resp_low_amp_position_list = None, None + print("resp_low_amp_mask is None") + + # 分析Resp的高幅值伪迹区间 + resp_movement_conf = conf.get("resp_movement", None) + if resp_movement_conf is not None: + raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement( + signal_data=resp_data, + sampling_rate=resp_fs, + **resp_movement_conf + ) + print( + f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") + else: + resp_movement_mask, resp_movement_position_list = None, None + print("resp_movement_mask is None") + + resp_movement_revise_conf = conf.get("resp_movement_revise", None) + if resp_movement_mask is not None and resp_movement_revise_conf is not None: + resp_movement_mask, resp_movement_position_list = signal_method.movement_revise( + signal_data=resp_data, + movement_mask=resp_movement_mask, + movement_list=resp_movement_position_list, + sampling_rate=resp_fs, + **resp_movement_revise_conf, + verbose=False + ) + print( + f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") + else: + print("resp_movement_mask revise is skipped") + + # 分析Resp的幅值突变区间 + resp_amp_change_conf = conf.get("resp_amp_change", None) + if resp_amp_change_conf is not None and resp_movement_mask is not None: + resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v3( + signal_data=resp_data, + movement_mask=resp_movement_mask, + movement_list=resp_movement_position_list, + sampling_rate=resp_fs, + **resp_amp_change_conf, + verbose=False) + print( + f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}") + else: + resp_amp_change_mask = None + print("amp_change_mask is None") + + # 分析Bcg的低幅值区间 + bcg_low_amp_conf = conf.get("bcg_low_amp", None) + if bcg_low_amp_conf is not None: + bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal( + signal_data=bcg_data, + sampling_rate=bcg_fs, + **bcg_low_amp_conf + ) + print( + f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}, count_low_amp_positions: {len(bcg_low_amp_position_list)}") + else: + bcg_low_amp_mask, bcg_low_amp_position_list = None, None + print("bcg_low_amp_mask is None") + + # 分析Bcg的高幅值伪迹区间 + bcg_movement_conf = conf.get("bcg_movement", None) + if bcg_movement_conf is not None: + raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement( + signal_data=bcg_data, + sampling_rate=bcg_fs, + **bcg_movement_conf + ) + print( + f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}, count_movement_positions: {len(bcg_movement_position_list)}") + else: + bcg_movement_mask = None + print("bcg_movement_mask is None") + + # 分析Bcg的幅值突变区间 + if bcg_movement_mask is not None: + bcg_amp_change_mask, bcg_amp_change_list = signal_method.position_based_sleep_recognition_v2( + signal_data=bcg_data, + movement_mask=bcg_movement_mask, + sampling_rate=bcg_fs) + print( + f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}, count_amp_change_positions: {len(bcg_amp_change_list)}") + else: + bcg_amp_change_mask = None + print("bcg_amp_change_mask is None") + + # 如果signal_data采样率过,进行降采样 + if signal_fs == 1000: + signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100) + signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs, + target_fs=100) + signal_fs = 100 + + draw_tools.draw_signal_with_mask(samp_id=samp_id, + signal_data=signal_data, + signal_fs=signal_fs, + resp_data=resp_data, + resp_fs=resp_fs, + bcg_data=bcg_data, + bcg_fs=bcg_fs, + signal_disable_mask=manual_disable_mask, + resp_low_amp_mask=resp_low_amp_mask, + resp_movement_mask=resp_movement_mask, + resp_change_mask=resp_amp_change_mask, + resp_sa_mask=event_mask, + bcg_low_amp_mask=bcg_low_amp_mask, + bcg_movement_mask=bcg_movement_mask, + bcg_change_mask=bcg_amp_change_mask, + show=show, + save_path=save_samp_path / f"{samp_id}_Signal_Plots.png") + + # 复制事件文件 到保存路径 + sa_label_save_name = f"{samp_id}_" + label_path.name + shutil.copyfile(label_path, save_samp_path / sa_label_save_name) + + # 新建一个dataframe,分别是秒数、SA标签,SA质量标签,禁用标签,Resp低幅值标签,Resp体动标签,Resp幅值突变标签,Bcg低幅值标签,Bcg体动标签,Bcg幅值突变标签 + save_dict = { + "Second": np.arange(signal_second), + "SA_Label": event_mask, + "SA_Score": score_mask, + "Disable_Label": manual_disable_mask, + "Resp_LowAmp_Label": resp_low_amp_mask if resp_low_amp_mask is not None else np.zeros(signal_second, dtype=int), + "Resp_Movement_Label": resp_movement_mask if resp_movement_mask is not None else np.zeros(signal_second, + dtype=int), + "Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second, + dtype=int), + "BCG_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int), + "BCG_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second, + dtype=int), + "BCG_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second, + dtype=int) + } + + mask_label_save_name = f"{samp_id}_Processed_Labels.csv" + utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict) + + +if __name__ == '__main__': + yaml_path = project_root_path / "dataset_config/HYS_config.yaml" + disable_df_path = project_root_path / "排除区间.xlsx" + + conf = utils.load_dataset_conf(yaml_path) + select_ids = conf["select_ids"] + root_path = Path(conf["root_path"]) + save_path = Path(conf["mask_save_path"]) + + print(f"select_ids: {select_ids}") + print(f"root_path: {root_path}") + print(f"save_path: {save_path}") + + org_signal_root_path = root_path / "OrgBCG_Aligned" + label_root_path = root_path / "Label" + + all_samp_disable_df = utils.read_disable_excel(disable_df_path) + + # process_one_signal(select_ids[6], show=True) + # + for samp_id in select_ids: + print(f"Processing sample ID: {samp_id}") + process_one_signal(samp_id, show=False) + print(f"Finished processing sample ID: {samp_id}\n\n") diff --git a/event_mask_process/SHHS1_process.py b/event_mask_process/SHHS1_process.py new file mode 100644 index 0000000..a1fdae6 --- /dev/null +++ b/event_mask_process/SHHS1_process.py @@ -0,0 +1,42 @@ +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) +project_root_path = Path(__file__).resolve().parent.parent + +import shutil +import draw_tools +import utils +import numpy as np +import signal_method +import os +import mne +from tqdm import tqdm +import xml.etree.ElementTree as ET +import re + +# 获取分期和事件标签,以及不可用区间 + +def process_one_signal(samp_id, show=False): + + + + + + + + +if __name__ == '__main__': + yaml_path = project_root_path / "dataset_config/SHHS1_config.yaml" + + conf = utils.load_dataset_conf(yaml_path) + + root_path = Path(conf["root_path"]) + save_path = Path(conf["mask_save_path"]) + + print(f"root_path: {root_path}") + print(f"save_path: {save_path}") + + org_signal_root_path = root_path / "OrgBCG_Aligned" + label_root_path = root_path / "Label" + diff --git a/event_mask_process/__init__.py b/event_mask_process/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/signal_method/__init__.py b/signal_method/__init__.py new file mode 100644 index 0000000..eb9e1e2 --- /dev/null +++ b/signal_method/__init__.py @@ -0,0 +1,6 @@ +from .rule_base_event import detect_low_amplitude_signal, detect_movement +from .rule_base_event import position_based_sleep_recognition_v2, position_based_sleep_recognition_v3 +from .rule_base_event import movement_revise +from .time_metrics import calc_mav_by_slide_windows +from .signal_process import signal_filter_split, rpeak2hr, psg_effort_filter, rpeak2rri_interpolation +from .normalize_method import normalize_resp_signal_by_segment diff --git a/signal_method/normalize_method.py b/signal_method/normalize_method.py new file mode 100644 index 0000000..095f16e --- /dev/null +++ b/signal_method/normalize_method.py @@ -0,0 +1,52 @@ +import utils +import pandas as pd +import numpy as np +from scipy import signal + +def normalize_resp_signal_by_segment(resp_signal: np.ndarray, resp_fs, movement_mask, enable_list): + # 根据呼吸信号的幅值改变区间,对每段进行Z-Score标准化 + normalized_resp_signal = np.zeros_like(resp_signal) + # 全部填成nan + normalized_resp_signal[:] = np.nan + + resp_signal_no_movement = resp_signal.copy() + + + resp_signal_no_movement[np.array(movement_mask == 1).repeat(resp_fs)] = np.nan + + + for i in range(len(enable_list)): + enable_start = enable_list[i][0] * resp_fs + enable_end = enable_list[i][1] * resp_fs + segment = resp_signal_no_movement[enable_start:enable_end] + + # print(f"Normalizing segment {i+1}/{len(enable_list)}: start={enable_start}, end={enable_end}, length={len(segment)}") + + segment_mean = np.nanmean(segment) + segment_std = np.nanstd(segment) + if segment_std == 0: + raise ValueError(f"segment_std is zero! segment_start: {enable_start}, segment_end: {enable_end}") + + # 同下一个enable区间的体动一起进行标准化 + if i <= len(enable_list) - 2: + enable_end = enable_list[i + 1][0] * resp_fs + raw_segment = resp_signal[enable_start:enable_end] + normalized_resp_signal[enable_start:enable_end] = (raw_segment - segment_mean) / segment_std + + + #如果enable区间不从0开始,则将前面的部分也进行标准化 + if enable_list[0][0] > 0: + new_enable_start = 0 + enable_start = enable_list[0][0] * resp_fs + enable_end = enable_list[0][1] * resp_fs + segment = resp_signal_no_movement[enable_start:enable_end] + + segment_mean = np.nanmean(segment) + segment_std = np.nanstd(segment) + if segment_std == 0: + raise ValueError(f"segment_std is zero! segment_start: {enable_start}, segment_end: {enable_end}") + + raw_segment = resp_signal[new_enable_start:enable_start] + normalized_resp_signal[new_enable_start:enable_start] = (raw_segment - segment_mean) / segment_std + + return normalized_resp_signal diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py new file mode 100644 index 0000000..ee6b6ba --- /dev/null +++ b/signal_method/rule_base_event.py @@ -0,0 +1,688 @@ +import utils +from utils.operation_tools import timing_decorator +import numpy as np +from utils import merge_short_gaps, remove_short_durations, event_mask_2_list, collect_values +from signal_method.time_metrics import calc_mav_by_slide_windows + + +@timing_decorator() +def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=None, + std_median_multiplier=4.5, compare_intervals_sec=[30, 60], + interval_multiplier=2.5, + merge_gap_sec=10, min_duration_sec=5, + low_amplitude_periods=None): + """ + 检测信号中的体动状态,结合两种方法:标准差比较和前后窗口幅值对比。 + 使用反射填充处理信号边界。 + + 参数: + - signal_data: numpy array,输入的信号数据 + - sampling_rate: int,信号的采样率(Hz) + - window_size_sec: float,分析窗口的时长(秒),默认值为 2 秒 + - stride_sec: float,窗口滑动步长(秒),默认值为None(等于window_size_sec,无重叠) + - std_median_multiplier: float,标准差中位数的乘数阈值,默认值为 4.5 + - compare_intervals_sec: list,用于比较的时间间隔列表(秒),默认为 [30, 60] + - interval_multiplier: float,间隔中位数的上限乘数,默认值为 2.5 + - merge_gap_sec: float,要合并的体动状态之间的最大间隔(秒),默认值为 10 秒 + - min_duration_sec: float,要保留的体动状态的最小持续时间(秒),默认值为 5 秒 + - low_amplitude_periods: numpy array,低幅值期间的掩码(1表示低幅值期间),默认为None + + 返回: + - movement_mask: numpy array,体动状态的掩码(1表示体动,0表示睡眠) + """ + # 计算窗口大小(样本数) + window_samples = int(window_size_sec * sampling_rate) + + # 如果未指定步长,设置为窗口大小(无重叠) + if stride_sec is None: + stride_sec = window_size_sec + + # 计算步长(样本数) + stride_samples = int(stride_sec * sampling_rate) + + # 确保步长至少为1 + stride_samples = max(1, stride_samples) + + # 计算需要的最大填充大小(基于比较间隔) + max_interval_samples = int(max(compare_intervals_sec) * sampling_rate) + + # 应用反射填充以正确处理边界 + # 填充大小为最大比较间隔的一半,以确保边界有足够的上下文 + pad_size = max_interval_samples + padded_signal = np.pad(signal_data, pad_size, mode='reflect') + + # 计算填充后的窗口数量 + num_windows = max(1, (len(padded_signal) - window_samples) // stride_samples + 1) + + # 初始化窗口标准差数组 + window_std = np.zeros(num_windows) + # 计算每个窗口的标准差 + # 分窗计算标准差 + for i in range(num_windows): + start_idx = i * stride_samples + end_idx = min(start_idx + window_samples, len(padded_signal)) + + # 处理窗口,包括可能不完整的最后一个窗口 + window_data = padded_signal[start_idx:end_idx] + if len(window_data) > 0: + window_std[i] = np.std(window_data, ddof=1) + else: + window_std[i] = 0 + + # 计算原始信号对应的窗口索引范围 + # 填充后,原始信号从pad_size开始 + orig_start_window = pad_size // stride_samples + if stride_sec == 1: + orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + else: + orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + 1 + + # 只保留原始信号对应的窗口标准差 + original_window_std = window_std[orig_start_window:orig_end_window] + num_original_windows = len(original_window_std) + + # 创建时间点数组(秒) + time_points = np.arange(num_original_windows) * stride_sec + + # # 如果提供了低幅值期间的掩码,则在计算全局中位数时排除这些期间 + # if low_amplitude_periods is not None and len(low_amplitude_periods) == num_original_windows: + # valid_std = original_window_std[low_amplitude_periods == 0] + # if len(valid_std) == 0: # 如果所有窗口都在低幅值期间 + # valid_std = original_window_std # 使用全部窗口 + # else: + # valid_std = original_window_std + + valid_std = original_window_std ##20250418新修改 + + # ---------------------- 方法一:基于STD的体动判定 ----------------------# + # 计算所有有效窗口标准差的中位数 + median_std = np.median(valid_std) + + # 当窗口标准差大于中位数的倍数,判定为体动状态 + std_movement = np.where((original_window_std > (median_std * std_median_multiplier)), 1, 0) + + # ------------------ 方法二:基于前后信号幅值变化的体动判定 ------------------# + amplitude_movement = np.zeros(num_original_windows, dtype=int) + + # 定义基于时间粒度的比较间隔索引 + compare_intervals_idx = [int(interval // stride_sec) for interval in compare_intervals_sec] + + # 逐窗口判断 + for win_idx in range(num_original_windows): + # 全局索引(在填充后的窗口数组中) + global_win_idx = win_idx + orig_start_window + + # 对每个比较间隔进行检查 + for interval_idx in compare_intervals_idx: + # 确定比较范围的结束索引(在填充后的窗口数组中) + end_idx = min(global_win_idx + interval_idx, len(window_std)) + + # 提取相应时间范围内的标准差值 + if global_win_idx < end_idx: + interval_std = window_std[global_win_idx:end_idx] + + # 计算该间隔的中位数 + interval_median = np.median(interval_std) + + # 计算上下阈值 + upper_threshold = interval_median * interval_multiplier + + # 检查当前窗口是否超出阈值范围,如果超出则直接标记为体动 + if window_std[global_win_idx] > upper_threshold: + amplitude_movement[win_idx] = 1 + break # 一旦确定为体动就不需要继续检查其他间隔 + + # 将两种方法的结果合并:只要其中一种方法判定为体动,最终结果就是体动 + movement_mask = np.logical_or(std_movement, amplitude_movement).astype(int) + raw_movement_mask = movement_mask + + # 如果需要合并间隔小的体动状态 + if merge_gap_sec > 0: + movement_mask = merge_short_gaps(movement_mask, time_points, merge_gap_sec) + + # 如果需要移除短时体动状态 + if min_duration_sec > 0: + movement_mask = remove_short_durations(movement_mask, time_points, min_duration_sec) + + # raw_movement_mask, movement_mask恢复对应秒数,而不是点数 + raw_movement_mask = raw_movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate] + movement_mask = movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate] + + # 比较剔除的体动,如果被剔除的体动所在区域有高于3std的幅值,则不剔除 + removed_movement_mask = (raw_movement_mask - movement_mask) > 0 + removed_movement_start = np.where(np.diff(np.concatenate([[0], removed_movement_mask])) == 1)[0] + removed_movement_end = np.where(np.diff(np.concatenate([removed_movement_mask, [0]])) == -1)[0] + + for start, end in zip(removed_movement_start, removed_movement_end): + # print(start ,end) + # 计算剔除的体动区域的幅值 + if np.nanmax(signal_data[start * sampling_rate:(end + 1) * sampling_rate]) > median_std * std_median_multiplier: + movement_mask[start:end + 1] = 1 + + # raw体动起止位置 [[start, end], [start, end], ...] + raw_movement_position_list = event_mask_2_list(raw_movement_mask) + + # merge体动起止位置 [[start, end], [start, end], ...] + movement_position_list = event_mask_2_list(movement_mask) + + return raw_movement_mask, movement_mask, raw_movement_position_list, movement_position_list + + +def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up_interval_multiplier: float, + down_interval_multiplier: float, compare_intervals_sec, merge_gap_sec, min_duration_sec, verbose=False): + """ + 基于标准差对已有体动掩码进行修正。 用于大尺度的体动检测后的位置精细修正 + + 参数: + - signal_data: numpy array,输入的信号数据 + - sampling_rate: int,信号的采样率(Hz) + - movement_mask: numpy array,已有的体动掩码(1表示体动,0表示睡眠) + + 返回: + - revised_movement_mask: numpy array,修正后的体动掩码 + """ + window_size = sampling_rate + stride_size = sampling_rate + + time_points = np.arange(len(signal_data)) + + compare_size = int(compare_intervals_sec // (stride_size / sampling_rate)) + + _, mav = calc_mav_by_slide_windows(signal_data, movement_mask=None, low_amp_mask=None, sampling_rate=sampling_rate, + window_second=4, step_second=1, + inner_window_second=4) + + # 往左右两边取compare_size个点的mav,取平均值 + for start, end in movement_list: + left_points = start - 20 + right_points = end + 20 + + left_values = collect_values(arr=mav, index=left_points, step=-1, limit=compare_size, mask=movement_mask) + right_values = collect_values(arr=mav, index=right_points, step=1, limit=compare_size, mask=movement_mask) + + left_value_metrics = np.median(left_values) if len(left_values) > 0 else 0 + right_value_metrics = np.median(right_values) if len(right_values) > 0 else 0 + + # if left_value_metrics == 0: + # value_metrics = right_value_metrics + # elif right_value_metrics == 0: + # value_metrics = left_value_metrics + # else: + # value_metrics = np.mean([left_value_metrics, right_value_metrics]) + + if left_value_metrics == 0: + left_value_metrics = right_value_metrics + elif right_value_metrics == 0: + right_value_metrics = left_value_metrics + + if verbose: + print(f"Revising movement from index {start} to {end}, left_metric: {left_value_metrics:.2f}, right_metric: {right_value_metrics:.2f}") + + for i in range(left_points, right_points): + if i < 0 or i >= len(mav): + continue + if i < start: + value_metrics = left_value_metrics + elif i > end: + value_metrics = right_value_metrics + else: + value_metrics = (left_value_metrics + right_value_metrics) / 2 + + if mav[i] > (value_metrics * up_interval_multiplier) and movement_mask[i] == 0: + movement_mask[i] = 1 + if verbose: + print(f"Normal revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_metrics * up_interval_multiplier:.2f}") + elif mav[i] < (value_metrics * down_interval_multiplier) and movement_mask[i] == 1: + movement_mask[i] = 0 + if verbose: + print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_metrics * down_interval_multiplier:.2f}") + else: + if verbose: + print(f"No revision at index {i}, mav: {mav[i]:.2f}, up_threshold: {value_metrics * up_interval_multiplier:.2f}, down_threshold: {value_metrics * down_interval_multiplier:.2f}") + # + # 逐秒遍历mav,判断是否需要修正 + # print(f"Revising movement from index {start} to {end}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") + # for i in range(left_points, right_points): + # if i < 0 or i >= len(mav): + # continue + # # print(f"Index {i}, mav: {mav[i]:.2f}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") + # if mav[i] > (value_metrics * up_interval_multiplier): + # movement_mask[i] = 1 + # # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * up_interval_multiplier:.2f}") + # elif mav[i] < (value_metrics * down_interval_multiplier): + # movement_mask[i] = 0 + # # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * down_interval_multiplier:.2f}") + # # else: + # # print(f"No revision at index {i}, mav: {mav[i]:.2f}, up_threshold: {value_mean * up_interval_multiplier:.2f}, down_threshold: {value_mean * down_interval_multiplier:.2f}") + # # + # 如果需要合并间隔小的体动状态 + if merge_gap_sec > 0: + movement_mask = merge_short_gaps(movement_mask, time_points, merge_gap_sec) + + # 如果需要移除短时体动状态 + if min_duration_sec > 0: + movement_mask = remove_short_durations(movement_mask, time_points, min_duration_sec) + + movement_list = event_mask_2_list(movement_mask) + return movement_mask, movement_list + + +@timing_decorator() +def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, stride_sec=None, + amplitude_threshold=50, merge_gap_sec=10, min_duration_sec=5): + """ + 检测信号中的低幅值状态,通过计算RMS值判断信号强度是否低于设定阈值。 + + 参数: + - signal_data: numpy array,输入的信号数据 + - sampling_rate: int,信号的采样率(Hz) + - window_size_sec: float,分析窗口的时长(秒),默认值为 1 秒 + - stride_sec: float,窗口滑动步长(秒),默认值为None(等于window_size_sec,无重叠) + - amplitude_threshold: float,RMS阈值,低于此值表示低幅值状态,默认值为 50 + - merge_gap_sec: float,要合并的状态之间的最大间隔(秒),默认值为 10 秒 + - min_duration_sec: float,要保留的状态的最小持续时间(秒),默认值为 5 秒 + + 返回: + - low_amplitude_mask: numpy array,低幅值状态的掩码(1表示低幅值,0表示正常幅值) + """ + # 计算窗口大小(样本数) + window_samples = int(window_size_sec * sampling_rate) + + # 如果未指定步长,设置为窗口大小(无重叠) + if stride_sec is None: + stride_sec = window_size_sec + + # 计算步长(样本数) + stride_samples = int(stride_sec * sampling_rate) + + # 确保步长至少为1 + stride_samples = max(sampling_rate, stride_samples) + + # 处理信号边界,使用反射填充 + pad_size = window_samples // 2 + padded_signal = np.pad(signal_data, pad_size, mode='reflect') + + # 计算填充后的窗口数量 + num_windows = max(1, (len(padded_signal) - window_samples) // stride_samples + 1) + + # 初始化RMS值数组 + rms_values = np.zeros(num_windows) + + # 计算每个窗口的RMS值 + for i in range(num_windows): + start_idx = i * stride_samples + end_idx = min(start_idx + window_samples, len(signal_data)) + + # 处理窗口,包括可能不完整的最后一个窗口 + window_data = signal_data[start_idx:end_idx] + if len(window_data) > 0: + rms_values[i] = np.sqrt(np.mean(window_data ** 2)) + else: + rms_values[i] = 0 + + # 生成初始低幅值掩码:RMS低于阈值的窗口标记为1(低幅值),其他为0 + low_amplitude_mask = np.where(rms_values <= amplitude_threshold, 1, 0) + + # 计算原始信号对应的窗口索引范围 + orig_start_window = pad_size // stride_samples + if stride_sec == 1: + orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + else: + orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + 1 + + # 只保留原始信号对应的窗口低幅值掩码 + low_amplitude_mask = low_amplitude_mask[orig_start_window:orig_end_window] + # print("low_amplitude_mask_length: ", len(low_amplitude_mask)) + num_original_windows = len(low_amplitude_mask) + + # 转换为时间轴上的状态序列 + # 计算每个窗口对应的时间点(秒) + time_points = np.arange(num_original_windows) * stride_sec + + # 如果需要合并间隔小的状态 + if merge_gap_sec > 0: + low_amplitude_mask = merge_short_gaps(low_amplitude_mask, time_points, merge_gap_sec) + + # 如果需要移除短时状态 + if min_duration_sec > 0: + low_amplitude_mask = remove_short_durations(low_amplitude_mask, time_points, min_duration_sec) + + low_amplitude_mask = low_amplitude_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate] + + # 低幅值状态起止位置 [[start, end], [start, end], ...] + low_amplitude_position_list = event_mask_2_list(low_amplitude_mask) + + return low_amplitude_mask, low_amplitude_position_list + + +def get_typical_segment_for_continues_signal(signal_data, sampling_rate=100, window_size=30, step_size=1): + """ + 获取十个片段 + :param signal_data: 信号数据 + :param sampling_rate: 采样率 + :param window_size: 窗口大小(秒) + :param step_size: 步长(秒) + :return: 典型片段列表 + """ + pass + + +# 基于体动位置和幅值的睡姿识别 +# 主要是依靠体动mask,将整夜分割成多个有效片段,然后每个片段计算幅值指标,判断两个片段的幅值指标是否存在显著差异,如果存在显著差异,则认为存在睡姿变化 +# 考虑到每个片段长度为10s,所以每个片段的幅值指标计算时间长度为10s,然后计算每个片段的幅值指标 +# 仅对比相邻片段的幅值指标,如果存在显著差异,则认为存在睡姿变化,即每个体动相邻的30秒内存在睡姿变化,如果片段不足30秒,则按实际长度对比 + +@timing_decorator() +def position_based_sleep_recognition_v1(signal_data, movement_mask, sampling_rate=100, window_size_sec=30, + interval_to_movement=10): + mav_calc_window_sec = 2 # 计算mav的窗口大小,单位秒 + # 获取有效片段起止位置 + valid_mask = 1 - movement_mask + valid_starts = np.where(np.diff(np.concatenate([[0], valid_mask])) == 1)[0] + valid_ends = np.where(np.diff(np.concatenate([valid_mask, [0]])) == -1)[0] + + movement_start = np.where(np.diff(np.concatenate([[0], movement_mask])) == 1)[0] + movement_end = np.where(np.diff(np.concatenate([movement_mask, [0]])) == -1)[0] + + segment_left_average_amplitude = [] + segment_right_average_amplitude = [] + segment_left_average_energy = [] + segment_right_average_energy = [] + + # window_samples = int(window_size_sec * sampling_rate) + # pad_size = window_samples // 2 + # padded_signal = np.pad(signal_data, pad_size, mode='reflect') + + for start, end in zip(valid_starts, valid_ends): + start *= sampling_rate + end *= sampling_rate + # 避免过短的片段 + if end - start <= sampling_rate: # 小于1秒的片段不考虑 + continue + # 获取当前片段数据 + + + elif end - start < (window_size_sec * interval_to_movement + 1) * sampling_rate: + left_start = start + left_end = min(end, left_start + window_size_sec * sampling_rate) + right_start = max(start, end - window_size_sec * sampling_rate) + right_end = end + else: + left_start = start + interval_to_movement * sampling_rate + left_end = left_start + window_size_sec * sampling_rate + right_start = end - interval_to_movement * sampling_rate - window_size_sec * sampling_rate + right_end = end + + # 新的end - start确保为200的整数倍 + if (left_end - left_start) % (mav_calc_window_sec * sampling_rate) != 0: + left_end = left_start + ((left_end - left_start) // (mav_calc_window_sec * sampling_rate)) * ( + mav_calc_window_sec * sampling_rate) + if (right_end - right_start) % (mav_calc_window_sec * sampling_rate) != 0: + right_end = right_start + ((right_end - right_start) // (mav_calc_window_sec * sampling_rate)) * ( + mav_calc_window_sec * sampling_rate) + + # 计算每个片段的幅值指标 + left_mav = np.mean(np.max(signal_data[left_start:left_end].reshape(-1, mav_calc_window_sec * sampling_rate), + axis=0)) - np.mean( + np.min(signal_data[left_start:left_end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) + right_mav = np.mean( + np.max(signal_data[right_start:right_end].reshape(-1, mav_calc_window_sec * sampling_rate), + axis=0)) - np.mean( + np.min(signal_data[right_start:right_end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) + segment_left_average_amplitude.append(left_mav) + segment_right_average_amplitude.append(right_mav) + + left_energy = np.sum(np.abs(signal_data[left_start:left_end] ** 2)) + right_energy = np.sum(np.abs(signal_data[right_start:right_end] ** 2)) + segment_left_average_energy.append(left_energy) + segment_right_average_energy.append(right_energy) + + position_changes = [] + position_change_times = [] + # 判断是否存在显著变化 (可根据实际情况调整阈值) + threshold_amplitude = 0.1 # 幅值变化阈值 + threshold_energy = 0.1 # 能量变化阈值 + + for i in range(1, len(segment_left_average_amplitude)): + # 计算幅值指标的变化率 + left_amplitude_change = abs(segment_left_average_amplitude[i] - segment_left_average_amplitude[i - 1]) / max( + segment_left_average_amplitude[i - 1], 1e-6) + right_amplitude_change = abs(segment_right_average_amplitude[i] - segment_right_average_amplitude[i - 1]) / max( + segment_right_average_amplitude[i - 1], 1e-6) + + # 计算能量指标的变化率 + left_energy_change = abs(segment_left_average_energy[i] - segment_left_average_energy[i - 1]) / max( + segment_left_average_energy[i - 1], 1e-6) + right_energy_change = abs(segment_right_average_energy[i] - segment_right_average_energy[i - 1]) / max( + segment_right_average_energy[i - 1], 1e-6) + + # 如果左右通道中的任一通道同时满足幅值和能量的变化阈值,则认为存在姿势变化 + left_significant_change = (left_amplitude_change > threshold_amplitude) and ( + left_energy_change > threshold_energy) + right_significant_change = (right_amplitude_change > threshold_amplitude) and ( + right_energy_change > threshold_energy) + + if left_significant_change or right_significant_change: + # 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示 + position_changes.append(1) + position_change_times.append((movement_start[i - 1], movement_end[i - 1])) + else: + position_changes.append(0) # 0表示不存在姿势变化 + + return position_changes, position_change_times + + +def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rate=100): + """ + + :param signal_data: + :param movement_mask: mask的采样率为1Hz + :param sampling_rate: + :param window_size_sec: + :return: + """ + mav_calc_window_sec = 2 # 计算mav的窗口大小,单位秒 + # 获取有效片段起止位置 + valid_mask = 1 - movement_mask + valid_starts = np.where(np.diff(np.concatenate([[0], valid_mask])) == 1)[0] + valid_ends = np.where(np.diff(np.concatenate([valid_mask, [0]])) == -1)[0] + + movement_start = np.where(np.diff(np.concatenate([[0], movement_mask])) == 1)[0] + movement_end = np.where(np.diff(np.concatenate([movement_mask, [0]])) == -1)[0] + + segment_average_amplitude = [] + segment_average_energy = [] + + signal_data_no_movement = signal_data.copy() + for start, end in zip(movement_start, movement_end): + signal_data_no_movement[start * sampling_rate:end * sampling_rate] = np.nan + + # from matplotlib import pyplot as plt + # plt.plot(signal_data, alpha=0.3, color='gray') + # plt.plot(signal_data_no_movement, color='blue', linewidth=1) + # plt.show() + + for start, end in zip(valid_starts, valid_ends): + start *= sampling_rate + end *= sampling_rate + # 避免过短的片段 + if end - start <= sampling_rate: # 小于1秒的片段不考虑 + continue + + # 新的end - start确保为200的整数倍 + if (end - start) % (mav_calc_window_sec * sampling_rate) != 0: + end = start + ((end - start) // (mav_calc_window_sec * sampling_rate)) * ( + mav_calc_window_sec * sampling_rate) + + # 计算每个片段的幅值指标 + mav = np.nanmean( + np.nanmax(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), + axis=0)) - np.nanmean( + np.nanmin(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) + segment_average_amplitude.append(mav) + + energy = np.nansum(np.abs(signal_data_no_movement[start:end] ** 2)) + segment_average_energy.append(energy) + + position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int) + position_change_times = [] + # 判断是否存在显著变化 (可根据实际情况调整阈值) + threshold_amplitude = 0.1 # 幅值变化阈值 + threshold_energy = 0.1 # 能量变化阈值 + + for i in range(1, len(segment_average_amplitude)): + # 计算幅值指标的变化率 + amplitude_change = abs(segment_average_amplitude[i] - segment_average_amplitude[i - 1]) / max( + segment_average_amplitude[i - 1], 1e-6) + + # 计算能量指标的变化率 + energy_change = abs(segment_average_energy[i] - segment_average_energy[i - 1]) / max( + segment_average_energy[i - 1], 1e-6) + + significant_change = (amplitude_change > threshold_amplitude) and (energy_change > threshold_energy) + + if significant_change: + # 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示 + position_changes[movement_start[i - 1]:movement_end[i - 1]] = 1 + position_change_times.append((movement_start[i - 1], movement_end[i - 1])) + + return position_changes, position_change_times + + +def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_list, sampling_rate, mav_calc_window_sec, + threshold_amplitude, threshold_energy, verbose=False): + """ + + :param threshold_energy: + :param threshold_amplitude: + :param mav_calc_window_sec: + :param movement_list: + :param signal_data: + :param movement_mask: mask的采样率为1Hz + :param sampling_rate: + :param window_size_sec: + :return: + """ + # 获取有效片段起止位置 + + valid_list = utils.event_mask_2_list(movement_mask, event_true=False) + + segment_average_amplitude = [] + segment_average_energy = [] + + signal_data_no_movement = signal_data.copy() + for start, end in movement_list: + signal_data_no_movement[start * sampling_rate:end * sampling_rate] = np.nan + + # from matplotlib import pyplot as plt + # plt.plot(signal_data, alpha=0.3, color='gray') + # plt.plot(signal_data_no_movement, color='blue', linewidth=1) + # plt.show() + + if len(valid_list) < 2: + return [], [] + + def clac_mav(data_segment): + # 确定data_segment长度为mav_calc_window_sec的整数倍 + if len(data_segment) % (mav_calc_window_sec * sampling_rate) != 0: + data_segment = data_segment[:-(len(data_segment) % (mav_calc_window_sec * sampling_rate))] + + mav = np.nanstd( + np.nanmax(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), + axis=0) - + np.nanmin(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) + return mav + + def clac_energy(data_segment): + energy = np.nansum(np.abs(data_segment ** 2)) // (len(data_segment) // sampling_rate) + return energy + + def calc_mav_by_quantiles(data_segment): + # 先计算所有的mav值 + if len(data_segment) % (mav_calc_window_sec * sampling_rate) != 0: + data_segment = data_segment[:-(len(data_segment) % (mav_calc_window_sec * sampling_rate))] + + mav_values = np.nanmax(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=1) - np.nanmin( + data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=1) + # 计算分位数 + q20 = np.nanpercentile(mav_values, 20) + q80 = np.nanpercentile(mav_values, 80) + + mav_values = mav_values[(mav_values >= q20) & (mav_values <= q80)] + mav = np.nanmean(mav_values) + return mav + + position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int) + position_change_list = [] + + pre_valid_start = valid_list[0][0] * sampling_rate + pre_valid_end = valid_list[0][1] * sampling_rate + + if verbose: + print(f"Total movement segments to analyze: {len(movement_list)}") + print(f"Total valid segments available: {len(valid_list)}") + + for i in range(len(movement_list)): + if verbose: + print(f"Analyzing movement segment {i + 1}/{len(movement_list)}") + + if i + 1 >= len(valid_list): + if verbose: + print("No more valid segments to compare. Ending analysis.") + break + + next_valid_start = valid_list[i + 1][0] * sampling_rate + next_valid_end = valid_list[i + 1][1] * sampling_rate + + movement_start = movement_list[i][0] + movement_end = movement_list[i][1] + + # 避免过短的片段 + if movement_end - movement_start <= 1: # 小于1秒的片段不考虑 + if verbose: + print( + f"Skipping movement segment {i + 1} due to insufficient length. movement start: {movement_start}, movement end: {movement_end}") + pre_valid_start = pre_valid_start + pre_valid_end = next_valid_end + continue + + # 计算前后片段的幅值和能量 + # left_mav = clac_mav(signal_data_no_movement[pre_valid_start:pre_valid_end]) + # right_mav = clac_mav(signal_data_no_movement[next_valid_start:next_valid_end]) + # left_energy = clac_energy(signal_data_no_movement[pre_valid_start:pre_valid_end]) + # right_energy = clac_energy(signal_data_no_movement[next_valid_start:next_valid_end]) + + left_mav = calc_mav_by_quantiles(signal_data_no_movement[pre_valid_start:pre_valid_end]) + right_mav = calc_mav_by_quantiles(signal_data_no_movement[next_valid_start:next_valid_end]) + + + # 计算幅值指标的变化率 + amplitude_change = abs(right_mav - left_mav) / max(left_mav, 1e-6) + # # 计算能量指标的变化率 + # energy_change = abs(right_energy - left_energy) / max(left_energy, 1e-6) + + # significant_change = (amplitude_change > threshold_amplitude) and (energy_change > threshold_energy) + significant_change = (amplitude_change > threshold_amplitude) + if significant_change: + # print( + # f"Significant position change detected between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}") + if verbose: + print(f"Significant position change detected between segments {movement_start}s and {movement_end}:s left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right:{next_valid_start}to{next_valid_end} right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}") + # 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示 + position_changes[movement_start:movement_end] = 1 + position_change_list.append(movement_list[i]) + # 更新前后片段 + pre_valid_start = next_valid_start + pre_valid_end = next_valid_end + + else: + # print( + # f"No significant position change between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}") + if verbose: + print(f"No significant position change between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}") + + # 仅更新前片段 + pre_valid_start = pre_valid_start + pre_valid_end = next_valid_end + + return position_changes, position_change_list diff --git a/signal_method/shhs_tools.py b/signal_method/shhs_tools.py new file mode 100644 index 0000000..fc4c653 --- /dev/null +++ b/signal_method/shhs_tools.py @@ -0,0 +1,62 @@ +import xml.etree.ElementTree as ET +ANNOTATION_MAP = { + "Wake|0": 0, + "Stage 1 sleep|1": 1, + "Stage 2 sleep|2": 2, + "Stage 3 sleep|3": 3, + "Stage 4 sleep|4": 4, + "REM sleep|5": 5, + "Unscored|9": 9, + "Movement|6": 6 +} + +SA_EVENTS = ['Central apnea', 'Hypopnea', 'Obstructive apnea'] + +def parse_sleep_annotations(annotation_path): + """解析睡眠分期注释""" + try: + tree = ET.parse(annotation_path) + root = tree.getroot() + events = [] + for scored_event in root.findall('.//ScoredEvent'): + event_type = scored_event.find('EventType').text + if event_type != "Stages|Stages": + continue + description = scored_event.find('EventConcept').text + start = float(scored_event.find('Start').text) + duration = float(scored_event.find('Duration').text) + if description not in ANNOTATION_MAP: + continue + events.append({ + 'onset': start, + 'duration': duration, + 'description': description, + 'stage': ANNOTATION_MAP[description] + }) + return events + except Exception as e: + return None + + +def extract_osa_events(annotation_path): + """提取睡眠呼吸暂停事件""" + try: + tree = ET.parse(annotation_path) + root = tree.getroot() + events = [] + for scored_event in root.findall('.//ScoredEvent'): + event_concept = scored_event.find('EventConcept').text + event_type = event_concept.split('|')[0].strip() + if event_type in SA_EVENTS: + start = float(scored_event.find('Start').text) + duration = float(scored_event.find('Duration').text) + if duration >= 10: + events.append({ + 'start': start, + 'duration': duration, + 'end': start + duration, + 'type': event_type + }) + return events + except Exception as e: + return [] \ No newline at end of file diff --git a/signal_method/signal_process.py b/signal_method/signal_process.py new file mode 100644 index 0000000..168de8f --- /dev/null +++ b/signal_method/signal_process.py @@ -0,0 +1,107 @@ +import numpy as np +from scipy.interpolate import interp1d + +import utils + +def signal_filter_split(conf, signal_data_raw, signal_fs, verbose=True): + # 滤波 + # 50Hz陷波滤波器 + # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) + if verbose: + print("Applying 50Hz notch filter...") + signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) + + resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs) + resp_fs = conf["resp"]["downsample_fs_1"] + resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs) + resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20) + resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"], + low_cut=conf["resp_filter"]["low_cut"], + high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"], + sample_rate=resp_fs) + if verbose: + print("Begin plotting signal data...") + + # fig = plt.figure(figsize=(12, 8)) + # # 绘制三个图raw_data、resp_data_1、resp_data_2 + # ax0 = fig.add_subplot(3, 1, 1) + # ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue') + # ax0.set_title('Raw Signal Data') + # ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) + # ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange') + # ax1.set_title('Resp Data after Average Filtering') + # ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) + # ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green') + # ax2.set_title('Resp Data after Butterworth Filtering') + # plt.tight_layout() + # plt.show() + + bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"], + low_cut=conf["bcg_filter"]["low_cut"], + high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"], + sample_rate=signal_fs) + + + return signal_data, resp_data_2, resp_fs, bcg_data, signal_fs + + +def psg_effort_filter(conf, effort_data_raw, effort_fs): + # 滤波 + effort_data_1 = utils.bessel(data=effort_data_raw, _type=conf["effort_filter"]["filter_type"], + low_cut=conf["effort_filter"]["low_cut"], + high_cut=conf["effort_filter"]["high_cut"], order=conf["effort_filter"]["order"], + sample_rate=effort_fs) + # 移动平均 + effort_data_2 = utils.average_filter(raw_data=effort_data_1, sample_rate=effort_fs, window_size_sec=conf["average_filter"]["window_size_sec"]) + return effort_data_raw, effort_data_2, effort_fs + + +def rpeak2hr(rpeak_indices, signal_length, ecg_fs): + hr_signal = np.zeros(signal_length) + for i in range(1, len(rpeak_indices)): + rri = rpeak_indices[i] - rpeak_indices[i - 1] + if rri == 0: + continue + hr = 60 * ecg_fs / rri # 心率,单位:bpm + if hr > 120: + hr = 120 + elif hr < 30: + hr = 30 + hr_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = hr + # 填充最后一个R峰之后的心率值 + if len(rpeak_indices) > 1: + hr_signal[rpeak_indices[-1]:] = hr_signal[rpeak_indices[-2]] + return hr_signal + +def rpeak2rri_repeat(rpeak_indices, signal_length, ecg_fs): + rri_signal = np.zeros(signal_length) + for i in range(1, len(rpeak_indices)): + rri = rpeak_indices[i] - rpeak_indices[i - 1] + rri_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = rri + # 填充最后一个R峰之后的RRI值 + if len(rpeak_indices) > 1: + rri_signal[rpeak_indices[-1]:] = rri_signal[rpeak_indices[-2]] + + # 遍历异常值 + for i in range(1, len(rpeak_indices)): + rri = rpeak_indices[i] - rpeak_indices[i - 1] + if rri < 0.3 * ecg_fs or rri > 2 * ecg_fs: + rri_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = 0 + + return rri_signal + +def rpeak2rri_interpolation(rpeak_indices, ecg_fs, rri_fs): + r_peak_time = np.asarray(rpeak_indices) / ecg_fs + rri = np.diff(r_peak_time) + t_rri = r_peak_time[1:] + + mask = (rri > 0.3) & (rri < 2.0) + rri_clean = rri[mask] + t_rri_clean = t_rri[mask] + + t_uniform = np.arange(t_rri_clean[0], t_rri_clean[-1], 1/rri_fs) + f = interp1d(t_rri_clean, rri_clean, kind='linear', fill_value="extrapolate") + rri_uniform = f(t_uniform) + + return rri_uniform, rri_fs + diff --git a/signal_method/time_metrics.py b/signal_method/time_metrics.py new file mode 100644 index 0000000..7895c3d --- /dev/null +++ b/signal_method/time_metrics.py @@ -0,0 +1,45 @@ +from utils.operation_tools import calculate_by_slide_windows +from utils.operation_tools import timing_decorator +import numpy as np + + +@timing_decorator() +def calc_mav_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window_second=10, step_second=1, inner_window_second=2): + if movement_mask is not None: + assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" + # assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" + # print(f"movement_mask_length: {len(movement_mask)}, signal_data_length: {len(signal_data)}") + processed_mask = movement_mask.copy() + else: + processed_mask = None + + def mav_func(x): + return np.mean(np.nanmax(x.reshape(-1, inner_window_second*sampling_rate), axis=1) - np.nanmin(x.reshape(-1, inner_window_second*sampling_rate), axis=1)) / 2 + mav_nan, mav = calculate_by_slide_windows(mav_func, signal_data, processed_mask, sampling_rate=sampling_rate, + window_second=window_second, step_second=step_second) + + return mav_nan, mav + +@timing_decorator() +def calc_wavefactor_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100): + assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" + assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" + + processed_mask = movement_mask.copy() + processed_mask = processed_mask.repeat(sampling_rate) + wavefactor_nan, wavefactor = calculate_by_slide_windows(lambda x: np.sqrt(np.mean(x ** 2)) / np.mean(np.abs(x)), + signal_data, processed_mask, sampling_rate=sampling_rate, window_second=2, step_second=1) + + return wavefactor_nan, wavefactor + +@timing_decorator() +def calc_peakfactor_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100): + assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" + assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" + + processed_mask = movement_mask.copy() + processed_mask = processed_mask.repeat(sampling_rate) + peakfactor_nan, peakfactor = calculate_by_slide_windows(lambda x: np.max(np.abs(x)) / np.sqrt(np.mean(x ** 2)), + signal_data, processed_mask, sampling_rate=sampling_rate, window_second=2, step_second=1) + + return peakfactor_nan, peakfactor \ No newline at end of file diff --git a/utils/HYS_FileReader.py b/utils/HYS_FileReader.py new file mode 100644 index 0000000..d2fb03f --- /dev/null +++ b/utils/HYS_FileReader.py @@ -0,0 +1,354 @@ +from pathlib import Path +from typing import Union + +import utils +from .event_map import N2Chn +import numpy as np +import pandas as pd +from .operation_tools import event_mask_2_list +# 尝试导入 Polars +try: + import polars as pl + + HAS_POLARS = True +except ImportError: + HAS_POLARS = False + + +def read_signal_txt(path: Union[str, Path], dtype, verbose=True, is_peak=False): + """ + Read a txt file and return the first column as a numpy array. + + Args: + :param is_peak: + :param path: + :param verbose: + :param dtype: + Returns: + np.ndarray: The first column of the txt file as a numpy array. + + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + if HAS_POLARS: + df = pl.read_csv(path, has_header=False, infer_schema_length=0) + signal_data_raw = df[:, 0].to_numpy().astype(dtype) + else: + df = pd.read_csv(path, header=None, dtype=dtype) + signal_data_raw = df.iloc[:, 0].to_numpy() + + signal_original_length = len(signal_data_raw) + signal_fs = int(path.stem.split("_")[-1]) + if is_peak: + signal_second = None + signal_length = None + else: + signal_second = signal_original_length // signal_fs + # 根据采样率进行截断 + signal_data_raw = signal_data_raw[:signal_second * signal_fs] + signal_length = len(signal_data_raw) + + if verbose: + print(f"Signal file read from {path}") + print(f"signal_fs: {signal_fs}") + print(f"signal_original_length: {signal_original_length}") + print(f"signal_after_cut_off_length: {signal_length}") + print(f"signal_second: {signal_second}") + + return signal_data_raw, signal_length, signal_fs, signal_second + + +def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame: + """ + Read a CSV file and return it as a pandas DataFrame. + + Args: + path (str | Path): Path to the CSV file. + verbose (bool): + Returns: + pd.DataFrame: The content of the CSV file as a pandas DataFrame. + :param path: + :param verbose: + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + # 直接用pandas读取 包含中文 故指定编码 + df = pd.read_csv(path, encoding="gbk") + if verbose: + print(f"Label file read from {path}, number of rows: {len(df)}") + + # 统计打标情况 + # isLabeled=1 表示已打标 + # Event type 有值的为PSG导出的事件 + # Event type 为nan的为手动打标的事件 + # score=1 显著事件, score=2 为受干扰事件 score=3 为非显著应删除事件 + # 确认后的事件在correct_EventsType + # 输出事件信息 按照总计事件、低通气、中枢性、阻塞性、混合型按行输出 格式为 总计/来自PSG/手动/删除/未标注 + # Columns: + # Index Event type Stage Time Epoch Date Duration HR bef. HR extr. HR delta O2 bef. O2 min. O2 delta Body Position Validation Start End score remark correct_Start correct_End correct_EventsType isLabeled + # Event type: + # Hypopnea + # Central apnea + # Obstructive apnea + # Mixed apnea + + num_total = np.sum((df["isLabeled"] == 1) & (df["score"] != 3)) + + num_psg_events = np.sum(df["Event type"].notna()) + num_manual_events = np.sum(df["Event type"].isna()) + + num_deleted = np.sum(df["score"] == 3) + + # 统计事件 + num_unlabeled = np.sum(df["isLabeled"] == -1) + + num_psg_hyp = np.sum(df["Event type"] == "Hypopnea") + num_psg_csa = np.sum(df["Event type"] == "Central apnea") + num_psg_osa = np.sum(df["Event type"] == "Obstructive apnea") + num_psg_msa = np.sum(df["Event type"] == "Mixed apnea") + + num_hyp = np.sum((df["correct_EventsType"] == "Hypopnea") & (df["score"] != 3)) + num_csa = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] != 3)) + num_osa = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] != 3)) + num_msa = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] != 3)) + + num_manual_hyp = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Hypopnea")) + num_manual_csa = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Central apnea")) + num_manual_osa = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Obstructive apnea")) + num_manual_msa = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Mixed apnea")) + + num_deleted_hyp = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Hypopnea")) + num_deleted_csa = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Central apnea")) + num_deleted_osa = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Obstructive apnea")) + num_deleted_msa = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Mixed apnea")) + + num_unlabeled_hyp = np.sum((df["isLabeled"] == -1) & (df["Event type"] == "Hypopnea")) + num_unlabeled_csa = np.sum((df["isLabeled"] == -1) & (df["Event type"] == "Central apnea")) + num_unlabeled_osa = np.sum((df["isLabeled"] == -1) & (df["Event type"] == "Obstructive apnea")) + num_unlabeled_msa = np.sum((df["isLabeled"] == -1) & (df["Event type"] == "Mixed apnea")) + + num_hyp_1_score = np.sum((df["correct_EventsType"] == "Hypopnea") & (df["score"] == 1)) + num_csa_1_score = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] == 1)) + num_osa_1_score = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] == 1)) + num_msa_1_score = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] == 1)) + + num_hyp_2_score = np.sum((df["correct_EventsType"] == "Hypopnea") & (df["score"] == 2)) + num_csa_2_score = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] == 2)) + num_osa_2_score = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] == 2)) + num_msa_2_score = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] == 2)) + + num_hyp_3_score = np.sum((df["correct_EventsType"] == "Hypopnea") & (df["score"] == 3)) + num_csa_3_score = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] == 3)) + num_osa_3_score = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] == 3)) + num_msa_3_score = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] == 3)) + + num_1_score = np.sum(df["score"] == 1) + num_2_score = np.sum(df["score"] == 2) + num_3_score = np.sum(df["score"] == 3) + + if verbose: + print("Event Statistics:") + # 格式化输出 总计/来自PSG/手动/删除/未标注 指定宽度 + print(f"Type {'Total':^8s} / {'From PSG':^8s} / {'Manual':^8s} / {'Deleted':^8s} / {'Unlabeled':^8s}") + print( + f"Hyp: {num_hyp:^8d} / {num_psg_hyp:^8d} / {num_manual_hyp:^8d} / {num_deleted_hyp:^8d} / {num_unlabeled_hyp:^8d}") + print( + f"CSA: {num_csa:^8d} / {num_psg_csa:^8d} / {num_manual_csa:^8d} / {num_deleted_csa:^8d} / {num_unlabeled_csa:^8d}") + print( + f"OSA: {num_osa:^8d} / {num_psg_osa:^8d} / {num_manual_osa:^8d} / {num_deleted_osa:^8d} / {num_unlabeled_osa:^8d}") + print( + f"MSA: {num_msa:^8d} / {num_psg_msa:^8d} / {num_manual_msa:^8d} / {num_deleted_msa:^8d} / {num_unlabeled_msa:^8d}") + print( + f"All: {num_total:^8d} / {num_psg_events:^8d} / {num_manual_events:^8d} / {num_deleted:^8d} / {num_unlabeled:^8d}") + + print("Score Statistics (only for non-deleted events and manual created events):") + print(f"Type {'Total':^8s} / {'Score 1':^8s} / {'Score 2':^8s} / {'Score 3':^8s}") + print(f"Hyp: {num_hyp:^8d} / {num_hyp_1_score:^8d} / {num_hyp_2_score:^8d} / {num_hyp_3_score:^8d}") + print(f"CSA: {num_csa:^8d} / {num_csa_1_score:^8d} / {num_csa_2_score:^8d} / {num_csa_3_score:^8d}") + print(f"OSA: {num_osa:^8d} / {num_osa_1_score:^8d} / {num_osa_2_score:^8d} / {num_osa_3_score:^8d}") + print(f"MSA: {num_msa:^8d} / {num_msa_1_score:^8d} / {num_msa_2_score:^8d} / {num_msa_3_score:^8d}") + print(f"All: {num_total:^8d} / {num_1_score:^8d} / {num_2_score:^8d} / {num_3_score:^8d}") + + df["Start"] = df["Start"].astype(int) + df["End"] = df["End"].astype(int) + return df + + +def read_raw_psg_label(path: Union[str, Path], verbose=True) -> pd.DataFrame: + """ + Read a CSV file and return it as a pandas DataFrame. + + Args: + path (str | Path): Path to the CSV file. + verbose (bool): + Returns: + pd.DataFrame: The content of the CSV file as a pandas DataFrame. + :param path: + :param verbose: + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + # 直接用pandas读取 包含中文 故指定编码 + df = pd.read_csv(path, encoding="gbk") + if verbose: + print(f"Label file read from {path}, number of rows: {len(df)}") + + num_psg_events = np.sum(df["Event type"].notna()) + # 统计事件 + num_psg_hyp = np.sum(df["Event type"] == "Hypopnea") + num_psg_csa = np.sum(df["Event type"] == "Central apnea") + num_psg_osa = np.sum(df["Event type"] == "Obstructive apnea") + num_psg_msa = np.sum(df["Event type"] == "Mixed apnea") + + + + if verbose: + print("Event Statistics:") + # 格式化输出 总计/来自PSG/手动/删除/未标注 指定宽度 + print(f"Type {'Total':^8s}") + print( + f"Hyp: {num_psg_hyp:^8d} ") + print( + f"CSA: {num_psg_csa:^8d} ") + print( + f"OSA: {num_psg_osa:^8d} ") + print( + f"MSA: {num_psg_msa:^8d} ") + print( + f"All: {num_psg_events:^8d}") + + df["Start"] = df["Start"].astype(int) + df["End"] = df["End"].astype(int) + return df + +def read_disable_excel(path: Union[str, Path]) -> pd.DataFrame: + """ + Read an Excel file and return it as a pandas DataFrame. + + Args: + path (str | Path): Path to the Excel file. + Returns: + pd.DataFrame: The content of the Excel file as a pandas DataFrame. + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + # 直接用pandas读取 + df = pd.read_excel(path) + df["id"] = df["id"].astype(int) + df["start"] = df["start"].astype(int) + df["end"] = df["end"].astype(int) + return df + + +def read_mask_execl(path: Union[str, Path]): + """ + Read an Excel file and return the mask as a numpy array. + Args: + path (str | Path): Path to the Excel file. + Returns: + np.ndarray: The mask as a numpy array. + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + df = pd.read_csv(path) + event_mask = df.to_dict(orient="list") + for key in event_mask: + event_mask[key] = np.array(event_mask[key]) + + event_list = {"RespAmpChangeSegment": event_mask_2_list(1 - event_mask["Resp_AmpChange_Label"]), + "BCGAmpChangeSegment": event_mask_2_list(1 - event_mask["BCG_AmpChange_Label"]), + "EnableSegment": event_mask_2_list(1 - event_mask["Disable_Label"]), + "DisableSegment": event_mask_2_list(event_mask["Disable_Label"])} + + + return event_mask, event_list + + +def read_psg_mask_excel(path: Union[str, Path]): + + df = pd.read_csv(path) + event_mask = df.to_dict(orient="list") + for key in event_mask: + event_mask[key] = np.array(event_mask[key]) + + return event_mask + + +def read_psg_channel(path_str: Union[str, Path], channel_number: list[int], verbose=True): + """ + 读取PSG文件中特定通道的数据。 + + 参数: + path_str (Union[str, Path]): 存放PSG文件的文件夹路径。 + channel_name (str): 需要读取的通道名称。 + 返回: + np.ndarray: 指定通道的数据数组。 + """ + path = Path(path_str) + if not path.exists(): + raise FileNotFoundError(f"PSG Dir not found: {path}") + + if not path.is_dir(): + raise NotADirectoryError(f"PSG Dir not found: {path}") + channel_data = {} + # 遍历检查通道对应的文件是否存在 + for ch_id in channel_number: + ch_name = N2Chn[ch_id] + ch_path = list(path.glob(f"{ch_name}*.txt")) + + if not any(ch_path): + raise FileNotFoundError(f"PSG Channel file not found: {ch_path}") + + if len(ch_path) > 1: + print(f"Warning!!! PSG Channel file more than one: {ch_path}") + + if ch_id == 8: + # sleep stage 特例 读取为整数 + ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=str, verbose=verbose) + # 转换为整数数组 + for stage_str, stage_number in utils.Stage2N.items(): + np.place(ch_signal, ch_signal == stage_str, stage_number) + ch_signal = ch_signal.astype(int) + elif ch_id == 1: + # Rpeak 特例 读取为整数 + ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=int, verbose=verbose, is_peak=True) + else: + ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=float, verbose=verbose) + channel_data[ch_name] = { + "name": ch_name, + "path": ch_path[0], + "data": ch_signal, + "length": ch_length, + "fs": ch_fs, + "second": ch_second + } + + return channel_data + + +def read_psg_label(path: Union[str, Path], verbose=True): + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + # 直接用pandas读取 包含中文 故指定编码 + df = pd.read_csv(path, encoding="gbk") + if verbose: + print(f"Label file read from {path}, number of rows: {len(df)}") + + # 丢掉Event type为空的行 + df = df.dropna(subset=["Event type"], how='all').reset_index(drop=True) + + return df + + diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..c449802 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,12 @@ +from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel, read_psg_label, read_raw_psg_label, read_psg_mask_excel +from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list +from .operation_tools import merge_short_gaps, remove_short_durations +from .operation_tools import collect_values +from .operation_tools import save_process_label +from .operation_tools import none_to_nan_mask +from .operation_tools import get_wake_mask +from .operation_tools import fill_spo2_anomaly +from .split_method import resp_split +from .HYS_FileReader import read_mask_execl, read_psg_channel +from .event_map import E2N, N2Chn, Stage2N, ColorCycle +from .filter_func import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel, adjust_sample_rate diff --git a/utils/event_map.py b/utils/event_map.py new file mode 100644 index 0000000..20c6c58 --- /dev/null +++ b/utils/event_map.py @@ -0,0 +1,42 @@ +# apnea event type to number mapping +E2N = { + "Hypopnea": 1, + "Central apnea": 2, + "Obstructive apnea": 3, + "Mixed apnea": 4 +} + +N2Chn = { + 1: "Rpeak", + 2: "ECG_Sync", + 3: "Effort Tho", + 4: "Effort Abd", + 5: "Flow P", + 6: "Flow T", + 7: "SpO2", + 8: "5_class" +} + +Stage2N = { + "W": 5, + "N1": 3, + "N2": 2, + "N3": 1, + "R": 4, +} + +# 设定事件和其对应颜色 +# event_code color event +# 0 黑色 背景 +# 1 粉色 低通气 +# 2 蓝色 中枢性 +# 3 红色 阻塞型 +# 4 灰色 混合型 +# 5 绿色 血氧饱和度下降 +# 6 橙色 大体动 +# 7 橙色 小体动 +# 8 橙色 深呼吸 +# 9 橙色 脉冲体动 +# 10 橙色 无效片段 +ColorCycle = ["black", "pink", "blue", "red", "silver", "green", "orange", "orange", "orange", "orange", + "orange"] \ No newline at end of file diff --git a/utils/filter_func.py b/utils/filter_func.py new file mode 100644 index 0000000..5a00c55 --- /dev/null +++ b/utils/filter_func.py @@ -0,0 +1,155 @@ +from utils.operation_tools import timing_decorator +import numpy as np +from scipy import signal, ndimage + + +@timing_decorator() +def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10, sample_rate=1000): + if _type == "lowpass": # 低通滤波处理 + sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + elif _type == "bandpass": # 带通滤波处理 + low = low_cut / (sample_rate * 0.5) + high = high_cut / (sample_rate * 0.5) + sos = signal.butter(order, [low, high], btype='bandpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + elif _type == "highpass": # 高通滤波处理 + sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + else: # 警告,滤波器类型必须有 + raise ValueError("Please choose a type of fliter") + + +@timing_decorator() +def bessel(data, _type, low_cut=0.0, high_cut=0.0, order=4, sample_rate=1000): + if _type == "lowpass": # 低通滤波处理 + b, a = signal.bessel(order, low_cut / (sample_rate * 0.5), btype='lowpass', analog=False, norm='mag') + return signal.filtfilt(b, a, np.array(data)) + elif _type == "bandpass": # 带通滤波处理 + low = low_cut / (sample_rate * 0.5) + high = high_cut / (sample_rate * 0.5) + b, a = signal.bessel(order, [low, high], btype='bandpass', analog=False, norm='mag') + return signal.filtfilt(b, a, np.array(data)) + elif _type == "highpass": # 高通滤波处理 + b, a = signal.bessel(order, high_cut / (sample_rate * 0.5), btype='highpass', analog=False, norm='mag') + return signal.filtfilt(b, a, np.array(data)) + else: # 警告,滤波器类型必须有 + raise ValueError("Please choose a type of fliter") + + +@timing_decorator() +def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000): + """ + 高效整数倍降采样长信号,分段处理以优化内存和速度。 + + 参数: + original_signal : array-like, 原始信号数组 + original_fs : float, 原始采样率 (Hz) + target_fs : float, 目标采样率 (Hz) + chunk_size : int, 每段处理的样本数,默认100000 + + 返回: + downsampled_signal : array-like, 降采样后的信号 + """ + # 输入验证 + if not isinstance(original_signal, np.ndarray): + original_signal = np.array(original_signal) + if original_fs <= target_fs: + raise ValueError("目标采样率必须小于原始采样率") + if target_fs <= 0 or original_fs <= 0: + raise ValueError("采样率必须为正数") + + # 计算降采样因子(必须为整数) + downsample_factor = original_fs / target_fs + if not downsample_factor.is_integer(): + raise ValueError("降采样因子必须为整数倍") + downsample_factor = int(downsample_factor) + + # 计算总输出长度 + total_length = len(original_signal) + output_length = total_length // downsample_factor + + # 初始化输出数组 + downsampled_signal = np.zeros(output_length) + + # 分段处理 + for start in range(0, total_length, chunk_size): + end = min(start + chunk_size, total_length) + chunk = original_signal[start:end] + + # 使用decimate进行整数倍降采样 + chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True) + + # 计算输出位置 + out_start = start // downsample_factor + out_end = out_start + len(chunk_downsampled) + if out_end > output_length: + chunk_downsampled = chunk_downsampled[:output_length - out_start] + + downsampled_signal[out_start:out_end] = chunk_downsampled + + return downsampled_signal + +def upsample_signal(original_signal, original_fs, target_fs): + """ + 信号升采样 + + 参数: + original_signal : array-like, 原始信号数组 + original_fs : float, 原始采样率 (Hz) + target_fs : float, 目标采样率 (Hz) + + 返回: + upsampled_signal : array-like, 升采样后的信号 + """ + if not isinstance(original_signal, np.ndarray): + original_signal = np.array(original_signal) + if target_fs <= original_fs: + raise ValueError("目标采样率必须大于原始采样率") + if target_fs <= 0 or original_fs <= 0: + raise ValueError("采样率必须为正数") + + upsample_factor = target_fs / original_fs + num_output_samples = int(len(original_signal) * upsample_factor) + + upsampled_signal = signal.resample(original_signal, num_output_samples) + + return upsampled_signal + + +def adjust_sample_rate(signal_data, original_fs, target_fs): + """ + 根据信号的原始采样率和目标采样率,自动选择升采样或降采样。 + + 参数: + signal_data : array-like, 原始信号数组 + original_fs : float, 原始采样率 (Hz) + target_fs : float, 目标采样率 (Hz) + + 返回: + adjusted_signal : array-like, 调整采样率后的信号 + """ + if original_fs == target_fs: + return signal_data + elif original_fs > target_fs: + return downsample_signal_fast(signal_data, original_fs, target_fs) + else: + return upsample_signal(signal_data, original_fs, target_fs) + + +@timing_decorator() +def average_filter(raw_data, sample_rate, window_size_sec=20): + kernel = np.ones(window_size_sec * sample_rate) / (window_size_sec * sample_rate) + filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect') + convolve_filter_signal = raw_data - filtered + return convolve_filter_signal + + +# 陷波滤波器 +@timing_decorator() +def notch_filter(data, notch_freq=50.0, quality_factor=30.0, sample_rate=1000): + nyquist = 0.5 * sample_rate + norm_notch_freq = notch_freq / nyquist + b, a = signal.iirnotch(norm_notch_freq, quality_factor) + filtered_data = signal.filtfilt(b, a, data) + return filtered_data diff --git a/utils/operation_tools.py b/utils/operation_tools.py new file mode 100644 index 0000000..d6d7051 --- /dev/null +++ b/utils/operation_tools.py @@ -0,0 +1,380 @@ +import time + +from pathlib import Path +import numpy as np +import pandas as pd +from matplotlib import pyplot as plt +import yaml +from numpy.ma.core import append +from scipy.interpolate import PchipInterpolator + +from utils.event_map import E2N +plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 +plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 + +from scipy import ndimage, signal +from functools import wraps + +# 全局配置 +class Config: + time_verbose = False + +def timing_decorator(): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + elapsed_time = end_time - start_time + if Config.time_verbose: # 运行时检查全局配置 + print(f"函数 '{func.__name__}' 执行耗时: {elapsed_time:.4f} 秒") + return result + return wrapper + return decorator + +@timing_decorator() +def read_auto(file_path): + # print('suffix: ', file_path.suffix) + if file_path.suffix == '.txt': + # 使用pandas read csv读取txt + return pd.read_csv(file_path, header=None).to_numpy().reshape(-1) + elif file_path.suffix == '.npy': + return np.load(file_path.__str__()) + elif file_path.suffix == '.base64': + with open(file_path) as f: + files = f.readlines() + + data = np.array(files, dtype=int) + return data + else: + raise ValueError('这个文件类型不支持,需要自己写读取程序') + + +def merge_short_gaps(state_sequence, time_points, max_gap_sec): + """ + 合并状态序列中间隔小于指定时长的段 + + 参数: + - state_sequence: numpy array,状态序列(0/1) + - time_points: numpy array,每个状态点对应的时间点 + - max_gap_sec: float,要合并的最大间隔(秒) + + 返回: + - merged_sequence: numpy array,合并后的状态序列 + """ + if len(state_sequence) <= 1: + return state_sequence + + merged_sequence = state_sequence.copy() + + # 找出状态转换点 + transitions = np.diff(np.concatenate([[0], merged_sequence, [0]])) + # 找出状态1的起始和结束位置 + state_starts = np.where(transitions == 1)[0] + state_ends = np.where(transitions == -1)[0] - 1 + + # 检查每对连续的状态1 + for i in range(len(state_starts) - 1): + if state_ends[i] < len(time_points) and state_starts[i + 1] < len(time_points): + # 计算间隔时长 + gap_duration = time_points[state_starts[i + 1]] - time_points[state_ends[i]] + # 如果间隔小于阈值,则合并 + if gap_duration <= max_gap_sec: + merged_sequence[state_ends[i]:state_starts[i + 1]] = 1 + + return merged_sequence + + +def remove_short_durations(state_sequence, time_points, min_duration_sec): + """ + 移除状态序列中持续时间短于指定阈值的段 + + 参数: + - state_sequence: numpy array,状态序列(0/1) + - time_points: numpy array,每个状态点对应的时间点 + - min_duration_sec: float,要保留的最小持续时间(秒) + + 返回: + - filtered_sequence: numpy array,过滤后的状态序列 + """ + if len(state_sequence) <= 1: + return state_sequence + + filtered_sequence = state_sequence.copy() + + # 找出状态转换点 + transitions = np.diff(np.concatenate([[0], filtered_sequence, [0]])) + # 找出状态1的起始和结束位置 + state_starts = np.where(transitions == 1)[0] + state_ends = np.where(transitions == -1)[0] - 1 + + # 检查每个状态1的持续时间 + for i in range(len(state_starts)): + if state_starts[i] < len(time_points) and state_ends[i] < len(time_points): + # 计算持续时间 + duration = time_points[state_ends[i]] - time_points[state_starts[i]] + if state_ends[i] == len(time_points) - 1: + # 如果是最后一个窗口,加上一个窗口的长度 + duration += time_points[1] - time_points[0] if len(time_points) > 1 else 0 + + # 如果持续时间短于阈值,则移除 + if duration < min_duration_sec: + filtered_sequence[state_starts[i]:state_ends[i] + 1] = 0 + + return filtered_sequence + + +@timing_decorator() +def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, window_second=20, step_second=None): + # 处理标志位长度与 signal_data 对齐 + # if calc_mask is None: + # calc_mask = np.zeros(len(signal_data), dtype=bool) + + if step_second is None: + step_second = window_second + + step_length = step_second * sampling_rate + window_length = window_second * sampling_rate + + origin_seconds = len(signal_data) // sampling_rate + total_samples = len(signal_data) + + + # reflect padding + left_pad_size = int(window_length // 2) + right_pad_size = window_length - left_pad_size + data = np.pad(signal_data, (left_pad_size, right_pad_size), mode='reflect') + + num_segments = int(np.ceil(len(signal_data) / step_length)) + values_nan = np.full(num_segments, np.nan) + + # print(f"num_segments: {num_segments}, step_length: {step_length}, window_length: {window_length}") + for i in range(num_segments): + # 包含体动则仅计算不含体动部分 + start = int(i * step_length) + end = start + window_length + segment = data[start:end] + values_nan[i] = func(segment) + + + values_nan = values_nan.repeat(step_second)[:origin_seconds] + + if calc_mask is not None: + for i in range(len(values_nan)): + if calc_mask[i]: + values_nan[i] = np.nan + + values = values_nan.copy() + + # 插值处理体动区域的 NaN 值 + def interpolate_nans(x, t): + valid_mask = ~np.isnan(x) + return np.interp(t, t[valid_mask], x[valid_mask]) + + values = interpolate_nans(values, np.arange(len(values))) + else: + values = values_nan.copy() + + return values_nan, values + + +def load_dataset_conf(yaml_path): + with open(yaml_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + # select_ids = config.get('select_id', []) + # root_path = config.get('root_path', None) + # data_path = Path(root_path) + return config + + +def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray: + disable_mask = np.zeros(signal_second, dtype=int) + + for _, row in disable_df.iterrows(): + start = row["start"] + end = row["end"] + disable_mask[start:end] = 1 + return disable_mask + + +def generate_event_mask(signal_second: int, event_df, use_correct=True, with_score=True): + event_mask = np.zeros(signal_second, dtype=int) + if with_score: + score_mask = np.zeros(signal_second, dtype=int) + else: + score_mask = None + if use_correct: + start_name = "correct_Start" + end_name = "correct_End" + event_type_name = "correct_EventsType" + else: + start_name = "Start" + end_name = "End" + event_type_name = "Event type" + + # 剔除start = -1 的行 + event_df = event_df[event_df[start_name] >= 0] + + for _, row in event_df.iterrows(): + start = row[start_name] + end = row[end_name] + 1 + event_mask[start:end] = E2N[row[event_type_name]] + if with_score: + score_mask[start:end] = row["score"] + return event_mask, score_mask + + +def event_mask_2_list(mask, event_true=True): + if event_true: + event_2_normal = -1 + normal_2_event = 1 + _append = 0 + else: + event_2_normal = 1 + normal_2_event = -1 + _append = 1 + mask_start = np.where(np.diff(mask, prepend=_append, append=_append) == normal_2_event)[0] + mask_end = np.where(np.diff(mask, prepend=_append, append=_append) == event_2_normal)[0] + event_list =[[start, end] for start, end in zip(mask_start, mask_end)] + return event_list + + +def collect_values(arr: np.ndarray, index: int, step: int, limit: int, mask=None) -> list: + """收集非 NaN 值,直到达到指定数量或边界""" + values = [] + count = 0 + mask = mask if mask is not None else arr + while count < limit and 0 <= index < len(mask): + if not np.isnan(mask[index]): + values.append(arr[index]) + count += 1 + index += step + return values + + +def save_process_label(save_path: Path, save_dict: dict): + save_df = pd.DataFrame(save_dict) + save_df.to_csv(save_path, index=False) + +def none_to_nan_mask(mask, ref): + """将None转换为与ref形状相同的nan掩码""" + if mask is None: + return np.full_like(ref, np.nan) + else: + # 将mask中的0替换为nan,其他的保持 + mask = np.where(mask == 0, np.nan, mask) + return mask + +def get_wake_mask(sleep_stage_mask): + # 将N1, N2, N3, REM视为睡眠 0,其他为清醒 1 + # 输入是字符 分别有 'W', 'N1', 'N2', 'N3', 'R' 等 + wake_mask = np.where(np.isin(sleep_stage_mask, ['N1', 'N2', 'N3', 'REM', 'R']), 0, 1) + return wake_mask + +def detect_spo2_anomaly(spo2, fs, diff_thresh=7): + anomaly = np.zeros(len(spo2), dtype=bool) + + # 生理范围 + anomaly |= (spo2 < 50) | (spo2 > 100) + + # 突变 + diff = np.abs(np.diff(spo2, prepend=spo2[0])) + anomaly |= diff > diff_thresh + + # NaN + anomaly |= np.isnan(spo2) + + return anomaly + +def merge_close_anomalies(anomaly, fs, min_gap_duration): + min_gap = int(min_gap_duration * fs) + merged = anomaly.copy() + + i = 0 + n = len(anomaly) + + while i < n: + if not anomaly[i]: + i += 1 + continue + + # 当前异常段 + start = i + while i < n and anomaly[i]: + i += 1 + end = i + + # 向后看 gap + j = end + while j < n and not anomaly[j]: + j += 1 + + if j < n and (j - end) < min_gap: + merged[end:j] = True + + return merged + +def fill_spo2_anomaly( + spo2_data, + spo2_fs, + max_fill_duration, + min_gap_duration, +): + spo2 = spo2_data.astype(float).copy() + n = len(spo2) + + anomaly = detect_spo2_anomaly(spo2, spo2_fs) + anomaly = merge_close_anomalies(anomaly, spo2_fs, min_gap_duration) + + max_len = int(max_fill_duration * spo2_fs) + + valid_mask = ~anomaly + + i = 0 + while i < n: + if not anomaly[i]: + i += 1 + continue + + start = i + while i < n and anomaly[i]: + i += 1 + end = i + + seg_len = end - start + + # 超长异常段 + if seg_len > max_len: + spo2[start:end] = np.nan + valid_mask[start:end] = False + continue + + has_left = start > 0 and valid_mask[start - 1] + has_right = end < n and valid_mask[end] + + # 开头异常:单侧填充 + if not has_left and has_right: + spo2[start:end] = spo2[end] + continue + + # 结尾异常:单侧填充 + if has_left and not has_right: + spo2[start:end] = spo2[start - 1] + continue + + # 两侧都有 → PCHIP + if has_left and has_right: + x = np.array([start - 1, end]) + y = np.array([spo2[start - 1], spo2[end]]) + + interp = PchipInterpolator(x, y) + spo2[start:end] = interp(np.arange(start, end)) + continue + + # 两侧都没有(极端情况) + spo2[start:end] = np.nan + valid_mask[start:end] = False + + return spo2, valid_mask \ No newline at end of file diff --git a/utils/signal_process.py b/utils/signal_process.py new file mode 100644 index 0000000..e690c33 --- /dev/null +++ b/utils/signal_process.py @@ -0,0 +1,108 @@ +from utils.operation_tools import timing_decorator +import numpy as np +from scipy import signal, ndimage + + +@timing_decorator() +def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10, sample_rate=1000): + if _type == "lowpass": # 低通滤波处理 + sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + elif _type == "bandpass": # 带通滤波处理 + low = low_cut / (sample_rate * 0.5) + high = high_cut / (sample_rate * 0.5) + sos = signal.butter(order, [low, high], btype='bandpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + elif _type == "highpass": # 高通滤波处理 + sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + else: # 警告,滤波器类型必须有 + raise ValueError("Please choose a type of fliter") + + +def bessel(data, _type, low_cut=0.0, high_cut=0.0, order=4, sample_rate=1000): + if _type == "lowpass": # 低通滤波处理 + b, a = signal.bessel(order, low_cut / (sample_rate * 0.5), btype='lowpass', analog=False, norm='mag') + return signal.filtfilt(b, a, np.array(data)) + elif _type == "bandpass": # 带通滤波处理 + low = low_cut / (sample_rate * 0.5) + high = high_cut / (sample_rate * 0.5) + b, a = signal.bessel(order, [low, high], btype='bandpass', analog=False, norm='mag') + return signal.filtfilt(b, a, np.array(data)) + elif _type == "highpass": # 高通滤波处理 + b, a = signal.bessel(order, high_cut / (sample_rate * 0.5), btype='highpass', analog=False, norm='mag') + return signal.filtfilt(b, a, np.array(data)) + else: # 警告,滤波器类型必须有 + raise ValueError("Please choose a type of fliter") + + +@timing_decorator() +def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000): + """ + 高效整数倍降采样长信号,分段处理以优化内存和速度。 + + 参数: + original_signal : array-like, 原始信号数组 + original_fs : float, 原始采样率 (Hz) + target_fs : float, 目标采样率 (Hz) + chunk_size : int, 每段处理的样本数,默认100000 + + 返回: + downsampled_signal : array-like, 降采样后的信号 + """ + # 输入验证 + if not isinstance(original_signal, np.ndarray): + original_signal = np.array(original_signal) + if original_fs <= target_fs: + raise ValueError("目标采样率必须小于原始采样率") + if target_fs <= 0 or original_fs <= 0: + raise ValueError("采样率必须为正数") + + # 计算降采样因子(必须为整数) + downsample_factor = original_fs / target_fs + if not downsample_factor.is_integer(): + raise ValueError("降采样因子必须为整数倍") + downsample_factor = int(downsample_factor) + + # 计算总输出长度 + total_length = len(original_signal) + output_length = total_length // downsample_factor + + # 初始化输出数组 + downsampled_signal = np.zeros(output_length) + + # 分段处理 + for start in range(0, total_length, chunk_size): + end = min(start + chunk_size, total_length) + chunk = original_signal[start:end] + + # 使用decimate进行整数倍降采样 + chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True) + + # 计算输出位置 + out_start = start // downsample_factor + out_end = out_start + len(chunk_downsampled) + if out_end > output_length: + chunk_downsampled = chunk_downsampled[:output_length - out_start] + + downsampled_signal[out_start:out_end] = chunk_downsampled + + return downsampled_signal + + +@timing_decorator() +def average_filter(raw_data, sample_rate, window_size_sec=20): + kernel = np.ones(window_size_sec * sample_rate) / (window_size_sec * sample_rate) + filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect') + convolve_filter_signal = raw_data - filtered + return convolve_filter_signal + + +# 陷波滤波器 +@timing_decorator() +def notch_filter(data, notch_freq=50.0, quality_factor=30.0, sample_rate=1000): + nyquist = 0.5 * sample_rate + norm_notch_freq = notch_freq / nyquist + b, a = signal.iirnotch(norm_notch_freq, quality_factor) + filtered_data = signal.filtfilt(b, a, data) + return filtered_data diff --git a/utils/split_method.py b/utils/split_method.py new file mode 100644 index 0000000..e5c796c --- /dev/null +++ b/utils/split_method.py @@ -0,0 +1,56 @@ + +def check_split(event_mask, current_start, window_sec, verbose=False): + # 检查当前窗口是否包含在禁用区间或低幅值区间内 + resp_movement_mask = event_mask["Resp_Movement_Label"][current_start : current_start + window_sec] + resp_low_amp_mask = event_mask["Resp_LowAmp_Label"][current_start : current_start + window_sec] + + # 体动与低幅值进行与计算 + low_move_mask = resp_movement_mask | resp_low_amp_mask + if low_move_mask.sum() > 2/3 * window_sec: + if verbose: + print(f"{current_start}-{current_start + window_sec} rejected due to movement/low amplitude mask more than 2/3") + return False + return True + + +def resp_split(dataset_config, event_mask, event_list, verbose=False): + # 提取体动区间和呼吸低幅值区间 + enable_list = event_list["EnableSegment"] + disable_list = event_list["DisableSegment"] + + # 读取数据集配置 + window_sec = dataset_config["window_sec"] + stride_sec = dataset_config["stride_sec"] + + segment_list = [] + disable_segment_list = [] + + # 遍历每个enable区间, 如果最后一个窗口不足stride的1/2,则舍弃,否则以enable_end为结尾截取一个窗口 + for enable_start, enable_end in enable_list: + current_start = enable_start + while current_start + window_sec <= enable_end: + if check_split(event_mask, current_start, window_sec, verbose): + segment_list.append((current_start, current_start + window_sec)) + else: + disable_segment_list.append((current_start, current_start + window_sec)) + current_start += stride_sec + # 检查最后一个窗口是否需要添加 + if (enable_end - current_start >= stride_sec / 2) and (enable_end - current_start >= window_sec): + if check_split(event_mask, enable_end - window_sec, window_sec, verbose): + segment_list.append((enable_end - window_sec, enable_end)) + else: + disable_segment_list.append((enable_end - window_sec, enable_end)) + + # 遍历每个disable区间, 如果最后一个窗口不足stride的1/2,则舍弃,否则以disable_end为结尾截取一个窗口 + for disable_start, disable_end in disable_list: + current_start = disable_start + while current_start + window_sec <= disable_end: + disable_segment_list.append((current_start, current_start + window_sec)) + current_start += stride_sec + # 检查最后一个窗口是否需要添加 + if (disable_end - current_start >= stride_sec / 2) and (disable_end - current_start >= window_sec): + disable_segment_list.append((disable_end - window_sec, disable_end)) + + + return segment_list, disable_segment_list + diff --git a/utils/statistics_metrics.py b/utils/statistics_metrics.py new file mode 100644 index 0000000..5c06b19 --- /dev/null +++ b/utils/statistics_metrics.py @@ -0,0 +1,105 @@ +from utils.operation_tools import timing_decorator +import numpy as np +import pandas as pd + +@timing_decorator() +def statistic_amplitude_metrics(data, aml_interval=None, time_interval=None): + """ + 计算不同幅值区间占比和时间,最后汇总成混淆矩阵 + + 参数: + data: 采样率为1秒的一维序列,其中体动所在的区域用np.nan填充 + aml_interval: 幅值区间的分界点列表,默认为[200, 500, 1000, 2000] + time_interval: 时间区间的分界点列表,单位为秒,默认为[60, 300, 1800, 3600] + + 返回: + confusion_matrix: 幅值-时长统计矩阵 + summary: 汇总统计信息 + """ + if aml_interval is None: + aml_interval = [200, 500, 1000, 2000] + + if time_interval is None: + time_interval = [60, 300, 1800, 3600] + # 检查输入 + if not isinstance(data, np.ndarray): + data = np.array(data) + + # 整个记录的时长(包括nan) + total_duration = len(data) + + # 创建幅值标签和时间标签 + amp_labels = [f"0-{aml_interval[0]}"] + for i in range(len(aml_interval) - 1): + amp_labels.append(f"{aml_interval[i]}-{aml_interval[i + 1]}") + amp_labels.append(f"{aml_interval[-1]}+") + + time_labels = [f"0-{time_interval[0]}"] + for i in range(len(time_interval) - 1): + time_labels.append(f"{time_interval[i]}-{time_interval[i + 1]}") + time_labels.append(f"{time_interval[-1]}+") + + # 初始化结果矩阵(时长)和片段数矩阵 + result_matrix = np.zeros((len(amp_labels), len(time_labels))) # 时长矩阵 + segment_count_matrix = np.zeros((len(amp_labels), len(time_labels))) # 片段数矩阵 + + # 有效信号总量(非NaN的数据点数量) + valid_signal_length = np.sum(~np.isnan(data)) + + # 添加信号开始和结束的边界条件 + signal_padded = np.concatenate(([np.nan], data, [np.nan])) + diff = np.diff(np.isnan(signal_padded).astype(int)) + + # 连续片段的起始位置(从 nan 变为非 nan) + segment_starts = np.where(diff == -1)[0] + # 连续片段的结束位置(从非 nan 变为 nan) + segment_ends = np.where(diff == 1)[0] + + # 计算每个片段的时长和平均幅值,并填充结果矩阵 + for start, end in zip(segment_starts, segment_ends): + segment = data[start:end] + duration = end - start # 时长(单位:秒) + mean_amplitude = np.nanmean(segment) # 片段平均幅值 + + # 确定幅值区间 + if mean_amplitude <= aml_interval[0]: + amp_idx = 0 + elif mean_amplitude > aml_interval[-1]: + amp_idx = len(aml_interval) + else: + amp_idx = np.searchsorted(aml_interval, mean_amplitude) + + # 确定时长区间 + if duration <= time_interval[0]: + time_idx = 0 + elif duration > time_interval[-1]: + time_idx = len(time_interval) + else: + time_idx = np.searchsorted(time_interval, duration) + + # 在对应位置累加该片段的时长和片段数 + result_matrix[amp_idx, time_idx] += duration + segment_count_matrix[amp_idx, time_idx] += 1 # 片段数加1 + + # 创建DataFrame以便于展示和后续处理 + confusion_matrix = pd.DataFrame(result_matrix, index=amp_labels, columns=time_labels) + + # 计算行和列的总和 + confusion_matrix['总计'] = confusion_matrix.sum(axis=1) + row_totals = confusion_matrix['总计'].copy() + + # 计算百分比(相对于有效记录时长) + confusion_matrix_percent = confusion_matrix.div(total_duration) * 100 + + # 汇总统计 + summary = { + 'total_duration': total_duration, + 'total_valid_signal': valid_signal_length, + 'amplitude_distribution': row_totals.to_dict(), + 'amplitude_percent': row_totals.div(total_duration) * 100, + 'time_distribution': confusion_matrix.sum(axis=0).to_dict(), + 'time_percent': confusion_matrix.sum(axis=0).div(total_duration) * 100 + } + + return summary, (confusion_matrix, segment_count_matrix, confusion_matrix_percent, valid_signal_length, + total_duration, time_labels, amp_labels) diff --git a/排除区间.xlsx b/排除区间.xlsx new file mode 100644 index 0000000..12f0f9f Binary files /dev/null and b/排除区间.xlsx differ