feat: Add utility functions for signal processing and event mapping
- Created a new module `utils/__init__.py` to consolidate utility imports. - Added `event_map.py` for mapping apnea event types to numerical values and colors. - Implemented various filtering functions in `filter_func.py`, including Butterworth, Bessel, downsampling, and notch filters. - Developed `operation_tools.py` for dataset configuration loading, event mask generation, and signal processing utilities. - Introduced `split_method.py` for segmenting data based on movement and amplitude criteria. - Added `statistics_metrics.py` for calculating amplitude metrics and generating confusion matrices. - Included a new Excel file for additional data storage.
This commit is contained in:
commit
8ee5980906
470
AGENT_BACKGROUND_CN.md
Normal file
470
AGENT_BACKGROUND_CN.md
Normal file
@ -0,0 +1,470 @@
|
||||
# 项目背景文档(供 Agent 快速上手)
|
||||
|
||||
## 1. 文档目的
|
||||
|
||||
这份文档用于帮助新进入仓库的 agent 快速建立项目心智模型,重点回答下面几个问题:
|
||||
|
||||
- 这个仓库在做什么
|
||||
- 核心处理链路是什么
|
||||
- 代码从哪里开始看
|
||||
- 输入数据长什么样,输出产物长什么样
|
||||
- 哪些模块已经可用,哪些模块还只是占位或半成品
|
||||
|
||||
---
|
||||
|
||||
## 2. 一句话理解项目
|
||||
|
||||
`DataPrepare` 是一个睡眠呼吸暂停数据预处理仓库,主要用于把原始的 BCG / PSG 信号与呼吸暂停标签整理成可训练的数据集。
|
||||
|
||||
当前主线可以概括为两步:
|
||||
|
||||
1. 在 `event_mask_process/` 中把原始标注和规则检测结果整理成逐秒标签掩码。
|
||||
2. 在 `dataset_builder/` 中根据这些掩码切出固定长度窗口,并将信号和窗口索引保存为 `npz` 数据集。
|
||||
|
||||
---
|
||||
|
||||
## 3. 仓库总览
|
||||
|
||||
| 目录 | 作用 | 备注 |
|
||||
| --- | --- | --- |
|
||||
| `dataset_config/` | 各数据集的 YAML 配置 | 包含路径、采样率、阈值、切窗参数 |
|
||||
| `event_mask_process/` | 生成逐秒标签掩码 | 主入口之一 |
|
||||
| `dataset_builder/` | 将信号切成窗口并保存数据集 | 主入口之二 |
|
||||
| `signal_method/` | 信号预处理、规则检测、特征计算 | 规则逻辑核心 |
|
||||
| `utils/` | 读文件、标签转换、切窗、滤波、通用工具 | 底层支撑 |
|
||||
| `draw_tools/` | 全夜信号图、分段图、统计图 | 调试与可视化 |
|
||||
| `dataset_tools/` | 辅助脚本 | 包括配对拷贝、SHHS 标注检查 |
|
||||
| `output/` | 中间结果样例 | 仓库内已有若干处理结果示例 |
|
||||
|
||||
---
|
||||
|
||||
## 4. 项目主线流程
|
||||
|
||||
### 4.1 HYS(BCG 主线)
|
||||
|
||||
对应文件:
|
||||
|
||||
- `event_mask_process/HYS_process.py`
|
||||
- `dataset_builder/HYS_dataset.py`
|
||||
|
||||
处理过程:
|
||||
|
||||
1. 从 `root_path/OrgBCG_Aligned/<id>/OrgBCG_Sync_*.txt` 读取原始 BCG 同步信号。
|
||||
2. 从 `root_path/Label/<id>/SA Label_corrected.csv` 读取人工修正后的呼吸暂停标签。
|
||||
3. 对原始 BCG 进行 50Hz 陷波,再拆出:
|
||||
- 呼吸分量 `resp_data`
|
||||
- BCG 分量 `bcg_data`
|
||||
4. 根据配置做规则检测,生成逐秒掩码:
|
||||
- `Resp_LowAmp_Label`
|
||||
- `Resp_Movement_Label`
|
||||
- `Resp_AmpChange_Label`
|
||||
- `BCG_LowAmp_Label`
|
||||
- `BCG_Movement_Label`
|
||||
- `BCG_AmpChange_Label`
|
||||
- `Disable_Label`(来自 `排除区间.xlsx`)
|
||||
- `SA_Label` / `SA_Score`(来自人工标签)
|
||||
5. 将这些结果保存为 `output/HYS/<id>/<id>_Processed_Labels.csv`。
|
||||
6. 在数据集构建阶段读取上述逐秒标签,按 `window_sec` 和 `stride_sec` 切成窗口。
|
||||
7. 保存处理后的信号 `Signals/*.npz`、窗口索引 `Segments_List/*.npz`、标签副本 `Labels/*.csv`。
|
||||
|
||||
### 4.2 HYS_PSG(PSG 主线)
|
||||
|
||||
对应文件:
|
||||
|
||||
- `event_mask_process/HYS_PSG_process.py`
|
||||
- `dataset_builder/HYS_PSG_dataset.py`
|
||||
|
||||
处理过程:
|
||||
|
||||
1. 从 `root_path/PSG_Aligned/<id>/` 读取 PSG 通道,包括:
|
||||
- `Rpeak`
|
||||
- `ECG_Sync`
|
||||
- `Effort Tho`
|
||||
- `Effort Abd`
|
||||
- `Flow P`
|
||||
- `Flow T`
|
||||
- `SpO2`
|
||||
- `5_class`
|
||||
2. 从同目录读取 `SA Label_Sync.csv`。
|
||||
3. 在 `HYS_PSG_process.py` 中生成较简化的逐秒标签:
|
||||
- `SA_Label`
|
||||
- `SA_Score`
|
||||
- `Disable_Label`(主要由睡眠分期里的清醒期生成)
|
||||
- 其余 `Resp_*` / `BCG_*` 掩码目前全部置零
|
||||
4. 在 `HYS_PSG_dataset.py` 中对胸腹带、流量、SpO2、RRI 做统一重采样与长度对齐。
|
||||
5. 对 SpO2 进行异常填补或置空,再切成窗口,保存为 PSG 数据集。
|
||||
|
||||
### 4.3 设计上的共性
|
||||
|
||||
无论是 HYS 还是 HYS_PSG,核心设计都是:
|
||||
|
||||
1. 先把整夜记录整理成逐秒掩码。
|
||||
2. 再根据掩码和切窗规则提取训练片段。
|
||||
|
||||
这意味着以后如果要改“可用性判断”或“切窗逻辑”,优先看:
|
||||
|
||||
- `event_mask_process/`
|
||||
- `utils/split_method.py`
|
||||
|
||||
---
|
||||
|
||||
## 5. 数据与目录约定
|
||||
|
||||
### 5.1 外部数据目录
|
||||
|
||||
这个仓库本身不包含原始数据,真正的数据目录由 YAML 中的绝对路径指定,例如:
|
||||
|
||||
- `/mnt/disk_wd/marques_dataset/DataCombine2023/HYS`
|
||||
- `/mnt/disk_wd/marques_dataset/shhs/polysomnography/shhs1`
|
||||
|
||||
因此在新环境运行时,第一件事通常不是改代码,而是先改 `dataset_config/*.yaml` 里的绝对路径。
|
||||
|
||||
### 5.2 HYS 数据目录约定
|
||||
|
||||
代码默认外部目录至少包含:
|
||||
|
||||
- `OrgBCG_Aligned/<id>/OrgBCG_Sync_<fs>.txt`
|
||||
- `PSG_Aligned/<id>/...`
|
||||
- `Label/<id>/SA Label_corrected.csv`
|
||||
|
||||
### 5.3 HYS_PSG 数据目录约定
|
||||
|
||||
代码默认 `PSG_Aligned/<id>/` 内文件命名符合下面的模式:
|
||||
|
||||
- `Rpeak*.txt`
|
||||
- `ECG_Sync*.txt`
|
||||
- `Effort Tho*.txt`
|
||||
- `Effort Abd*.txt`
|
||||
- `Flow P*.txt`
|
||||
- `Flow T*.txt`
|
||||
- `SpO2*.txt`
|
||||
- `5_class*.txt`
|
||||
- `SA Label_Sync.csv`
|
||||
|
||||
### 5.4 处理中间结果目录
|
||||
|
||||
仓库内 `output/` 保存的是“中间标签与图像结果”,不是最终训练数据集。
|
||||
最终数据集通常会写到 YAML 中 `dataset_save_path` 指向的外部目录。
|
||||
|
||||
---
|
||||
|
||||
## 6. 核心入口脚本
|
||||
|
||||
### 6.1 主入口
|
||||
|
||||
| 场景 | 脚本 |
|
||||
| --- | --- |
|
||||
| HYS 逐秒标签生成 | `event_mask_process/HYS_process.py` |
|
||||
| HYS 数据集构建 | `dataset_builder/HYS_dataset.py` |
|
||||
| HYS_PSG 逐秒标签生成 | `event_mask_process/HYS_PSG_process.py` |
|
||||
| HYS_PSG 数据集构建 | `dataset_builder/HYS_PSG_dataset.py` |
|
||||
|
||||
### 6.2 辅助脚本
|
||||
|
||||
| 脚本 | 作用 |
|
||||
| --- | --- |
|
||||
| `dataset_tools/resp_pair_copy.py` | 将 HYS/ZD5Y 的 BCG 与 PSG 原始文件拷贝为配对数据集 |
|
||||
| `dataset_tools/shhs_annotations_check.py` | 统计 SHHS XML 标注中的事件类型组合 |
|
||||
| `event_mask_process/SHHS1_process.py` | SHHS1 处理入口占位,目前未实现 |
|
||||
|
||||
### 6.3 入口脚本的共同特点
|
||||
|
||||
- 大多数脚本没有命令行参数接口。
|
||||
- 配置文件路径通常直接写在 `if __name__ == '__main__':` 中。
|
||||
- 运行逻辑依赖 YAML 中的绝对路径。
|
||||
|
||||
所以如果 agent 要“跑脚本”,通常需要先确认:
|
||||
|
||||
1. 当前机器是否存在对应外部数据目录。
|
||||
2. YAML 路径是否匹配当前环境。
|
||||
|
||||
---
|
||||
|
||||
## 7. 关键模块说明
|
||||
|
||||
### 7.1 `utils/`
|
||||
|
||||
#### `utils/HYS_FileReader.py`
|
||||
|
||||
负责各种输入文件读取,是仓库最重要的底层模块之一:
|
||||
|
||||
- `read_signal_txt`:读取单通道 txt,并根据文件名中的采样率推断 `fs`
|
||||
- `read_label_csv`:读取人工修正的 HYS 标签
|
||||
- `read_raw_psg_label`:读取 PSG 原始同步标签
|
||||
- `read_disable_excel`:读取 `排除区间.xlsx`
|
||||
- `read_mask_execl`:读取处理后的逐秒标签 CSV,并生成事件片段列表
|
||||
- `read_psg_channel`:按通道名读取 PSG 文件夹里的多通道数据
|
||||
|
||||
#### `utils/operation_tools.py`
|
||||
|
||||
负责标签转换、片段提取和通用处理:
|
||||
|
||||
- `load_dataset_conf`:读取 YAML
|
||||
- `generate_event_mask`:把事件表转换成逐秒 `SA_Label`
|
||||
- `generate_disable_mask`:把 Excel 中的排除区间转换成逐秒 `Disable_Label`
|
||||
- `event_mask_2_list`:把 0/1 掩码转为 `[start, end]` 列表
|
||||
- `merge_short_gaps` / `remove_short_durations`:对逐秒掩码做时长后处理
|
||||
- `fill_spo2_anomaly`:修补 SpO2 异常段
|
||||
|
||||
#### `utils/split_method.py`
|
||||
|
||||
真正的切窗规则在这里:
|
||||
|
||||
- 默认按 `window_sec` / `stride_sec` 滑窗
|
||||
- 只在 `EnableSegment` 内生成可用窗口
|
||||
- 如果一个窗口中 `Resp_Movement_Label | Resp_LowAmp_Label` 超过窗口时长的 2/3,则该窗口转入 `disable_segment_list`
|
||||
|
||||
#### `utils/filter_func.py`
|
||||
|
||||
提供滤波和采样率处理:
|
||||
|
||||
- Butterworth / Bessel
|
||||
- 陷波
|
||||
- 整数倍降采样
|
||||
- 自动升降采样
|
||||
- 移动平均去趋势
|
||||
|
||||
### 7.2 `signal_method/`
|
||||
|
||||
#### `signal_method/signal_process.py`
|
||||
|
||||
负责信号预处理:
|
||||
|
||||
- `signal_filter_split`:把原始 OrgBCG 信号拆成呼吸分量和 BCG 分量
|
||||
- `psg_effort_filter`:处理 PSG 努力带 / 流量信号
|
||||
- `rpeak2hr` / `rpeak2rri_interpolation`:由 R 峰生成 HR / RRI
|
||||
|
||||
#### `signal_method/rule_base_event.py`
|
||||
|
||||
规则检测主逻辑:
|
||||
|
||||
- `detect_movement`:基于滑窗标准差和局部幅值比较检测体动
|
||||
- `movement_revise`:对体动掩码做二次修正
|
||||
- `detect_low_amplitude_signal`:检测低幅值
|
||||
- `position_based_sleep_recognition_v2/v3`:根据体动前后幅值变化标记姿势/幅值变化段
|
||||
|
||||
#### `signal_method/normalize_method.py`
|
||||
|
||||
`normalize_resp_signal_by_segment` 会按“可用片段”做分段 z-score 标准化。
|
||||
HYS 中通常按 `Resp_AmpChange_Label` 的反向片段来归一化,目的是减少整夜幅值漂移带来的影响。
|
||||
|
||||
### 7.3 `draw_tools/`
|
||||
|
||||
主要用于人工检查处理质量:
|
||||
|
||||
- `draw_signal_with_mask`:画 HYS 全夜原始信号 + 规则掩码
|
||||
- `draw_psg_signal`:画 HYS_PSG 全夜 PSG 信号
|
||||
- `draw_psg_label` / `draw_psg_bcg_label`:按窗口导出分段图
|
||||
|
||||
### 7.4 `dataset_builder/`
|
||||
|
||||
核心职责是把“整夜记录 + 逐秒掩码”转换成训练样本:
|
||||
|
||||
- 保存处理后的多通道信号到 `npz`
|
||||
- 保存窗口起止列表到 `npz`
|
||||
- 把标签 CSV 一并拷贝到数据集目录
|
||||
|
||||
---
|
||||
|
||||
## 8. 标签、通道和编码约定
|
||||
|
||||
### 8.1 呼吸暂停事件编码
|
||||
|
||||
定义于 `utils/event_map.py`:
|
||||
|
||||
| 事件 | 编码 |
|
||||
| --- | --- |
|
||||
| `Hypopnea` | 1 |
|
||||
| `Central apnea` | 2 |
|
||||
| `Obstructive apnea` | 3 |
|
||||
| `Mixed apnea` | 4 |
|
||||
|
||||
### 8.2 PSG 通道编号映射
|
||||
|
||||
同样定义于 `utils/event_map.py`:
|
||||
|
||||
| 编号 | 通道名 |
|
||||
| --- | --- |
|
||||
| 1 | `Rpeak` |
|
||||
| 2 | `ECG_Sync` |
|
||||
| 3 | `Effort Tho` |
|
||||
| 4 | `Effort Abd` |
|
||||
| 5 | `Flow P` |
|
||||
| 6 | `Flow T` |
|
||||
| 7 | `SpO2` |
|
||||
| 8 | `5_class` |
|
||||
|
||||
### 8.3 睡眠分期编码
|
||||
|
||||
`5_class` 在读取时会转成整数:
|
||||
|
||||
| 分期 | 编码 |
|
||||
| --- | --- |
|
||||
| `N3` | 1 |
|
||||
| `N2` | 2 |
|
||||
| `N1` | 3 |
|
||||
| `R` | 4 |
|
||||
| `W` | 5 |
|
||||
|
||||
### 8.4 逐秒标签 CSV 字段
|
||||
|
||||
`Processed_Labels.csv` 的核心字段为:
|
||||
|
||||
- `Second`
|
||||
- `SA_Label`
|
||||
- `SA_Score`
|
||||
- `Disable_Label`
|
||||
- `Resp_LowAmp_Label`
|
||||
- `Resp_Movement_Label`
|
||||
- `Resp_AmpChange_Label`
|
||||
- `BCG_LowAmp_Label`
|
||||
- `BCG_Movement_Label`
|
||||
- `BCG_AmpChange_Label`
|
||||
|
||||
样例可见仓库内:
|
||||
|
||||
- `output/HYS/220/220_Processed_Labels.csv`
|
||||
- `output/HYS_PSG/220/220_Processed_Labels.csv`
|
||||
|
||||
---
|
||||
|
||||
## 9. 主要配置文件
|
||||
|
||||
### 9.1 HYS
|
||||
|
||||
`dataset_config/HYS_config.yaml` 主要控制:
|
||||
|
||||
- 样本 ID 列表
|
||||
- 原始数据根目录
|
||||
- 中间标签保存目录
|
||||
- 呼吸 / BCG 的滤波和降采样参数
|
||||
- 低幅值、体动、幅值变化检测阈值
|
||||
- 数据集窗口长度和步长
|
||||
|
||||
### 9.2 HYS_PSG
|
||||
|
||||
`dataset_config/HYS_PSG_config.yaml` 主要控制:
|
||||
|
||||
- 样本 ID 列表
|
||||
- 目标统一采样率 `target_fs`
|
||||
- 努力带 / 流量滤波参数
|
||||
- SpO2 异常填补参数
|
||||
- 数据集输出路径
|
||||
|
||||
### 9.3 其他配置
|
||||
|
||||
- `dataset_config/ZD5Y_config.yaml`:另一套 BCG 规则配置
|
||||
- `dataset_config/SHHS1_config.yaml`:SHHS1 预留配置
|
||||
- `dataset_config/RESP_PAIR_HYS_config.yaml`:HYS 配对原始数据拷贝
|
||||
- `dataset_config/RESP_PAIR_ZD5Y_config.yaml`:ZD5Y 配对原始数据拷贝
|
||||
|
||||
---
|
||||
|
||||
## 10. 输出产物说明
|
||||
|
||||
### 10.1 中间输出
|
||||
|
||||
典型位置:
|
||||
|
||||
- `output/HYS/<id>/<id>_Processed_Labels.csv`
|
||||
- `output/HYS/<id>/<id>_Signal_Plots.png`
|
||||
- `output/HYS_PSG/<id>/<id>_Processed_Labels.csv`
|
||||
- `output/HYS_PSG/<id>/<id>_Signal_Plots.png`
|
||||
- `output/HYS_PSG/<id>/<id>_Signal_Plots_fill.png`
|
||||
|
||||
### 10.2 最终数据集输出
|
||||
|
||||
由 `dataset_builder/*` 保存到 YAML 指定的外部目录,结构通常是:
|
||||
|
||||
- `Signals/`
|
||||
- `Segments_List/`
|
||||
- `Labels/`
|
||||
|
||||
HYS 的 `Signals/*.npz` 里主要保存:
|
||||
|
||||
- `bcg_signal_notch`
|
||||
- `bcg_signal`
|
||||
- `resp_signal`
|
||||
|
||||
HYS_PSG 的 `Signals/*.npz` 里主要保存:
|
||||
|
||||
- `Effort Tho`
|
||||
- `Effort Abd`
|
||||
- `Effort`
|
||||
- `Flow P`
|
||||
- `Flow T`
|
||||
- `SpO2`
|
||||
- `HR`
|
||||
- `RRI`
|
||||
- `5_class`
|
||||
|
||||
`Segments_List/*.npz` 中主要保存:
|
||||
|
||||
- `segment_list`
|
||||
- `disable_segment_list`
|
||||
|
||||
---
|
||||
|
||||
## 11. 推荐阅读顺序
|
||||
|
||||
如果 agent 是第一次接触这个仓库,建议按下面顺序阅读:
|
||||
|
||||
1. `dataset_config/HYS_config.yaml`
|
||||
2. `event_mask_process/HYS_process.py`
|
||||
3. `utils/operation_tools.py`
|
||||
4. `utils/split_method.py`
|
||||
5. `dataset_builder/HYS_dataset.py`
|
||||
6. `signal_method/signal_process.py`
|
||||
7. `signal_method/rule_base_event.py`
|
||||
8. `utils/HYS_FileReader.py`
|
||||
|
||||
如果要看 PSG 主线,再继续读:
|
||||
|
||||
1. `dataset_config/HYS_PSG_config.yaml`
|
||||
2. `event_mask_process/HYS_PSG_process.py`
|
||||
3. `dataset_builder/HYS_PSG_dataset.py`
|
||||
4. `draw_tools/draw_label.py`
|
||||
|
||||
---
|
||||
|
||||
## 12. 常见修改入口
|
||||
|
||||
如果你要改不同类型的问题,可以优先从这些文件入手:
|
||||
|
||||
| 需求 | 优先看哪里 |
|
||||
| --- | --- |
|
||||
| 调整阈值、采样率、窗口长度 | `dataset_config/*.yaml` |
|
||||
| 修改逐秒标签生成逻辑 | `event_mask_process/*.py` |
|
||||
| 修改体动 / 低幅值 / 幅值变化规则 | `signal_method/rule_base_event.py` |
|
||||
| 修改滤波与重采样 | `signal_method/signal_process.py`、`utils/filter_func.py` |
|
||||
| 修改切窗规则 | `utils/split_method.py` |
|
||||
| 修改输入文件解析 | `utils/HYS_FileReader.py` |
|
||||
| 修改事件编码或通道映射 | `utils/event_map.py` |
|
||||
| 修改图像输出样式 | `draw_tools/*.py` |
|
||||
|
||||
---
|
||||
|
||||
## 13. 当前实现状态与注意事项
|
||||
|
||||
下面这些点对 agent 很重要:
|
||||
|
||||
1. 仓库没有依赖清单文件(如 `requirements.txt` / `pyproject.toml`),依赖需要从源码导入中反推,当前至少涉及 `numpy`、`pandas`、`scipy`、`matplotlib`、`seaborn`、`yaml`、`tqdm`、`rich`、`polars`、`lxml`、`mne`。
|
||||
2. 大部分脚本依赖绝对路径,迁移环境时优先修改 YAML。
|
||||
3. `event_mask_process/SHHS1_process.py` 目前基本为空,占位多于实现。
|
||||
4. `event_mask_process/HYS_PSG_process.py` 当前是“简化版标签生成”,核心只用了 SA 标签和睡眠分期,`Resp_*` / `BCG_*` 掩码还没有真正实现。
|
||||
5. `dataset_builder/HYS_dataset.py` 里虽然保留了分段可视化入口,但实际绘图调用被注释掉了,因此默认不会导出分段图。
|
||||
6. `utils/signal_process.py` 和 `utils/filter_func.py` 有部分重复实现;当前实际被 `utils/__init__.py` 导出并广泛使用的是 `filter_func.py` 中的版本。
|
||||
7. `README.md` 目前非常简略,真正的项目逻辑主要还是要靠源码理解。
|
||||
|
||||
---
|
||||
|
||||
## 14. 对 Agent 最有价值的结论
|
||||
|
||||
如果只能记住几件事,请记住下面这些:
|
||||
|
||||
1. 这是一个“先做逐秒掩码,再做固定窗口切片”的数据准备仓库。
|
||||
2. HYS 主线比 HYS_PSG 更完整,规则检测主要服务于 HYS。
|
||||
3. 切窗逻辑集中在 `utils/split_method.py`,不是分散在各个 builder 里。
|
||||
4. 原始数据不在仓库里,仓库只是代码和部分中间结果样例。
|
||||
5. 大多数运行问题都不是代码逻辑错,而是路径、文件命名、采样率和外部数据结构不匹配。
|
||||
|
||||
191
README.md
Normal file
191
README.md
Normal file
@ -0,0 +1,191 @@
|
||||
# DataPrepare
|
||||
|
||||
`DataPrepare` 是一个面向睡眠呼吸暂停数据的预处理仓库,用于把原始 BCG / PSG 信号与事件标签整理成可训练的数据集。
|
||||
|
||||
当前仓库的主线是:
|
||||
|
||||
1. 生成逐秒标签掩码
|
||||
2. 按固定窗口切分数据集
|
||||
3. 输出可视化结果用于人工检查
|
||||
|
||||
更详细的项目背景说明见 [AGENT_BACKGROUND_CN.md](./AGENT_BACKGROUND_CN.md)。
|
||||
|
||||
## 项目在做什么
|
||||
|
||||
仓库主要服务两类数据:
|
||||
|
||||
- `HYS`:以 `OrgBCG` 为核心,结合人工修正的呼吸暂停标签,提取呼吸分量、BCG 分量并生成逐秒可用性掩码
|
||||
- `HYS_PSG`:以 PSG 多通道信号为核心,整理胸腹带、流量、SpO2、RRI、睡眠分期与同步呼吸暂停标签
|
||||
|
||||
两条主线都采用同一个设计:
|
||||
|
||||
1. 先把整夜记录整理成逐秒标签
|
||||
2. 再根据标签切出固定长度窗口
|
||||
|
||||
## 仓库结构
|
||||
|
||||
| 目录 | 作用 |
|
||||
| --- | --- |
|
||||
| `dataset_config/` | 数据集配置文件,包含路径、采样率、阈值、切窗参数 |
|
||||
| `event_mask_process/` | 逐秒标签掩码生成脚本 |
|
||||
| `dataset_builder/` | 数据集切片与保存脚本 |
|
||||
| `signal_method/` | 信号预处理、规则检测、特征计算 |
|
||||
| `utils/` | 文件读取、标签转换、切窗、滤波等通用工具 |
|
||||
| `draw_tools/` | 全夜图、分段图和统计图绘制 |
|
||||
| `dataset_tools/` | 辅助脚本,如配对数据拷贝、SHHS 标注检查 |
|
||||
| `output/` | 仓库内的中间结果样例 |
|
||||
|
||||
## 核心流程
|
||||
|
||||
### HYS
|
||||
|
||||
对应脚本:
|
||||
|
||||
- `event_mask_process/HYS_process.py`
|
||||
- `dataset_builder/HYS_dataset.py`
|
||||
|
||||
流程概要:
|
||||
|
||||
1. 读取 `OrgBCG_Aligned/<id>/OrgBCG_Sync_*.txt`
|
||||
2. 读取 `Label/<id>/SA Label_corrected.csv`
|
||||
3. 对原始信号做陷波、呼吸分量提取和 BCG 分量提取
|
||||
4. 检测低幅值、体动、幅值变化,并结合 `排除区间.xlsx` 生成逐秒标签
|
||||
5. 保存 `Processed_Labels.csv`
|
||||
6. 根据 `window_sec` 与 `stride_sec` 切分训练窗口并保存 `npz`
|
||||
|
||||
### HYS_PSG
|
||||
|
||||
对应脚本:
|
||||
|
||||
- `event_mask_process/HYS_PSG_process.py`
|
||||
- `dataset_builder/HYS_PSG_dataset.py`
|
||||
|
||||
流程概要:
|
||||
|
||||
1. 读取 `PSG_Aligned/<id>/` 下的多通道信号
|
||||
2. 读取 `SA Label_Sync.csv`
|
||||
3. 根据同步标签与睡眠分期生成逐秒标签
|
||||
4. 对努力带、流量、SpO2、RRI 做统一重采样和长度对齐
|
||||
5. 切分窗口并保存 PSG 数据集
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 先检查配置
|
||||
|
||||
所有主脚本都依赖 `dataset_config/*.yaml` 中的绝对路径。
|
||||
在新机器或新数据目录下,优先修改这些配置文件:
|
||||
|
||||
- `dataset_config/HYS_config.yaml`
|
||||
- `dataset_config/HYS_PSG_config.yaml`
|
||||
- `dataset_config/ZD5Y_config.yaml`
|
||||
- `dataset_config/SHHS1_config.yaml`
|
||||
|
||||
重点检查:
|
||||
|
||||
- `root_path`
|
||||
- `mask_save_path` 或 `save_path`
|
||||
- `dataset_save_path`
|
||||
- `dataset_visual_path`
|
||||
- `select_ids`
|
||||
|
||||
### 2. 生成逐秒标签
|
||||
|
||||
HYS:
|
||||
|
||||
```bash
|
||||
python event_mask_process/HYS_process.py
|
||||
```
|
||||
|
||||
HYS_PSG:
|
||||
|
||||
```bash
|
||||
python event_mask_process/HYS_PSG_process.py
|
||||
```
|
||||
|
||||
执行后通常会在 `output/` 下生成:
|
||||
|
||||
- `*_Processed_Labels.csv`
|
||||
- `*_Signal_Plots.png`
|
||||
|
||||
### 3. 构建数据集
|
||||
|
||||
HYS:
|
||||
|
||||
```bash
|
||||
python dataset_builder/HYS_dataset.py
|
||||
```
|
||||
|
||||
HYS_PSG:
|
||||
|
||||
```bash
|
||||
python dataset_builder/HYS_PSG_dataset.py
|
||||
```
|
||||
|
||||
输出目录由对应 YAML 中的 `dataset_save_path` 控制,通常包含:
|
||||
|
||||
- `Signals/`
|
||||
- `Segments_List/`
|
||||
- `Labels/`
|
||||
|
||||
### 4. 可视化与辅助工具
|
||||
|
||||
配对拷贝原始数据:
|
||||
|
||||
```bash
|
||||
python dataset_tools/resp_pair_copy.py
|
||||
```
|
||||
|
||||
检查 SHHS XML 标注:
|
||||
|
||||
```bash
|
||||
python dataset_tools/shhs_annotations_check.py
|
||||
```
|
||||
|
||||
## 输入与输出约定
|
||||
|
||||
### 输入
|
||||
|
||||
仓库本身不包含原始数据,原始数据目录由 YAML 指定。代码默认外部数据目录中至少存在:
|
||||
|
||||
- `OrgBCG_Aligned/<id>/OrgBCG_Sync_*.txt`
|
||||
- `PSG_Aligned/<id>/...`
|
||||
- `Label/<id>/SA Label_corrected.csv`
|
||||
- `PSG_Aligned/<id>/SA Label_Sync.csv`
|
||||
|
||||
### 中间输出
|
||||
|
||||
仓库内 `output/` 保存的是中间标签与图像样例,不是最终训练数据集。例如:
|
||||
|
||||
- `output/HYS/<id>/<id>_Processed_Labels.csv`
|
||||
- `output/HYS_PSG/<id>/<id>_Processed_Labels.csv`
|
||||
|
||||
### 最终输出
|
||||
|
||||
由 `dataset_builder/` 写入 YAML 中配置的外部目录。典型结构为:
|
||||
|
||||
- `Signals/*.npz`
|
||||
- `Segments_List/*.npz`
|
||||
- `Labels/*.csv`
|
||||
|
||||
## 关键文件
|
||||
|
||||
如果你是第一次阅读这个仓库,推荐优先看:
|
||||
|
||||
1. `dataset_config/HYS_config.yaml`
|
||||
2. `event_mask_process/HYS_process.py`
|
||||
3. `utils/operation_tools.py`
|
||||
4. `utils/split_method.py`
|
||||
5. `dataset_builder/HYS_dataset.py`
|
||||
6. `signal_method/rule_base_event.py`
|
||||
|
||||
## 当前状态与注意事项
|
||||
|
||||
- 仓库目前没有依赖清单文件,常见依赖包括 `numpy`、`pandas`、`scipy`、`matplotlib`、`seaborn`、`yaml`、`tqdm`、`rich`、`polars`、`lxml`、`mne`
|
||||
- 大多数脚本没有命令行参数接口,配置文件路径直接写在脚本 `__main__` 中
|
||||
- `event_mask_process/SHHS1_process.py` 目前基本还是占位
|
||||
- `event_mask_process/HYS_PSG_process.py` 当前实现偏简化,`Resp_*` / `BCG_*` 掩码尚未真正展开
|
||||
- `output/` 里的文件更适合拿来理解格式与结果,不代表完整数据集
|
||||
|
||||
## 相关文档
|
||||
|
||||
- 详细项目背景:[AGENT_BACKGROUND_CN.md](./AGENT_BACKGROUND_CN.md)
|
||||
395
dataset_builder/HYS_PSG_dataset.py
Normal file
395
dataset_builder/HYS_PSG_dataset.py
Normal file
@ -0,0 +1,395 @@
|
||||
import multiprocessing
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from utils import N2Chn
|
||||
|
||||
os.environ['DISPLAY'] = "localhost:10.0"
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
project_root_path = Path(__file__).resolve().parent.parent
|
||||
|
||||
import utils
|
||||
import signal_method
|
||||
import draw_tools
|
||||
import shutil
|
||||
import gc
|
||||
|
||||
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None):
|
||||
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
|
||||
|
||||
total_seconds = min(
|
||||
psg_data[i]["second"] for i in N2Chn.values() if i != "Rpeak"
|
||||
)
|
||||
for i in N2Chn.values():
|
||||
if i == "Rpeak":
|
||||
continue
|
||||
length = int(total_seconds * psg_data[i]["fs"])
|
||||
psg_data[i]["data"] = psg_data[i]["data"][:length]
|
||||
psg_data[i]["length"] = length
|
||||
psg_data[i]["second"] = total_seconds
|
||||
|
||||
psg_data["HR"] = {
|
||||
"name": "HR",
|
||||
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"],
|
||||
psg_data["Rpeak"]["fs"]),
|
||||
"fs": psg_data["ECG_Sync"]["fs"],
|
||||
"length": psg_data["ECG_Sync"]["length"],
|
||||
"second": psg_data["ECG_Sync"]["second"]
|
||||
}
|
||||
# 预处理与滤波
|
||||
tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Tho"]["data"], effort_fs=psg_data["Effort Tho"]["fs"])
|
||||
abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Abd"]["data"], effort_fs=psg_data["Effort Abd"]["fs"])
|
||||
flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow P"]["data"], effort_fs=psg_data["Flow P"]["fs"])
|
||||
flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow T"]["data"], effort_fs=psg_data["Flow T"]["fs"])
|
||||
|
||||
rri, rri_fs = signal_method.rpeak2rri_interpolation(rpeak_indices=psg_data["Rpeak"]["data"], ecg_fs=psg_data["ECG_Sync"]["fs"], rri_fs=100)
|
||||
|
||||
|
||||
mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv")
|
||||
if verbose:
|
||||
print(f"mask_excel_path: {mask_excel_path}")
|
||||
|
||||
event_mask, event_list = utils.read_mask_execl(mask_excel_path)
|
||||
|
||||
enable_list = [[0, psg_data["Effort Tho"]["second"]]]
|
||||
normalized_tho_signal = signal_method.normalize_resp_signal_by_segment(tho_data_filt, tho_fs, np.zeros(psg_data["Effort Tho"]["second"]), enable_list)
|
||||
normalized_abd_signal = signal_method.normalize_resp_signal_by_segment(abd_data_filt, abd_fs, np.zeros(psg_data["Effort Abd"]["second"]), enable_list)
|
||||
normalized_flowp_signal = signal_method.normalize_resp_signal_by_segment(flowp_data_filt, flowp_fs, np.zeros(psg_data["Flow P"]["second"]), enable_list)
|
||||
normalized_flowt_signal = signal_method.normalize_resp_signal_by_segment(flowt_data_filt, flowt_fs, np.zeros(psg_data["Flow T"]["second"]), enable_list)
|
||||
|
||||
|
||||
# 都调整至100Hz采样率
|
||||
target_fs = conf["target_fs"]
|
||||
normalized_tho_signal = utils.adjust_sample_rate(normalized_tho_signal, tho_fs, target_fs)
|
||||
normalized_abd_signal = utils.adjust_sample_rate(normalized_abd_signal, abd_fs, target_fs)
|
||||
normalized_flowp_signal = utils.adjust_sample_rate(normalized_flowp_signal, flowp_fs, target_fs)
|
||||
normalized_flowt_signal = utils.adjust_sample_rate(normalized_flowt_signal, flowt_fs, target_fs)
|
||||
spo2_data_filt = utils.adjust_sample_rate(psg_data["SpO2"]["data"], psg_data["SpO2"]["fs"], target_fs)
|
||||
normalized_effort_signal = (normalized_tho_signal + normalized_abd_signal) / 2
|
||||
rri = utils.adjust_sample_rate(rri, rri_fs, target_fs)
|
||||
|
||||
# 调整至相同长度
|
||||
min_length = min(len(normalized_tho_signal), len(normalized_abd_signal), len(normalized_flowp_signal), len(normalized_flowt_signal), len(spo2_data_filt), len(normalized_effort_signal)
|
||||
,len(rri))
|
||||
min_length = min_length - min_length % target_fs # 保证是整数秒
|
||||
normalized_tho_signal = normalized_tho_signal[:min_length]
|
||||
normalized_abd_signal = normalized_abd_signal[:min_length]
|
||||
normalized_flowp_signal = normalized_flowp_signal[:min_length]
|
||||
normalized_flowt_signal = normalized_flowt_signal[:min_length]
|
||||
spo2_data_filt = spo2_data_filt[:min_length]
|
||||
normalized_effort_signal = normalized_effort_signal[:min_length]
|
||||
rri = rri[:min_length]
|
||||
|
||||
tho_second = min_length / target_fs
|
||||
for i in event_mask.keys():
|
||||
event_mask[i] = event_mask[i][:int(tho_second)]
|
||||
|
||||
spo2_data_filt_fill, spo2_disable_mask = utils.fill_spo2_anomaly(spo2_data=spo2_data_filt,
|
||||
spo2_fs=target_fs,
|
||||
max_fill_duration=conf["spo2_fill__anomaly"]["max_fill_duration"],
|
||||
min_gap_duration=conf["spo2_fill__anomaly"]["min_gap_duration"])
|
||||
|
||||
|
||||
draw_tools.draw_psg_signal(
|
||||
samp_id=samp_id,
|
||||
tho_signal=normalized_tho_signal,
|
||||
abd_signal=normalized_abd_signal,
|
||||
flowp_signal=normalized_flowp_signal,
|
||||
flowt_signal=normalized_flowt_signal,
|
||||
spo2_signal=spo2_data_filt,
|
||||
effort_signal=normalized_effort_signal,
|
||||
rri_signal = rri,
|
||||
fs=target_fs,
|
||||
event_mask=event_mask["SA_Label"],
|
||||
save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots.png",
|
||||
show=show
|
||||
)
|
||||
|
||||
draw_tools.draw_psg_signal(
|
||||
samp_id=samp_id,
|
||||
tho_signal=normalized_tho_signal,
|
||||
abd_signal=normalized_abd_signal,
|
||||
flowp_signal=normalized_flowp_signal,
|
||||
flowt_signal=normalized_flowt_signal,
|
||||
spo2_signal=spo2_data_filt_fill,
|
||||
effort_signal=normalized_effort_signal,
|
||||
rri_signal = rri,
|
||||
fs=target_fs,
|
||||
event_mask=event_mask["SA_Label"],
|
||||
save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots_fill.png",
|
||||
show=show
|
||||
)
|
||||
|
||||
spo2_disable_mask = spo2_disable_mask[::target_fs]
|
||||
min_len = min(len(event_mask["Disable_Label"]), len(spo2_disable_mask))
|
||||
|
||||
if len(event_mask["Disable_Label"]) != len(spo2_disable_mask):
|
||||
print(f"Warning: Data length mismatch! Truncating to {min_len}.")
|
||||
event_mask["Disable_Label"] = event_mask["Disable_Label"][:min_len] & spo2_disable_mask[:min_len]
|
||||
|
||||
event_list = {
|
||||
"EnableSegment": utils.event_mask_2_list(1 - event_mask["Disable_Label"]),
|
||||
"DisableSegment": utils.event_mask_2_list(event_mask["Disable_Label"])}
|
||||
|
||||
spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=conf["spo2_fill__anomaly"]["nan_to_num_value"])
|
||||
|
||||
segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose)
|
||||
if verbose:
|
||||
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
|
||||
|
||||
|
||||
# 复制mask到processed_Labels文件夹
|
||||
save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv"
|
||||
shutil.copyfile(mask_excel_path, save_mask_excel_path)
|
||||
|
||||
# 复制SA Label_corrected.csv到processed_Labels文件夹
|
||||
sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_Sync.csv")
|
||||
if sa_label_corrected_path.exists():
|
||||
save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv"
|
||||
shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path)
|
||||
else:
|
||||
if verbose:
|
||||
print(f"Warning: {sa_label_corrected_path} does not exist.")
|
||||
|
||||
# 保存处理后的信号和截取的片段列表
|
||||
save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz"
|
||||
save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz"
|
||||
|
||||
# psg_data更新为处理后的信号
|
||||
# 用下划线替换键里面的空格
|
||||
psg_data = {
|
||||
"Effort Tho": {
|
||||
"name": "Effort_Tho",
|
||||
"data": normalized_tho_signal,
|
||||
"fs": target_fs,
|
||||
"length": len(normalized_tho_signal),
|
||||
"second": len(normalized_tho_signal) / target_fs
|
||||
},
|
||||
"Effort Abd": {
|
||||
"name": "Effort_Abd",
|
||||
"data": normalized_abd_signal,
|
||||
"fs": target_fs,
|
||||
"length": len(normalized_abd_signal),
|
||||
"second": len(normalized_abd_signal) / target_fs
|
||||
},
|
||||
"Effort": {
|
||||
"name": "Effort",
|
||||
"data": normalized_effort_signal,
|
||||
"fs": target_fs,
|
||||
"length": len(normalized_effort_signal),
|
||||
"second": len(normalized_effort_signal) / target_fs
|
||||
},
|
||||
"Flow P": {
|
||||
"name": "Flow_P",
|
||||
"data": normalized_flowp_signal,
|
||||
"fs": target_fs,
|
||||
"length": len(normalized_flowp_signal),
|
||||
"second": len(normalized_flowp_signal) / target_fs
|
||||
},
|
||||
"Flow T": {
|
||||
"name": "Flow_T",
|
||||
"data": normalized_flowt_signal,
|
||||
"fs": target_fs,
|
||||
"length": len(normalized_flowt_signal),
|
||||
"second": len(normalized_flowt_signal) / target_fs
|
||||
},
|
||||
"SpO2": {
|
||||
"name": "SpO2",
|
||||
"data": spo2_data_filt_fill,
|
||||
"fs": target_fs,
|
||||
"length": len(spo2_data_filt_fill),
|
||||
"second": len(spo2_data_filt_fill) / target_fs
|
||||
},
|
||||
"HR": {
|
||||
"name": "HR",
|
||||
"data": psg_data["HR"]["data"],
|
||||
"fs": psg_data["HR"]["fs"],
|
||||
"length": psg_data["HR"]["length"],
|
||||
"second": psg_data["HR"]["second"]
|
||||
},
|
||||
"RRI": {
|
||||
"name": "RRI",
|
||||
"data": rri,
|
||||
"fs": target_fs,
|
||||
"length": len(rri),
|
||||
"second": len(rri) / target_fs
|
||||
},
|
||||
"5_class": {
|
||||
"name": "Stage",
|
||||
"data": psg_data["5_class"]["data"],
|
||||
"fs": psg_data["5_class"]["fs"],
|
||||
"length": psg_data["5_class"]["length"],
|
||||
"second": psg_data["5_class"]["second"]
|
||||
}
|
||||
}
|
||||
|
||||
np.savez_compressed(save_signal_path, **psg_data)
|
||||
np.savez_compressed(save_segment_path,
|
||||
segment_list=segment_list,
|
||||
disable_segment_list=disable_segment_list)
|
||||
if verbose:
|
||||
print(f"Saved processed signals to: {save_signal_path}")
|
||||
print(f"Saved segment list to: {save_segment_path}")
|
||||
|
||||
if draw_segment:
|
||||
total_len = len(segment_list) + len(disable_segment_list)
|
||||
if verbose:
|
||||
print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}")
|
||||
|
||||
|
||||
draw_tools.draw_psg_label(
|
||||
psg_data=psg_data,
|
||||
psg_label=event_mask["SA_Label"],
|
||||
segment_list=segment_list,
|
||||
save_path=visual_path / f"{samp_id}" / "enable",
|
||||
verbose=verbose,
|
||||
multi_p=multi_p,
|
||||
multi_task_id=multi_task_id
|
||||
)
|
||||
|
||||
draw_tools.draw_psg_label(
|
||||
psg_data=psg_data,
|
||||
psg_label=event_mask["SA_Label"],
|
||||
segment_list=disable_segment_list,
|
||||
save_path=visual_path / f"{samp_id}" / "disable",
|
||||
verbose=verbose,
|
||||
multi_p=multi_p,
|
||||
multi_task_id=multi_task_id
|
||||
)
|
||||
|
||||
# 显式删除大型对象
|
||||
try:
|
||||
del psg_data
|
||||
del normalized_tho_signal, normalized_abd_signal
|
||||
del normalized_flowp_signal, normalized_flowt_signal
|
||||
del normalized_effort_signal
|
||||
del spo2_data_filt, spo2_data_filt_fill
|
||||
del rri
|
||||
del event_mask, event_list
|
||||
del segment_list, disable_segment_list
|
||||
except:
|
||||
pass
|
||||
|
||||
# 强制垃圾回收
|
||||
gc.collect()
|
||||
|
||||
|
||||
|
||||
def multiprocess_entry(_progress, task_id, _id):
|
||||
build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id)
|
||||
|
||||
|
||||
def multiprocess_with_tqdm(args_list, n_processes):
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from rich import progress
|
||||
|
||||
with progress.Progress(
|
||||
"[progress.description]{task.description}",
|
||||
progress.BarColumn(),
|
||||
"[progress.percentage]{task.percentage:>3.0f}%",
|
||||
progress.MofNCompleteColumn(),
|
||||
progress.TimeRemainingColumn(),
|
||||
progress.TimeElapsedColumn(),
|
||||
refresh_per_second=1, # bit slower updates
|
||||
transient=False
|
||||
) as progress:
|
||||
futures = []
|
||||
with multiprocessing.Manager() as manager:
|
||||
_progress = manager.dict()
|
||||
overall_progress_task = progress.add_task("[green]All jobs progress:")
|
||||
with ProcessPoolExecutor(max_workers=n_processes, mp_context=multiprocessing.get_context("spawn")) as executor:
|
||||
for i_args in range(len(args_list)):
|
||||
args = args_list[i_args]
|
||||
task_id = progress.add_task(f"task {i_args}", visible=True)
|
||||
futures.append(executor.submit(multiprocess_entry, _progress, task_id, args_list[i_args]))
|
||||
# monitor the progress:
|
||||
while (n_finished := sum([future.done() for future in futures])) < len(
|
||||
futures
|
||||
):
|
||||
progress.update(
|
||||
overall_progress_task, completed=n_finished, total=len(futures)
|
||||
)
|
||||
for task_id, update_data in _progress.items():
|
||||
desc = update_data.get("desc", "")
|
||||
# update the progress bar for this task:
|
||||
progress.update(
|
||||
task_id,
|
||||
completed=update_data.get("progress", 0),
|
||||
total=update_data.get("total", 0),
|
||||
description=desc
|
||||
)
|
||||
|
||||
# raise any errors:
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
|
||||
def multiprocess_with_pool(args_list, n_processes):
|
||||
"""使用Pool,每个worker处理固定数量任务后重启"""
|
||||
from multiprocessing import Pool
|
||||
|
||||
# maxtasksperchild 设置每个worker处理多少任务后重启(释放内存)
|
||||
with Pool(processes=n_processes, maxtasksperchild=2) as pool:
|
||||
results = []
|
||||
for samp_id in args_list:
|
||||
result = pool.apply_async(
|
||||
build_HYS_dataset_segment,
|
||||
args=(samp_id, False, True, False, None, None)
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# 等待所有任务完成
|
||||
for i, result in enumerate(results):
|
||||
try:
|
||||
result.get()
|
||||
print(f"Completed {i + 1}/{len(args_list)}")
|
||||
except Exception as e:
|
||||
print(f"Error processing task {i}: {e}")
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml"
|
||||
|
||||
conf = utils.load_dataset_conf(yaml_path)
|
||||
select_ids = conf["select_ids"]
|
||||
root_path = Path(conf["root_path"])
|
||||
mask_path = Path(conf["mask_save_path"])
|
||||
save_path = Path(conf["dataset_config"]["dataset_save_path"])
|
||||
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
|
||||
dataset_config = conf["dataset_config"]
|
||||
|
||||
visual_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_processed_signal_path = save_path / "Signals"
|
||||
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_segment_list_path = save_path / "Segments_List"
|
||||
save_segment_list_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_processed_label_path = save_path / "Labels"
|
||||
save_processed_label_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# print(f"select_ids: {select_ids}")
|
||||
# print(f"root_path: {root_path}")
|
||||
# print(f"save_path: {save_path}")
|
||||
# print(f"visual_path: {visual_path}")
|
||||
|
||||
org_signal_root_path = root_path / "OrgBCG_Aligned"
|
||||
psg_signal_root_path = root_path / "PSG_Aligned"
|
||||
print(select_ids)
|
||||
|
||||
# build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True)
|
||||
|
||||
# for samp_id in select_ids:
|
||||
# print(f"Processing sample ID: {samp_id}")
|
||||
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
|
||||
|
||||
# multiprocess_with_tqdm(args_list=select_ids, n_processes=8)
|
||||
multiprocess_with_pool(args_list=select_ids, n_processes=8)
|
||||
233
dataset_builder/HYS_dataset.py
Normal file
233
dataset_builder/HYS_dataset.py
Normal file
@ -0,0 +1,233 @@
|
||||
import multiprocessing
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
os.environ['DISPLAY'] = "localhost:10.0"
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
project_root_path = Path(__file__).resolve().parent.parent
|
||||
|
||||
import utils
|
||||
import signal_method
|
||||
import draw_tools
|
||||
import shutil
|
||||
|
||||
|
||||
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None):
|
||||
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
|
||||
if not signal_path:
|
||||
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
|
||||
signal_path = signal_path[0]
|
||||
if verbose:
|
||||
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
|
||||
|
||||
mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv")
|
||||
if verbose:
|
||||
print(f"mask_excel_path: {mask_excel_path}")
|
||||
|
||||
event_mask, event_list = utils.read_mask_execl(mask_excel_path)
|
||||
|
||||
bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=verbose)
|
||||
|
||||
bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs, verbose=verbose)
|
||||
normalized_resp_signal = signal_method.normalize_resp_signal_by_segment(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"])
|
||||
|
||||
|
||||
# 如果signal_data采样率过,进行降采样
|
||||
if signal_fs == 1000:
|
||||
bcg_signal_notch = utils.downsample_signal_fast(original_signal=bcg_signal_notch, original_fs=signal_fs, target_fs=100)
|
||||
bcg_signal_raw = utils.downsample_signal_fast(original_signal=bcg_signal_raw, original_fs=signal_fs,
|
||||
target_fs=100)
|
||||
signal_fs = 100
|
||||
|
||||
if bcg_fs == 1000:
|
||||
bcg_signal = utils.downsample_signal_fast(original_signal=bcg_signal, original_fs=bcg_fs, target_fs=100)
|
||||
bcg_fs = 100
|
||||
|
||||
# draw_tools.draw_signal_with_mask(samp_id=samp_id,
|
||||
# signal_data=resp_signal,
|
||||
# signal_fs=resp_fs,
|
||||
# resp_data=normalized_resp_signal,
|
||||
# resp_fs=resp_fs,
|
||||
# bcg_data=bcg_signal,
|
||||
# bcg_fs=bcg_fs,
|
||||
# signal_disable_mask=event_mask["Disable_Label"],
|
||||
# resp_low_amp_mask=event_mask["Resp_LowAmp_Label"],
|
||||
# resp_movement_mask=event_mask["Resp_Movement_Label"],
|
||||
# resp_change_mask=event_mask["Resp_AmpChange_Label"],
|
||||
# resp_sa_mask=event_mask["SA_Label"],
|
||||
# bcg_low_amp_mask=event_mask["BCG_LowAmp_Label"],
|
||||
# bcg_movement_mask=event_mask["BCG_Movement_Label"],
|
||||
# bcg_change_mask=event_mask["BCG_AmpChange_Label"],
|
||||
# show=show,
|
||||
# save_path=None)
|
||||
|
||||
segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose)
|
||||
if verbose:
|
||||
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
|
||||
|
||||
|
||||
# 复制mask到processed_Labels文件夹
|
||||
save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv"
|
||||
shutil.copyfile(mask_excel_path, save_mask_excel_path)
|
||||
|
||||
# 复制SA Label_corrected.csv到processed_Labels文件夹
|
||||
sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_corrected.csv")
|
||||
if sa_label_corrected_path.exists():
|
||||
save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv"
|
||||
shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path)
|
||||
else:
|
||||
if verbose:
|
||||
print(f"Warning: {sa_label_corrected_path} does not exist.")
|
||||
|
||||
# 保存处理后的信号和截取的片段列表
|
||||
save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz"
|
||||
save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz"
|
||||
|
||||
bcg_data = {
|
||||
"bcg_signal_notch": {
|
||||
"name": "BCG_Signal_Notch",
|
||||
"data": bcg_signal_notch,
|
||||
"fs": signal_fs,
|
||||
"length": len(bcg_signal_notch),
|
||||
"second": len(bcg_signal_notch) // signal_fs
|
||||
},
|
||||
"bcg_signal":{
|
||||
"name": "BCG_Signal_Raw",
|
||||
"data": bcg_signal,
|
||||
"fs": bcg_fs,
|
||||
"length": len(bcg_signal),
|
||||
"second": len(bcg_signal) // bcg_fs
|
||||
},
|
||||
"resp_signal": {
|
||||
"name": "Resp_Signal",
|
||||
"data": normalized_resp_signal,
|
||||
"fs": resp_fs,
|
||||
"length": len(normalized_resp_signal),
|
||||
"second": len(normalized_resp_signal) // resp_fs
|
||||
}
|
||||
}
|
||||
|
||||
np.savez_compressed(save_signal_path, **bcg_data)
|
||||
np.savez_compressed(save_segment_path,
|
||||
segment_list=segment_list,
|
||||
disable_segment_list=disable_segment_list)
|
||||
if verbose:
|
||||
print(f"Saved processed signals to: {save_signal_path}")
|
||||
print(f"Saved segment list to: {save_segment_path}")
|
||||
|
||||
if draw_segment:
|
||||
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
|
||||
psg_data["HR"] = {
|
||||
"name": "HR",
|
||||
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"], psg_data["Rpeak"]["fs"]),
|
||||
"fs": psg_data["ECG_Sync"]["fs"],
|
||||
"length": psg_data["ECG_Sync"]["length"],
|
||||
"second": psg_data["ECG_Sync"]["second"]
|
||||
}
|
||||
|
||||
|
||||
psg_label = utils.read_psg_label(sa_label_corrected_path, verbose=verbose)
|
||||
psg_event_mask, _ = utils.generate_event_mask(event_df=psg_label, signal_second=psg_data["ECG_Sync"]["second"], use_correct=False)
|
||||
total_len = len(segment_list) + len(disable_segment_list)
|
||||
if verbose:
|
||||
print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}")
|
||||
|
||||
# draw_tools.draw_psg_bcg_label(psg_data=psg_data,
|
||||
# psg_label=psg_event_mask,
|
||||
# bcg_data=bcg_data,
|
||||
# event_mask=event_mask,
|
||||
# segment_list=segment_list,
|
||||
# save_path=visual_path / f"{samp_id}" / "enable",
|
||||
# verbose=verbose,
|
||||
# multi_p=multi_p,
|
||||
# multi_task_id=multi_task_id
|
||||
# )
|
||||
#
|
||||
# draw_tools.draw_psg_bcg_label(
|
||||
# psg_data=psg_data,
|
||||
# psg_label=psg_event_mask,
|
||||
# bcg_data=bcg_data,
|
||||
# event_mask=event_mask,
|
||||
# segment_list=disable_segment_list,
|
||||
# save_path=visual_path / f"{samp_id}" / "disable",
|
||||
# verbose=verbose,
|
||||
# multi_p=multi_p,
|
||||
# multi_task_id=multi_task_id
|
||||
# )
|
||||
|
||||
|
||||
|
||||
def multiprocess_entry(_progress, task_id, _id):
|
||||
build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id)
|
||||
|
||||
|
||||
def multiprocess_with_pool(args_list, n_processes):
|
||||
"""使用Pool,每个worker处理固定数量任务后重启"""
|
||||
from multiprocessing import Pool
|
||||
|
||||
# maxtasksperchild 设置每个worker处理多少任务后重启(释放内存)
|
||||
with Pool(processes=n_processes, maxtasksperchild=2) as pool:
|
||||
results = []
|
||||
for samp_id in args_list:
|
||||
result = pool.apply_async(
|
||||
build_HYS_dataset_segment,
|
||||
args=(samp_id, False, True, False, None, None)
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# 等待所有任务完成
|
||||
for i, result in enumerate(results):
|
||||
try:
|
||||
result.get()
|
||||
print(f"Completed {i + 1}/{len(args_list)}")
|
||||
except Exception as e:
|
||||
print(f"Error processing task {i}: {e}")
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
yaml_path = project_root_path / "dataset_config/HYS_config.yaml"
|
||||
|
||||
conf = utils.load_dataset_conf(yaml_path)
|
||||
select_ids = conf["select_ids"]
|
||||
root_path = Path(conf["root_path"])
|
||||
mask_path = Path(conf["mask_save_path"])
|
||||
save_path = Path(conf["dataset_config"]["dataset_save_path"])
|
||||
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
|
||||
dataset_config = conf["dataset_config"]
|
||||
|
||||
visual_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_processed_signal_path = save_path / "Signals"
|
||||
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_segment_list_path = save_path / "Segments_List"
|
||||
save_segment_list_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_processed_label_path = save_path / "Labels"
|
||||
save_processed_label_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# print(f"select_ids: {select_ids}")
|
||||
# print(f"root_path: {root_path}")
|
||||
# print(f"save_path: {save_path}")
|
||||
# print(f"visual_path: {visual_path}")
|
||||
|
||||
org_signal_root_path = root_path / "OrgBCG_Aligned"
|
||||
psg_signal_root_path = root_path / "PSG_Aligned"
|
||||
print(select_ids)
|
||||
|
||||
# build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True)
|
||||
|
||||
# for samp_id in select_ids:
|
||||
# print(f"Processing sample ID: {samp_id}")
|
||||
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
|
||||
|
||||
# multiprocess_with_tqdm(args_list=select_ids, n_processes=16)
|
||||
multiprocess_with_pool(args_list=select_ids, n_processes=16)
|
||||
0
dataset_builder/__init__.py
Normal file
0
dataset_builder/__init__.py
Normal file
139
dataset_config/HYS_PSG_config.yaml
Normal file
139
dataset_config/HYS_PSG_config.yaml
Normal file
@ -0,0 +1,139 @@
|
||||
select_ids:
|
||||
- 54
|
||||
- 88
|
||||
- 220
|
||||
- 221
|
||||
- 229
|
||||
- 282
|
||||
- 286
|
||||
- 541
|
||||
- 579
|
||||
- 582
|
||||
- 670
|
||||
- 671
|
||||
- 683
|
||||
- 684
|
||||
- 735
|
||||
- 933
|
||||
- 935
|
||||
- 950
|
||||
- 952
|
||||
- 960
|
||||
- 962
|
||||
- 967
|
||||
- 1302
|
||||
|
||||
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
|
||||
mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG
|
||||
|
||||
target_fs: 100
|
||||
|
||||
dataset_config:
|
||||
window_sec: 180
|
||||
stride_sec: 60
|
||||
dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_PSG_dataset
|
||||
dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_PSG_dataset/visualization
|
||||
|
||||
|
||||
effort:
|
||||
downsample_fs: 10
|
||||
|
||||
effort_filter:
|
||||
filter_type: bandpass
|
||||
low_cut: 0.05
|
||||
high_cut: 0.5
|
||||
order: 3
|
||||
|
||||
average_filter:
|
||||
window_size_sec: 20
|
||||
|
||||
flow:
|
||||
downsample_fs: 10
|
||||
|
||||
flow_filter:
|
||||
filter_type: bandpass
|
||||
low_cut: 0.05
|
||||
high_cut: 0.5
|
||||
order: 3
|
||||
|
||||
spo2_fill__anomaly:
|
||||
max_fill_duration: 30
|
||||
min_gap_duration: 10
|
||||
nan_to_num_value: 95
|
||||
|
||||
#ecg:
|
||||
# downsample_fs: 100
|
||||
#
|
||||
#ecg_filter:
|
||||
# filter_type: bandpass
|
||||
# low_cut: 0.5
|
||||
# high_cut: 40
|
||||
# order: 5
|
||||
|
||||
|
||||
#resp:
|
||||
# downsample_fs_1: None
|
||||
# downsample_fs_2: 10
|
||||
#
|
||||
#resp_filter:
|
||||
# filter_type: bandpass
|
||||
# low_cut: 0.05
|
||||
# high_cut: 0.5
|
||||
# order: 3
|
||||
#
|
||||
#resp_low_amp:
|
||||
# window_size_sec: 30
|
||||
# stride_sec:
|
||||
# amplitude_threshold: 3
|
||||
# merge_gap_sec: 60
|
||||
# min_duration_sec: 60
|
||||
#
|
||||
#resp_movement:
|
||||
# window_size_sec: 20
|
||||
# stride_sec: 1
|
||||
# std_median_multiplier: 4
|
||||
# compare_intervals_sec:
|
||||
# - 60
|
||||
# - 120
|
||||
## - 180
|
||||
# interval_multiplier: 3
|
||||
# merge_gap_sec: 30
|
||||
# min_duration_sec: 1
|
||||
#
|
||||
#resp_movement_revise:
|
||||
# up_interval_multiplier: 3
|
||||
# down_interval_multiplier: 2
|
||||
# compare_intervals_sec: 30
|
||||
# merge_gap_sec: 10
|
||||
# min_duration_sec: 1
|
||||
#
|
||||
#resp_amp_change:
|
||||
# mav_calc_window_sec: 4
|
||||
# threshold_amplitude: 0.25
|
||||
# threshold_energy: 0.4
|
||||
#
|
||||
#
|
||||
#bcg:
|
||||
# downsample_fs: 100
|
||||
#
|
||||
#bcg_filter:
|
||||
# filter_type: bandpass
|
||||
# low_cut: 1
|
||||
# high_cut: 10
|
||||
# order: 10
|
||||
#
|
||||
#bcg_low_amp:
|
||||
# window_size_sec: 1
|
||||
# stride_sec:
|
||||
# amplitude_threshold: 8
|
||||
# merge_gap_sec: 20
|
||||
# min_duration_sec: 3
|
||||
#
|
||||
#
|
||||
#bcg_movement:
|
||||
# window_size_sec: 2
|
||||
# stride_sec:
|
||||
# merge_gap_sec: 20
|
||||
# min_duration_sec: 4
|
||||
|
||||
|
||||
86
dataset_config/HYS_config.yaml
Normal file
86
dataset_config/HYS_config.yaml
Normal file
@ -0,0 +1,86 @@
|
||||
select_ids:
|
||||
- 1302
|
||||
- 286
|
||||
- 950
|
||||
- 220
|
||||
- 229
|
||||
- 541
|
||||
- 582
|
||||
- 670
|
||||
- 684
|
||||
- 960
|
||||
|
||||
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
|
||||
mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS
|
||||
|
||||
resp:
|
||||
downsample_fs_1: 100
|
||||
downsample_fs_2: 10
|
||||
|
||||
resp_filter:
|
||||
filter_type: bandpass
|
||||
low_cut: 0.05
|
||||
high_cut: 0.6
|
||||
order: 3
|
||||
|
||||
resp_low_amp:
|
||||
window_size_sec: 30
|
||||
stride_sec:
|
||||
amplitude_threshold: 3
|
||||
merge_gap_sec: 60
|
||||
min_duration_sec: 60
|
||||
|
||||
resp_movement:
|
||||
window_size_sec: 20
|
||||
stride_sec: 1
|
||||
std_median_multiplier: 4
|
||||
compare_intervals_sec:
|
||||
- 60
|
||||
- 120
|
||||
# - 180
|
||||
interval_multiplier: 3
|
||||
merge_gap_sec: 30
|
||||
min_duration_sec: 1
|
||||
|
||||
resp_movement_revise:
|
||||
up_interval_multiplier: 3
|
||||
down_interval_multiplier: 2
|
||||
compare_intervals_sec: 30
|
||||
merge_gap_sec: 10
|
||||
min_duration_sec: 1
|
||||
|
||||
resp_amp_change:
|
||||
mav_calc_window_sec: 4
|
||||
threshold_amplitude: 0.25
|
||||
threshold_energy: 0.4
|
||||
|
||||
|
||||
bcg:
|
||||
downsample_fs: 100
|
||||
|
||||
bcg_filter:
|
||||
filter_type: bandpass
|
||||
low_cut: 1
|
||||
high_cut: 10
|
||||
order: 10
|
||||
|
||||
bcg_low_amp:
|
||||
window_size_sec: 1
|
||||
stride_sec:
|
||||
amplitude_threshold: 8
|
||||
merge_gap_sec: 20
|
||||
min_duration_sec: 3
|
||||
|
||||
|
||||
bcg_movement:
|
||||
window_size_sec: 2
|
||||
stride_sec:
|
||||
merge_gap_sec: 20
|
||||
min_duration_sec: 4
|
||||
|
||||
|
||||
dataset_config:
|
||||
window_sec: 180
|
||||
stride_sec: 60
|
||||
dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset
|
||||
dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset/visualization
|
||||
56
dataset_config/RESP_PAIR_HYS_config.yaml
Normal file
56
dataset_config/RESP_PAIR_HYS_config.yaml
Normal file
@ -0,0 +1,56 @@
|
||||
select_ids:
|
||||
- 1000
|
||||
- 1004
|
||||
- 1006
|
||||
- 1009
|
||||
- 1010
|
||||
- 1300
|
||||
- 1301
|
||||
- 1302
|
||||
- 1308
|
||||
- 1314
|
||||
- 1354
|
||||
- 1374
|
||||
- 1378
|
||||
- 1478
|
||||
- 220
|
||||
- 221
|
||||
- 229
|
||||
- 282
|
||||
- 285
|
||||
- 286
|
||||
- 54
|
||||
- 541
|
||||
- 579
|
||||
- 582
|
||||
- 670
|
||||
- 671
|
||||
- 683
|
||||
- 684
|
||||
- 686
|
||||
- 703
|
||||
- 704
|
||||
- 726
|
||||
- 735
|
||||
- 736
|
||||
- 88
|
||||
- 893
|
||||
- 933
|
||||
- 935
|
||||
- 939
|
||||
- 950
|
||||
- 952
|
||||
- 954
|
||||
- 955
|
||||
- 956
|
||||
- 960
|
||||
- 961
|
||||
- 962
|
||||
- 967
|
||||
- 969
|
||||
- 971
|
||||
- 972
|
||||
|
||||
|
||||
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
|
||||
pair_file_path: /mnt/disk_wd/marques_dataset/Resp_Pair_Dataset/HYS/Raw
|
||||
32
dataset_config/RESP_PAIR_ZD5Y_config.yaml
Normal file
32
dataset_config/RESP_PAIR_ZD5Y_config.yaml
Normal file
@ -0,0 +1,32 @@
|
||||
select_ids:
|
||||
- 3103
|
||||
- 3105
|
||||
- 3106
|
||||
- 3107
|
||||
- 3108
|
||||
- 3110
|
||||
- 3203
|
||||
- 3204
|
||||
- 3205
|
||||
- 3209
|
||||
- 3211
|
||||
- 3301
|
||||
- 3303
|
||||
- 3304
|
||||
- 3307
|
||||
- 3308
|
||||
- 3309
|
||||
- 3403
|
||||
- 3405
|
||||
- 3406
|
||||
- 3407
|
||||
- 3408
|
||||
- 3504
|
||||
|
||||
|
||||
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/ZD5Y
|
||||
pair_file_path: /mnt/disk_wd/marques_dataset/Resp_Pair_Dataset/ZD5Y/Raw
|
||||
|
||||
|
||||
|
||||
|
||||
41
dataset_config/SHHS1_config.yaml
Normal file
41
dataset_config/SHHS1_config.yaml
Normal file
@ -0,0 +1,41 @@
|
||||
root_path: /mnt/disk_wd/marques_dataset/shhs/polysomnography/shhs1
|
||||
mask_save_path: /mnt/disk_code/marques/dataprepare/output/shhs1
|
||||
|
||||
effort_target_fs: 10
|
||||
ecg_target_fs: 100
|
||||
|
||||
|
||||
dataset_config:
|
||||
window_sec: 180
|
||||
stride_sec: 60
|
||||
dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/SHHS1_dataset
|
||||
dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/SHHS1_dataset/visualization
|
||||
|
||||
|
||||
effort:
|
||||
downsample_fs: 10
|
||||
|
||||
effort_filter:
|
||||
filter_type: bandpass
|
||||
low_cut: 0.05
|
||||
high_cut: 0.5
|
||||
order: 3
|
||||
|
||||
average_filter:
|
||||
window_size_sec: 20
|
||||
|
||||
flow:
|
||||
downsample_fs: 10
|
||||
|
||||
flow_filter:
|
||||
filter_type: bandpass
|
||||
low_cut: 0.05
|
||||
high_cut: 0.5
|
||||
order: 3
|
||||
|
||||
spo2_fill__anomaly:
|
||||
max_fill_duration: 30
|
||||
min_gap_duration: 10
|
||||
nan_to_num_value: 95
|
||||
|
||||
|
||||
88
dataset_config/ZD5Y_config.yaml
Normal file
88
dataset_config/ZD5Y_config.yaml
Normal file
@ -0,0 +1,88 @@
|
||||
select_ids:
|
||||
- 3103
|
||||
- 3105
|
||||
- 3106
|
||||
- 3107
|
||||
- 3108
|
||||
- 3110
|
||||
- 3203
|
||||
- 3204
|
||||
- 3205
|
||||
- 3207
|
||||
- 3208
|
||||
- 3209
|
||||
- 3212
|
||||
- 3301
|
||||
- 3303
|
||||
- 3307
|
||||
- 3403
|
||||
- 3504
|
||||
|
||||
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/ZD5Y
|
||||
save_path: /mnt/disk_code/marques/dataprepare/output/ZD5Y
|
||||
|
||||
resp:
|
||||
downsample_fs_1: 100
|
||||
downsample_fs_2: 10
|
||||
|
||||
resp_filter:
|
||||
filter_type: bandpass
|
||||
low_cut: 0.01
|
||||
high_cut: 0.7
|
||||
order: 3
|
||||
|
||||
resp_low_amp:
|
||||
window_size_sec: 30
|
||||
stride_sec:
|
||||
amplitude_threshold: 3
|
||||
merge_gap_sec: 60
|
||||
min_duration_sec: 60
|
||||
|
||||
resp_movement:
|
||||
window_size_sec: 20
|
||||
stride_sec: 1
|
||||
std_median_multiplier: 5
|
||||
compare_intervals_sec:
|
||||
- 60
|
||||
- 120
|
||||
# - 180
|
||||
interval_multiplier: 3.5
|
||||
merge_gap_sec: 30
|
||||
min_duration_sec: 1
|
||||
|
||||
resp_movement_revise:
|
||||
up_interval_multiplier: 3
|
||||
down_interval_multiplier: 2
|
||||
compare_intervals_sec: 30
|
||||
merge_gap_sec: 10
|
||||
min_duration_sec: 1
|
||||
|
||||
resp_amp_change:
|
||||
mav_calc_window_sec: 1
|
||||
threshold_amplitude: 0.25
|
||||
threshold_energy: 0.4
|
||||
|
||||
|
||||
bcg:
|
||||
downsample_fs: 100
|
||||
|
||||
bcg_filter:
|
||||
filter_type: bandpass
|
||||
low_cut: 1
|
||||
high_cut: 10
|
||||
order: 10
|
||||
|
||||
bcg_low_amp:
|
||||
window_size_sec: 1
|
||||
stride_sec:
|
||||
amplitude_threshold: 8
|
||||
merge_gap_sec: 20
|
||||
min_duration_sec: 3
|
||||
|
||||
|
||||
bcg_movement:
|
||||
window_size_sec: 2
|
||||
stride_sec:
|
||||
merge_gap_sec: 20
|
||||
min_duration_sec: 4
|
||||
|
||||
67
dataset_tools/resp_pair_copy.py
Normal file
67
dataset_tools/resp_pair_copy.py
Normal file
@ -0,0 +1,67 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
project_root_path = Path(__file__).resolve().parent.parent
|
||||
import utils
|
||||
import shutil
|
||||
|
||||
def copy_one_resp_pair(one_id):
|
||||
sync_type = "Sync"
|
||||
|
||||
org_bcg_file_path = sync_bcg_path / f"{one_id}"
|
||||
dest_bcg_file_path = pair_file_path / f"{one_id}"
|
||||
dest_bcg_file_path.mkdir(parents=True, exist_ok=True)
|
||||
if not list(org_bcg_file_path.glob("OrgBCG_Sync_*.txt")):
|
||||
if not list(org_bcg_file_path.glob("OrgBCG_RoughCut_*.txt")):
|
||||
print(f"No OrgBCG files found for ID {one_id}.")
|
||||
return
|
||||
else:
|
||||
sync_type = "RoughCut"
|
||||
print(f"Using RoughCut files for ID {one_id}.")
|
||||
|
||||
|
||||
for file in org_bcg_file_path.glob(f"OrgBCG_{sync_type}_*.txt"):
|
||||
shutil.copyfile(file, dest_bcg_file_path / f"{one_id}_{file.name}".replace("_RoughCut", "").replace("_Sync", ""))
|
||||
psg_file_path = sync_psg_path / f"{one_id}"
|
||||
dest_psg_file_path = pair_file_path / f"{one_id}"
|
||||
dest_psg_file_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 检查上面的文件是否存在
|
||||
psg_file_patterns = [
|
||||
f"5_class_{sync_type}_*.txt",
|
||||
f"Effort Abd_{sync_type}_*.txt",
|
||||
f"Effort Tho_{sync_type}_*.txt",
|
||||
f"Flow P_{sync_type}_*.txt",
|
||||
f"Flow T_{sync_type}_*.txt",
|
||||
f"SA Label_Sync.csv",
|
||||
f"SpO2_{sync_type}_*.txt"
|
||||
]
|
||||
for pattern in psg_file_patterns:
|
||||
if not list(psg_file_path.glob(pattern)):
|
||||
print(f"No PSG files found for ID {one_id} with pattern {pattern}.")
|
||||
return
|
||||
for pattern in psg_file_patterns:
|
||||
for file in psg_file_path.glob(pattern):
|
||||
shutil.copyfile(file, dest_psg_file_path / f"{one_id}_{file.name.replace('_RoughCut', '').replace('_Sync', '')}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
yaml_path = project_root_path / "dataset_config/RESP_PAIR_ZD5Y_config.yaml"
|
||||
|
||||
conf = utils.load_dataset_conf(yaml_path)
|
||||
select_ids = conf["select_ids"]
|
||||
root_path = Path(conf["root_path"])
|
||||
sync_bcg_path = root_path / "OrgBCG_Aligned"
|
||||
sync_psg_path = root_path / "PSG_Aligned"
|
||||
pair_file_path = Path(conf["pair_file_path"])
|
||||
|
||||
# copy_one_resp_pair(961)
|
||||
|
||||
for samp_id in select_ids:
|
||||
print(f"Processing {samp_id}...")
|
||||
copy_one_resp_pair(samp_id)
|
||||
100
dataset_tools/shhs_annotations_check.py
Normal file
100
dataset_tools/shhs_annotations_check.py
Normal file
@ -0,0 +1,100 @@
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from lxml import etree
|
||||
from tqdm import tqdm
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def main():
|
||||
# 设定目标文件夹路径,你可以修改这里的路径,或者运行脚本时手动输入
|
||||
# 默认为当前目录 '.'
|
||||
# target_dir = "/mnt/disk_wd/marques_dataset/shhs/polysomnography/annotations-events-nsrr/shhs1"
|
||||
target_dir = "/mnt/disk_wd/marques_dataset/shhs/polysomnography/annotations-events-nsrr/shhs2"
|
||||
|
||||
folder_path = Path(target_dir)
|
||||
|
||||
if not folder_path.exists():
|
||||
print(f"错误: 路径 '{folder_path}' 不存在。")
|
||||
return
|
||||
|
||||
# 1. 获取所有 XML 文件 (扁平结构,不递归子目录)
|
||||
xml_files = list(folder_path.glob("*.xml"))
|
||||
total_files = len(xml_files)
|
||||
|
||||
if total_files == 0:
|
||||
print(f"在 '{folder_path}' 中没有找到 XML 文件。")
|
||||
return
|
||||
|
||||
print(f"找到 {total_files} 个 XML 文件,准备开始处理...")
|
||||
|
||||
# 用于统计 (EventType, EventConcept) 组合的计数器
|
||||
stats_counter = Counter()
|
||||
|
||||
# 2. 遍历文件,使用 tqdm 显示进度条
|
||||
for xml_file in tqdm(xml_files, desc="Processing XMLs", unit="file"):
|
||||
try:
|
||||
# 使用 lxml 解析
|
||||
tree = etree.parse(str(xml_file))
|
||||
root = tree.getroot()
|
||||
|
||||
# 3. 定位到 ScoredEvent 节点
|
||||
# SHHS XML 结构通常是: PSGAnnotation -> ScoredEvents -> ScoredEvent
|
||||
# 我们直接查找所有的 ScoredEvent 节点
|
||||
events = root.findall(".//ScoredEvent")
|
||||
|
||||
for event in events:
|
||||
# 提取 EventType
|
||||
type_node = event.find("EventType")
|
||||
# 处理节点不存在或文本为空的情况
|
||||
e_type = type_node.text.strip() if (type_node is not None and type_node.text) else "N/A"
|
||||
|
||||
# 提取 EventConcept
|
||||
concept_node = event.find("EventConcept")
|
||||
e_concept = concept_node.text.strip() if (concept_node is not None and concept_node.text) else "N/A"
|
||||
|
||||
# 4. 组合并计数
|
||||
# 组合键为元组 (EventType, EventConcept)
|
||||
key = (e_type, e_concept)
|
||||
stats_counter[key] += 1
|
||||
|
||||
except etree.XMLSyntaxError:
|
||||
print(f"\n[警告] 文件格式错误,跳过: {xml_file.name}")
|
||||
except Exception as e:
|
||||
print(f"\n[错误] 处理文件 {xml_file.name} 时出错: {e}")
|
||||
|
||||
# 5. 打印结果到终端
|
||||
if stats_counter:
|
||||
# --- 动态计算列宽 ---
|
||||
# 获取所有 EventType 的最大长度,默认长度 9
|
||||
max_type_width = max((len(k[0]) for k in stats_counter.keys()), default=9)
|
||||
max_type_width = max(max_type_width, 9)
|
||||
|
||||
# 获取所有 EventConcept 的最大长度,默认长度 12
|
||||
max_conc_width = max((len(k[1]) for k in stats_counter.keys()), default=12)
|
||||
max_conc_width = max(max_conc_width, 12)
|
||||
|
||||
# 计算表格总宽度
|
||||
total_line_width = max_type_width + max_conc_width + 10 + 6
|
||||
|
||||
print("\n" + "=" * total_line_width)
|
||||
print(f"{'EventType':<{max_type_width}} | {'EventConcept':<{max_conc_width}} | {'Count':>10}")
|
||||
print("-" * total_line_width)
|
||||
|
||||
# --- 修改处:按名称排序 ---
|
||||
# sorted() 默认会对元组 (EventType, EventConcept) 进行字典序排序
|
||||
# 即先按 EventType A-Z 排序,再按 EventConcept A-Z 排序
|
||||
for (e_type, e_concept), count in sorted(stats_counter.items()):
|
||||
print(f"{e_type:<{max_type_width}} | {e_concept:<{max_conc_width}} | {count:>10}")
|
||||
|
||||
print("=" * total_line_width)
|
||||
|
||||
else:
|
||||
print("\n未提取到任何事件数据。")
|
||||
|
||||
print("=" * 90)
|
||||
print(f"统计完成。共扫描 {total_files} 个文件。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
2
draw_tools/__init__.py
Normal file
2
draw_tools/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .draw_statics import draw_signal_with_mask, draw_psg_signal
|
||||
from .draw_label import draw_psg_bcg_label,draw_psg_label
|
||||
337
draw_tools/draw_label.py
Normal file
337
draw_tools/draw_label.py
Normal file
@ -0,0 +1,337 @@
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.gridspec import GridSpec
|
||||
from matplotlib.colors import PowerNorm
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
import seaborn as sns
|
||||
import numpy as np
|
||||
from tqdm.rich import tqdm
|
||||
import utils
|
||||
import gc
|
||||
# 添加with_prediction参数
|
||||
|
||||
psg_bcg_chn_name2ax = {
|
||||
"SpO2": 0,
|
||||
"Flow T": 1,
|
||||
"Flow P": 2,
|
||||
"Effort Tho": 3,
|
||||
"Effort Abd": 4,
|
||||
"HR": 5,
|
||||
"resp": 6,
|
||||
"bcg": 7,
|
||||
"Stage": 8,
|
||||
"resp_twinx": 9,
|
||||
"bcg_twinx": 10,
|
||||
}
|
||||
|
||||
psg_chn_name2ax = {
|
||||
"SpO2": 0,
|
||||
"Flow T": 1,
|
||||
"Flow P": 2,
|
||||
"Effort Tho": 3,
|
||||
"Effort Abd": 4,
|
||||
"Effort": 5,
|
||||
"HR": 6,
|
||||
"RRI": 7,
|
||||
"Stage": 8,
|
||||
}
|
||||
|
||||
|
||||
resp_chn_name2ax = {
|
||||
"resp": 0,
|
||||
"bcg": 1,
|
||||
}
|
||||
|
||||
|
||||
def create_psg_bcg_figure():
|
||||
fig = plt.figure(figsize=(12, 8), dpi=200)
|
||||
gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1])
|
||||
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
|
||||
axes = []
|
||||
for i in range(9):
|
||||
ax = fig.add_subplot(gs[i])
|
||||
axes.append(ax)
|
||||
|
||||
axes[psg_bcg_chn_name2ax["SpO2"]].grid(True)
|
||||
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[psg_bcg_chn_name2ax["SpO2"]].set_ylim((85, 100))
|
||||
axes[psg_bcg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[psg_bcg_chn_name2ax["Flow T"]].grid(True)
|
||||
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[psg_bcg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[psg_bcg_chn_name2ax["Flow P"]].grid(True)
|
||||
# axes[2].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[psg_bcg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[psg_bcg_chn_name2ax["Effort Tho"]].grid(True)
|
||||
# axes[3].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[psg_bcg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[psg_bcg_chn_name2ax["Effort Abd"]].grid(True)
|
||||
# axes[4].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[psg_bcg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[psg_bcg_chn_name2ax["HR"]].grid(True)
|
||||
axes[psg_bcg_chn_name2ax["HR"]].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[psg_bcg_chn_name2ax["resp"]].grid(True)
|
||||
axes[psg_bcg_chn_name2ax["resp"]].tick_params(axis='x', colors="white")
|
||||
axes.append(axes[psg_bcg_chn_name2ax["resp"]].twinx())
|
||||
|
||||
axes[psg_bcg_chn_name2ax["bcg"]].grid(True)
|
||||
# axes[5].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[psg_bcg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white")
|
||||
axes.append(axes[psg_bcg_chn_name2ax["bcg"]].twinx())
|
||||
|
||||
axes[psg_bcg_chn_name2ax["Stage"]].grid(True)
|
||||
# axes[7].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
|
||||
return fig, axes
|
||||
|
||||
|
||||
def create_psg_figure():
|
||||
fig = plt.figure(figsize=(12, 8), dpi=200)
|
||||
gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 1, 1, 1])
|
||||
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
|
||||
axes = []
|
||||
for i in range(9):
|
||||
ax = fig.add_subplot(gs[i])
|
||||
axes.append(ax)
|
||||
|
||||
axes[psg_chn_name2ax["SpO2"]].grid(True)
|
||||
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[psg_chn_name2ax["SpO2"]].set_ylim((85, 100))
|
||||
axes[psg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[psg_chn_name2ax["Flow T"]].grid(True)
|
||||
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[psg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[psg_chn_name2ax["Flow P"]].grid(True)
|
||||
# axes[2].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[psg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[psg_chn_name2ax["Effort Tho"]].grid(True)
|
||||
# axes[3].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[psg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[psg_chn_name2ax["Effort Abd"]].grid(True)
|
||||
# axes[4].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[psg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[psg_chn_name2ax["Effort"]].grid(True)
|
||||
axes[psg_chn_name2ax["Effort"]].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[psg_chn_name2ax["HR"]].grid(True)
|
||||
axes[psg_chn_name2ax["HR"]].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[psg_chn_name2ax["RRI"]].grid(True)
|
||||
axes[psg_chn_name2ax["RRI"]].tick_params(axis='x', colors="white")
|
||||
|
||||
|
||||
axes[psg_chn_name2ax["Stage"]].grid(True)
|
||||
|
||||
|
||||
return fig, axes
|
||||
|
||||
def create_resp_figure():
|
||||
fig = plt.figure(figsize=(12, 6), dpi=100)
|
||||
gs = GridSpec(2, 1, height_ratios=[3, 2])
|
||||
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
|
||||
axes = []
|
||||
for i in range(2):
|
||||
ax = fig.add_subplot(gs[i])
|
||||
axes.append(ax)
|
||||
|
||||
axes[0].grid(True)
|
||||
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[0].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[1].grid(True)
|
||||
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[1].tick_params(axis='x', colors="white")
|
||||
|
||||
return fig, axes
|
||||
|
||||
|
||||
def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None,
|
||||
event_codes: list[int] = None, multi_labels=None, ax2: Axes = None):
|
||||
signal_fs = signal_data["fs"]
|
||||
chn_signal = signal_data["data"]
|
||||
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs)
|
||||
ax.plot(time_axis, chn_signal[segment_start * signal_fs:segment_end * signal_fs], color='black',
|
||||
label=signal_data["name"])
|
||||
if event_mask is not None:
|
||||
if multi_labels is None and event_codes is not None:
|
||||
for event_code in event_codes:
|
||||
mask = event_mask[segment_start:segment_end].repeat(signal_fs) == event_code
|
||||
y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * mask).astype(float)
|
||||
np.place(y, y == 0, np.nan)
|
||||
ax.plot(time_axis, y, color=utils.ColorCycle[event_code])
|
||||
elif multi_labels == "resp" and event_codes is not None:
|
||||
ax.set_ylim(-6, 6)
|
||||
# 建立第二个y轴坐标
|
||||
ax2.cla()
|
||||
ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
|
||||
color='blue', alpha=0.8, label='Low Amplitude Mask')
|
||||
ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
|
||||
color='orange', alpha=0.8, label='Movement Mask')
|
||||
ax2.plot(time_axis, event_mask["Resp_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3,
|
||||
color='green', alpha=0.8, label='Amplitude Change Mask')
|
||||
for event_code in event_codes:
|
||||
sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code
|
||||
score_mask = event_mask["SA_Score"][segment_start:segment_end].repeat(signal_fs)
|
||||
# y = (sa_mask * score_mask).astype(float)
|
||||
y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * sa_mask).astype(float)
|
||||
np.place(y, y == 0, np.nan)
|
||||
ax.plot(time_axis, y, color=utils.ColorCycle[event_code])
|
||||
ax2.plot(time_axis, score_mask, color="orange")
|
||||
ax2.set_ylim(-4, 5)
|
||||
elif multi_labels == "bcg" and event_codes is not None:
|
||||
# 建立第二个y轴坐标
|
||||
ax2.cla()
|
||||
ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
|
||||
color='blue', alpha=0.8, label='Low Amplitude Mask')
|
||||
ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
|
||||
color='orange', alpha=0.8, label='Movement Mask')
|
||||
ax2.plot(time_axis, event_mask["BCG_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3,
|
||||
color='green', alpha=0.8, label='Amplitude Change Mask')
|
||||
|
||||
ax2.set_ylim(-4, 4)
|
||||
|
||||
ax.set_ylabel("Amplitude")
|
||||
ax.legend(loc=1)
|
||||
|
||||
|
||||
def plt_stage_on_ax(ax, stage_data, segment_start, segment_end):
|
||||
stage_signal = stage_data["data"]
|
||||
stage_fs = stage_data["fs"]
|
||||
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * stage_fs)
|
||||
ax.plot(time_axis, stage_signal[segment_start * stage_fs:segment_end * stage_fs], color='black', label=stage_data["name"])
|
||||
ax.set_ylim(0, 6)
|
||||
ax.set_yticks([1, 2, 3, 4, 5])
|
||||
ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"])
|
||||
ax.set_ylabel("Stage")
|
||||
ax.legend(loc=1)
|
||||
|
||||
|
||||
def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end):
|
||||
spo2_signal = spo2_data["data"]
|
||||
spo2_fs = spo2_data["fs"]
|
||||
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * spo2_fs)
|
||||
ax.plot(time_axis, spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs], color='black', label=spo2_data["name"])
|
||||
|
||||
if spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() < 85:
|
||||
ax.set_ylim((spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() - 5, 100))
|
||||
else:
|
||||
ax.set_ylim((85, 100))
|
||||
ax.set_ylabel("SpO2 (%)")
|
||||
ax.legend(loc=1)
|
||||
|
||||
|
||||
def score_mask2alpha(score_mask):
|
||||
alpha_mask = np.zeros_like(score_mask, dtype=float)
|
||||
alpha_mask[score_mask <= 0] = 0
|
||||
alpha_mask[score_mask == 1] = 0.9
|
||||
alpha_mask[score_mask == 2] = 0.6
|
||||
alpha_mask[score_mask == 3] = 0.1
|
||||
return alpha_mask
|
||||
|
||||
|
||||
def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None, verbose=True,
|
||||
multi_p=None, multi_task_id=None):
|
||||
if save_path is not None:
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for mask in event_mask.keys():
|
||||
if mask.startswith("Resp_") or mask.startswith("BCG_"):
|
||||
event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0)
|
||||
|
||||
event_mask["SA_Score"] = utils.none_to_nan_mask(event_mask["SA_Score"], 0)
|
||||
|
||||
# event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"])
|
||||
# event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0)
|
||||
|
||||
fig, axes = create_psg_bcg_figure()
|
||||
for i, (segment_start, segment_end) in enumerate(segment_list):
|
||||
for ax in axes:
|
||||
ax.cla()
|
||||
|
||||
plt_spo2_on_ax(axes[psg_bcg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end)
|
||||
plt_stage_on_ax(axes[psg_bcg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end)
|
||||
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end,
|
||||
psg_label, event_codes=[1, 2, 3, 4])
|
||||
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end,
|
||||
psg_label, event_codes=[1, 2, 3, 4])
|
||||
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end,
|
||||
psg_label, event_codes=[1, 2, 3, 4])
|
||||
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end,
|
||||
psg_label, event_codes=[1, 2, 3, 4])
|
||||
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
|
||||
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
|
||||
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4],
|
||||
ax2=axes[psg_bcg_chn_name2ax["resp_twinx"]])
|
||||
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
|
||||
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4],
|
||||
ax2=axes[psg_bcg_chn_name2ax["bcg_twinx"]])
|
||||
|
||||
|
||||
if save_path is not None:
|
||||
fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png")
|
||||
# print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
|
||||
|
||||
if multi_p is not None:
|
||||
multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"}
|
||||
|
||||
plt.close(fig)
|
||||
plt.close('all')
|
||||
gc.collect()
|
||||
|
||||
|
||||
def draw_psg_label(psg_data, psg_label, segment_list, save_path=None, verbose=True,
|
||||
multi_p=None, multi_task_id=None):
|
||||
if save_path is not None:
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
if multi_p is None:
|
||||
# 遍历psg_data中所有数据的长度
|
||||
for i in range(len(psg_data.keys())):
|
||||
chn_name = list(psg_data.keys())[i]
|
||||
print(f"{chn_name} data length: {len(psg_data[chn_name]['data'])}, fs: {psg_data[chn_name]['fs']}")
|
||||
# psg_label的长度
|
||||
print(f"psg_label length: {len(psg_label)}")
|
||||
|
||||
fig, axes = create_psg_figure()
|
||||
for i, (segment_start, segment_end) in enumerate(segment_list):
|
||||
for ax in axes:
|
||||
ax.cla()
|
||||
|
||||
plt_spo2_on_ax(axes[psg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end)
|
||||
plt_stage_on_ax(axes[psg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end)
|
||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end,
|
||||
psg_label, event_codes=[1, 2, 3, 4])
|
||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end,
|
||||
psg_label, event_codes=[1, 2, 3, 4])
|
||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end,
|
||||
psg_label, event_codes=[1, 2, 3, 4])
|
||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end,
|
||||
psg_label, event_codes=[1, 2, 3, 4])
|
||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort"]], psg_data["Effort"], segment_start, segment_end,
|
||||
psg_label, event_codes=[1, 2, 3, 4])
|
||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
|
||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["RRI"]], psg_data["RRI"], segment_start, segment_end)
|
||||
|
||||
if save_path is not None:
|
||||
fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png")
|
||||
# print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
|
||||
|
||||
if multi_p is not None:
|
||||
multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"}
|
||||
plt.close(fig)
|
||||
plt.close('all')
|
||||
gc.collect()
|
||||
|
||||
|
||||
371
draw_tools/draw_statics.py
Normal file
371
draw_tools/draw_statics.py
Normal file
@ -0,0 +1,371 @@
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.gridspec import GridSpec
|
||||
from matplotlib.colors import PowerNorm
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
import seaborn as sns
|
||||
import numpy as np
|
||||
|
||||
|
||||
def draw_ax_confusion_matrix(ax:Axes, confusion_matrix, segment_count_matrix, confusion_matrix_percent,
|
||||
valid_signal_length, total_duration, time_labels, amp_labels, signal_type=''):
|
||||
# 创建用于热图注释的文本矩阵
|
||||
text_matrix = np.empty((len(amp_labels), len(time_labels)), dtype=object)
|
||||
percent_matrix = np.zeros((len(amp_labels), len(time_labels)))
|
||||
|
||||
# 填充文本矩阵和百分比矩阵
|
||||
for i in range(len(amp_labels)):
|
||||
for j in range(len(time_labels)):
|
||||
val = confusion_matrix.iloc[i, j]
|
||||
segment_count = segment_count_matrix[i, j]
|
||||
percent = confusion_matrix_percent.iloc[i, j]
|
||||
text_matrix[i, j] = f"[{int(segment_count)}]{int(val)}\n({percent:.2f}%)"
|
||||
percent_matrix[i, j] = percent
|
||||
|
||||
# 绘制热图,调整颜色条位置
|
||||
sns_heatmap = sns.heatmap(percent_matrix, annot=text_matrix, fmt='',
|
||||
xticklabels=time_labels, yticklabels=amp_labels,
|
||||
cmap='YlGnBu', ax=ax, vmin=0, vmax=100,
|
||||
norm=PowerNorm(gamma=0.5, vmin=0, vmax=100),
|
||||
cbar_kws={'label': '百分比 (%)', 'shrink': 0.6, 'pad': 0.15},
|
||||
# annot_kws={'fontsize': 12}
|
||||
)
|
||||
|
||||
# 添加行统计(右侧)
|
||||
row_sums = confusion_matrix['总计']
|
||||
row_percents = confusion_matrix_percent['总计']
|
||||
ax.text(len(time_labels) + 1, -0.5, "各幅值时长\n(有效区间百分比%)", ha='center', va='center')
|
||||
for i, (val, perc) in enumerate(zip(row_sums, row_percents)):
|
||||
ax.text(len(time_labels) + 0.5, i + 0.5, f"{int(val)}\n({perc:.2f}%)",
|
||||
ha='center', va='center')
|
||||
|
||||
# 添加列统计(底部)
|
||||
col_sums = segment_count_matrix.sum(axis=0)
|
||||
col_percents = confusion_matrix.sum(axis=0) / total_duration * 100
|
||||
ax.text(-1, len(amp_labels) + 0.5, "[各时长片段数]\n(有效区间百分比%)", ha='center', va='center', rotation=0)
|
||||
for j, (val, perc) in enumerate(zip(col_sums, col_percents)):
|
||||
ax.text(j + 0.5, len(amp_labels) + 0.5, f"[{int(val)}]\n({perc:.2f}%)",
|
||||
ha='center', va='center', rotation=0)
|
||||
|
||||
# 将x轴坐标移到顶部
|
||||
ax.xaxis.set_label_position('top')
|
||||
ax.xaxis.tick_top()
|
||||
|
||||
# 设置标题和标签
|
||||
ax.set_title('幅值-时长统计矩阵', pad=40)
|
||||
ax.set_xlabel('持续时间区间 (秒)', labelpad=10)
|
||||
ax.set_ylabel('幅值区间')
|
||||
|
||||
# 设置坐标轴标签水平显示
|
||||
ax.set_xticklabels(time_labels, rotation=0)
|
||||
ax.set_yticklabels(amp_labels, rotation=0)
|
||||
|
||||
# 调整颜色条位置
|
||||
cbar = sns_heatmap.collections[0].colorbar
|
||||
cbar.ax.yaxis.set_label_position('right')
|
||||
|
||||
# 添加图例说明
|
||||
ax.text(-2, -1, "热图内:\n[片段数]时长\n(有效区间百分比%)",
|
||||
ha='left', va='top', bbox=dict(facecolor='none', edgecolor='black', alpha=0.5))
|
||||
|
||||
# 总计
|
||||
# 总片段数
|
||||
total_segments = segment_count_matrix.sum()
|
||||
# 有效总市场占比
|
||||
total_percent = valid_signal_length / total_duration * 100
|
||||
ax.text(len(time_labels) + 0.5, len(amp_labels) + 0.5, f"[{int(total_segments)}]{valid_signal_length}\n({total_percent:.2f}%)",
|
||||
ha='center', va='center')
|
||||
|
||||
|
||||
|
||||
def draw_ax_amp(ax, signal_name, original_times, origin_signal, no_movement_signal, mav_values,
|
||||
movement_position_list, low_amp_position_list, signal_second_length, aml_list=None):
|
||||
# 绘制信号线
|
||||
ax.plot(original_times, origin_signal, 'k-', linewidth=1, alpha=0.7)
|
||||
|
||||
# 添加红色和蓝色的axvspan区域
|
||||
for start, end in movement_position_list:
|
||||
if start < len(original_times) and end < len(original_times):
|
||||
ax.axvspan(start, end, color='red', alpha=0.3)
|
||||
|
||||
for start, end in low_amp_position_list:
|
||||
if start < len(original_times) and end < len(original_times):
|
||||
ax.axvspan(start, end, color='blue', alpha=0.2)
|
||||
|
||||
# 如果存在AML列表,绘制水平线
|
||||
if aml_list is not None:
|
||||
color_map = ['red', 'orange', 'green']
|
||||
for i, aml in enumerate(aml_list):
|
||||
ax.hlines(aml, 0, signal_second_length, color=color_map[min(i, 2)], linestyle='dashed', linewidth=2, alpha=0.5, label=f'{aml} aml')
|
||||
|
||||
|
||||
ax.plot(np.linspace(0, len(mav_values), len(mav_values)), mav_values, color='blue', linewidth=2, alpha=0.4, label='2sMAV')
|
||||
|
||||
# 设置Y轴范围
|
||||
ax.set_ylim((-2000, 2000))
|
||||
|
||||
# 创建表示不同颜色区域的图例
|
||||
red_patch = mpatches.Patch(color='red', alpha=0.2, label='Movement Area')
|
||||
blue_patch = mpatches.Patch(color='blue', alpha=0.2, label='Low Amplitude Area')
|
||||
|
||||
# 添加新的图例项,并保留原来的图例项
|
||||
handles, labels = ax.get_legend_handles_labels() # 获取原有图例
|
||||
ax.legend(handles=[red_patch, blue_patch] + handles,
|
||||
labels=['Movement Area', 'Low Amplitude Area'] + labels,
|
||||
loc='upper right',
|
||||
bbox_to_anchor=(1, 1.4),
|
||||
framealpha=0.5)
|
||||
|
||||
# 设置标题和标签
|
||||
ax.set_title(f'{signal_name} Signal')
|
||||
ax.set_ylabel('Amplitude')
|
||||
ax.set_xlabel('Time (s)')
|
||||
|
||||
# 启用网格
|
||||
ax.grid(True, linestyle='--', alpha=0.7)
|
||||
|
||||
|
||||
def draw_signal_metrics(bcg_origin_signal, resp_origin_signal, bcg_no_movement_signal, resp_no_movement_signal,
|
||||
bcg_sampling_rate, resp_sampling_rate, bcg_movement_position_list, bcg_low_amp_position_list,
|
||||
resp_movement_position_list, resp_low_amp_position_list,
|
||||
bcg_mav_values, resp_mav_values, bcg_statistic_info, resp_statistic_info,
|
||||
signal_info, show=False, save_path=None):
|
||||
|
||||
# 创建图像
|
||||
fig = plt.figure(figsize=(18, 10))
|
||||
|
||||
gs = GridSpec(2, 2, height_ratios=[2, 2], width_ratios=[4, 2], hspace=0.5)
|
||||
|
||||
signal_second_length = len(bcg_origin_signal) // bcg_sampling_rate
|
||||
bcg_origin_times = np.linspace(0, signal_second_length, len(bcg_origin_signal))
|
||||
resp_origin_times = np.linspace(0, signal_second_length, len(resp_origin_signal))
|
||||
# 子图 1:原始信号
|
||||
ax1 = fig.add_subplot(gs[0])
|
||||
draw_ax_amp(ax=ax1, signal_name='BCG', original_times=bcg_origin_times, origin_signal=bcg_origin_signal,
|
||||
no_movement_signal=bcg_no_movement_signal, mav_values=bcg_mav_values,
|
||||
movement_position_list=bcg_movement_position_list, low_amp_position_list=bcg_low_amp_position_list,
|
||||
signal_second_length=signal_second_length, aml_list=[200, 500, 1000])
|
||||
|
||||
ax2 = fig.add_subplot(gs[1])
|
||||
param_names = ['confusion_matrix', 'segment_count_matrix', 'confusion_matrix_percent',
|
||||
'valid_signal_length', 'total_duration', 'time_labels', 'amp_labels']
|
||||
params = dict(zip(param_names, bcg_statistic_info))
|
||||
params['ax'] = ax2
|
||||
draw_ax_confusion_matrix(**params)
|
||||
|
||||
ax3 = fig.add_subplot(gs[2], sharex=ax1)
|
||||
draw_ax_amp(ax=ax3, signal_name='RSEP', original_times=resp_origin_times, origin_signal=resp_origin_signal,
|
||||
no_movement_signal=resp_no_movement_signal, mav_values=resp_mav_values,
|
||||
movement_position_list=resp_movement_position_list, low_amp_position_list=resp_low_amp_position_list,
|
||||
signal_second_length=signal_second_length, aml_list=[100, 300, 500])
|
||||
ax4 = fig.add_subplot(gs[3])
|
||||
params = dict(zip(param_names, resp_statistic_info))
|
||||
params['ax'] = ax4
|
||||
draw_ax_confusion_matrix(**params)
|
||||
|
||||
# 全局标题
|
||||
fig.suptitle(f'{signal_info} Signal Metrics', fontsize=16, x=0.35, y=0.95)
|
||||
|
||||
if save_path is not None:
|
||||
# 保存图像
|
||||
plt.savefig(save_path, dpi=300)
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, resp_fs, bcg_fs,
|
||||
signal_disable_mask, resp_low_amp_mask, resp_movement_mask, resp_change_mask,
|
||||
resp_sa_mask, bcg_low_amp_mask, bcg_movement_mask, bcg_change_mask, show=False, save_path=None
|
||||
):
|
||||
# 第一行绘制去工频噪声的原始信号,右侧为不可用区间标记,左侧为信号幅值纵坐标
|
||||
# 第二行绘制呼吸分量,右侧低幅值、高幅值、幅值变换标记、SA标签,左侧为呼吸幅值纵坐标
|
||||
# 第三行绘制心冲击分量,右侧为低幅值、高幅值、幅值变换标记、,左侧为心冲击幅值纵坐标
|
||||
# mask为None,则生成全Nan掩码
|
||||
def _none_to_nan_mask(mask, ref):
|
||||
if mask is None:
|
||||
return np.full_like(ref, np.nan)
|
||||
else:
|
||||
# 将mask中的0替换为nan,其他的保持
|
||||
mask = np.where(mask == 0, np.nan, mask)
|
||||
return mask
|
||||
|
||||
signal_disable_mask = _none_to_nan_mask(signal_disable_mask, signal_data)
|
||||
resp_low_amp_mask = _none_to_nan_mask(resp_low_amp_mask, resp_data)
|
||||
resp_movement_mask = _none_to_nan_mask(resp_movement_mask, resp_data)
|
||||
resp_change_mask = _none_to_nan_mask(resp_change_mask, resp_data)
|
||||
resp_sa_mask = _none_to_nan_mask(resp_sa_mask, resp_data)
|
||||
bcg_low_amp_mask = _none_to_nan_mask(bcg_low_amp_mask, bcg_data)
|
||||
bcg_movement_mask = _none_to_nan_mask(bcg_movement_mask, bcg_data)
|
||||
bcg_change_mask = _none_to_nan_mask(bcg_change_mask, bcg_data)
|
||||
|
||||
|
||||
fig = plt.figure(figsize=(18, 10))
|
||||
ax0 = fig.add_subplot(3, 1, 1)
|
||||
ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue')
|
||||
ax0.set_title(f'Sample {samp_id} - Raw Signal Data')
|
||||
ax0.set_ylabel('Amplitude')
|
||||
# ax0.set_xticklabels([])
|
||||
|
||||
ax0_twin = ax0.twinx()
|
||||
ax0_twin.plot(np.linspace(0, len(signal_disable_mask), len(signal_disable_mask)), signal_disable_mask,
|
||||
color='red', alpha=0.5)
|
||||
ax0_twin.autoscale(enable=False, axis='y', tight=True)
|
||||
ax0_twin.set_ylim((-2, 2))
|
||||
ax0_twin.set_ylabel('Disable Mask')
|
||||
ax0_twin.set_yticks([0, 1])
|
||||
ax0_twin.set_yticklabels(['Enabled', 'Disabled'])
|
||||
ax0_twin.grid(False)
|
||||
ax0_twin.legend(['Disable Mask'], loc='upper right')
|
||||
|
||||
|
||||
ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
|
||||
ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='gray', alpha=0.5)
|
||||
resp_data_no_movement = resp_data.copy()
|
||||
resp_data_no_movement[resp_movement_mask.repeat(int(resp_fs)) == 1] = np.nan
|
||||
ax1.plot(np.linspace(0, len(resp_data_no_movement) // resp_fs, len(resp_data_no_movement)), resp_data_no_movement, color='orange')
|
||||
ax1.set_ylabel('Amplitude')
|
||||
# ax1.set_xticklabels([])
|
||||
ax1_twin = ax1.twinx()
|
||||
ax1_twin.plot(np.linspace(0, len(resp_low_amp_mask), len(resp_low_amp_mask)), resp_low_amp_mask*-1,
|
||||
color='blue', alpha=0.5, label='Low Amplitude Mask')
|
||||
ax1_twin.plot(np.linspace(0, len(resp_movement_mask), len(resp_movement_mask)), resp_movement_mask*-2,
|
||||
color='red', alpha=0.5, label='Movement Mask')
|
||||
ax1_twin.plot(np.linspace(0, len(resp_change_mask), len(resp_change_mask)), resp_change_mask*-3,
|
||||
color='green', alpha=0.5, label='Amplitude Change Mask')
|
||||
ax1_twin.plot(np.linspace(0, len(resp_sa_mask), len(resp_sa_mask)), resp_sa_mask,
|
||||
color='purple', alpha=0.5, label='SA Mask')
|
||||
ax1_twin.autoscale(enable=False, axis='y', tight=True)
|
||||
ax1_twin.set_ylim((-4, 5))
|
||||
# ax1_twin.set_ylabel('Resp Masks')
|
||||
# ax1_twin.set_yticks([0, 1])
|
||||
# ax1_twin.set_yticklabels(['No', 'Yes'])
|
||||
ax1_twin.grid(False)
|
||||
|
||||
ax1_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask', 'SA Mask'], loc='upper right')
|
||||
ax1.set_title(f'Sample {samp_id} - Respiration Component')
|
||||
|
||||
|
||||
|
||||
ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
|
||||
ax2.plot(np.linspace(0, len(bcg_data) // bcg_fs, len(bcg_data)), bcg_data, color='green')
|
||||
ax2.set_ylabel('Amplitude')
|
||||
ax2.set_xlabel('Time (s)')
|
||||
ax2_twin = ax2.twinx()
|
||||
ax2_twin.plot(np.linspace(0, len(bcg_low_amp_mask), len(bcg_low_amp_mask)), bcg_low_amp_mask*-1,
|
||||
color='blue', alpha=0.5, label='Low Amplitude Mask')
|
||||
ax2_twin.plot(np.linspace(0, len(bcg_movement_mask), len(bcg_movement_mask)), bcg_movement_mask*-2,
|
||||
color='red', alpha=0.5, label='Movement Mask')
|
||||
ax2_twin.plot(np.linspace(0, len(bcg_change_mask), len(bcg_change_mask)), bcg_change_mask*-3,
|
||||
color='green', alpha=0.5, label='Amplitude Change Mask')
|
||||
ax2_twin.autoscale(enable=False, axis='y', tight=True)
|
||||
ax2_twin.set_ylim((-4, 2))
|
||||
ax2_twin.set_ylabel('BCG Masks')
|
||||
# ax2_twin.set_yticks([0, 1])
|
||||
# ax2_twin.set_yticklabels(['No', 'Yes'])
|
||||
ax2_twin.grid(False)
|
||||
ax2_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask'], loc='upper right')
|
||||
# ax2.set_title(f'Sample {samp_id} - BCG Component')
|
||||
|
||||
ax0_twin._lim_lock = False
|
||||
ax1_twin._lim_lock = False
|
||||
ax2_twin._lim_lock = False
|
||||
|
||||
def on_lims_change(event_ax):
|
||||
if getattr(event_ax, '_lim_lock', False):
|
||||
return
|
||||
try:
|
||||
event_ax._lim_lock = True
|
||||
|
||||
if event_ax == ax0_twin:
|
||||
# 重新锁定 ax1 的 Y 轴范围
|
||||
ax0_twin.set_ylim(-2, 2)
|
||||
elif event_ax == ax1_twin:
|
||||
ax1_twin.set_ylim(-4, 5)
|
||||
elif event_ax == ax2_twin:
|
||||
ax2_twin.set_ylim(-4, 2)
|
||||
|
||||
finally:
|
||||
event_ax._lim_lock = False
|
||||
|
||||
|
||||
ax0_twin.callbacks.connect('ylim_changed', on_lims_change)
|
||||
ax1_twin.callbacks.connect('ylim_changed', on_lims_change)
|
||||
ax2_twin.callbacks.connect('ylim_changed', on_lims_change)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path is not None:
|
||||
plt.savefig(save_path, dpi=300)
|
||||
if show:
|
||||
plt.show()
|
||||
|
||||
|
||||
def draw_psg_signal(samp_id, tho_signal, abd_signal, flowp_signal, flowt_signal, spo2_signal, effort_signal, rri_signal, event_mask, fs,
|
||||
show=False, save_path=None):
|
||||
sa_mask = event_mask.repeat(fs)
|
||||
fig, axs = plt.subplots(7, 1, figsize=(18, 12), sharex=True)
|
||||
time_axis = np.linspace(0, len(tho_signal) / fs, len(tho_signal))
|
||||
axs[0].plot(time_axis, tho_signal, label='THO', color='black')
|
||||
axs[0].set_title(f'Sample {samp_id} - PSG Signal Data')
|
||||
axs[0].set_ylabel('THO Amplitude')
|
||||
axs[0].legend(loc='upper right')
|
||||
|
||||
ax0_twin = axs[0].twinx()
|
||||
ax0_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
|
||||
ax0_twin.autoscale(enable=False, axis='y', tight=True)
|
||||
ax0_twin.set_ylim((-4, 5))
|
||||
|
||||
axs[1].plot(time_axis, abd_signal, label='ABD', color='black')
|
||||
axs[1].set_ylabel('ABD Amplitude')
|
||||
axs[1].legend(loc='upper right')
|
||||
|
||||
ax1_twin = axs[1].twinx()
|
||||
ax1_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
|
||||
ax1_twin.autoscale(enable=False, axis='y', tight=True)
|
||||
ax1_twin.set_ylim((-4, 5))
|
||||
|
||||
axs[2].plot(time_axis, effort_signal, label='EFFO', color='black')
|
||||
axs[2].set_ylabel('EFFO Amplitude')
|
||||
axs[2].legend(loc='upper right')
|
||||
|
||||
ax2_twin = axs[2].twinx()
|
||||
ax2_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
|
||||
ax2_twin.autoscale(enable=False, axis='y', tight=True)
|
||||
ax2_twin.set_ylim((-4, 5))
|
||||
|
||||
axs[3].plot(time_axis, flowp_signal, label='FLOWP', color='black')
|
||||
axs[3].set_ylabel('FLOWP Amplitude')
|
||||
axs[3].legend(loc='upper right')
|
||||
|
||||
ax3_twin = axs[3].twinx()
|
||||
ax3_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
|
||||
ax3_twin.autoscale(enable=False, axis='y', tight=True)
|
||||
ax3_twin.set_ylim((-4, 5))
|
||||
|
||||
axs[4].plot(time_axis, flowt_signal, label='FLOWT', color='black')
|
||||
axs[4].set_ylabel('FLOWT Amplitude')
|
||||
axs[4].legend(loc='upper right')
|
||||
|
||||
ax4_twin = axs[4].twinx()
|
||||
ax4_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
|
||||
ax4_twin.autoscale(enable=False, axis='y', tight=True)
|
||||
ax4_twin.set_ylim((-4, 5))
|
||||
|
||||
axs[5].plot(time_axis, rri_signal, label='RRI', color='black')
|
||||
axs[5].set_ylabel('RRI Amplitude')
|
||||
axs[5].legend(loc='upper right')
|
||||
|
||||
|
||||
axs[6].plot(time_axis, spo2_signal, label='SPO2', color='black')
|
||||
axs[6].set_ylabel('SPO2 Amplitude')
|
||||
axs[6].set_xlabel('Time (s)')
|
||||
axs[6].legend(loc='upper right')
|
||||
|
||||
|
||||
if save_path is not None:
|
||||
plt.savefig(save_path, dpi=300)
|
||||
if show:
|
||||
plt.show()
|
||||
|
||||
183
event_mask_process/HYS_PSG_process.py
Normal file
183
event_mask_process/HYS_PSG_process.py
Normal file
@ -0,0 +1,183 @@
|
||||
"""
|
||||
本脚本完成对呼研所数据的处理,包含以下功能:
|
||||
1. 数据读取与预处理
|
||||
从传入路径中,进行数据和标签的读取,并进行初步的预处理
|
||||
预处理包括为数据进行滤波、去噪等操作
|
||||
2. 数据清洗与异常值处理
|
||||
3. 输出清晰后的统计信息
|
||||
4. 数据保存
|
||||
将处理后的数据保存到指定路径,便于后续使用
|
||||
主要是保存切分后的数据位置和标签
|
||||
5. 可视化
|
||||
提供数据处理前后的可视化对比,帮助理解数据变化
|
||||
绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等
|
||||
|
||||
|
||||
|
||||
# 低幅值区间规则标定与剔除
|
||||
# 高幅值连续体动规则标定与剔除
|
||||
# 手动标定不可用区间提剔除
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
project_root_path = Path(__file__).resolve().parent.parent
|
||||
|
||||
import shutil
|
||||
import draw_tools
|
||||
import utils
|
||||
import numpy as np
|
||||
import signal_method
|
||||
import os
|
||||
|
||||
|
||||
os.environ['DISPLAY'] = "localhost:10.0"
|
||||
|
||||
|
||||
def process_one_signal(samp_id, show=False):
|
||||
pass
|
||||
|
||||
tho_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Tho_Sync_*.txt"))
|
||||
abd_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Abd_Sync_*.txt"))
|
||||
flowp_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow P_Sync_*.txt"))
|
||||
flowt_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow T_Sync_*.txt"))
|
||||
spo2_signal_path = list((org_signal_root_path / f"{samp_id}").glob("SpO2_Sync_*.txt"))
|
||||
stage_signal_path = list((org_signal_root_path / f"{samp_id}").glob("5_class_Sync_*.txt"))
|
||||
|
||||
if not tho_signal_path:
|
||||
raise FileNotFoundError(f"Effort Tho_Sync file not found for sample ID: {samp_id}")
|
||||
tho_signal_path = tho_signal_path[0]
|
||||
print(f"Processing Effort Tho_Sync signal file: {tho_signal_path}")
|
||||
if not abd_signal_path:
|
||||
raise FileNotFoundError(f"Effort Abd_Sync file not found for sample ID: {samp_id}")
|
||||
abd_signal_path = abd_signal_path[0]
|
||||
print(f"Processing Effort Abd_Sync signal file: {abd_signal_path}")
|
||||
if not flowp_signal_path:
|
||||
raise FileNotFoundError(f"Flow P_Sync file not found for sample ID: {samp_id}")
|
||||
flowp_signal_path = flowp_signal_path[0]
|
||||
print(f"Processing Flow P_Sync signal file: {flowp_signal_path}")
|
||||
if not flowt_signal_path:
|
||||
raise FileNotFoundError(f"Flow T_Sync file not found for sample ID: {samp_id}")
|
||||
flowt_signal_path = flowt_signal_path[0]
|
||||
print(f"Processing Flow T_Sync signal file: {flowt_signal_path}")
|
||||
if not spo2_signal_path:
|
||||
raise FileNotFoundError(f"SpO2_Sync file not found for sample ID: {samp_id}")
|
||||
spo2_signal_path = spo2_signal_path[0]
|
||||
print(f"Processing SpO2_Sync signal file: {spo2_signal_path}")
|
||||
if not stage_signal_path:
|
||||
raise FileNotFoundError(f"5_class_Sync file not found for sample ID: {samp_id}")
|
||||
stage_signal_path = stage_signal_path[0]
|
||||
print(f"Processing 5_class_Sync signal file: {stage_signal_path}")
|
||||
|
||||
|
||||
|
||||
label_path = (label_root_path / f"{samp_id}").glob("SA Label_Sync.csv")
|
||||
if not label_path:
|
||||
raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}")
|
||||
label_path = list(label_path)[0]
|
||||
print(f"Processing Label_corrected file: {label_path}")
|
||||
#
|
||||
# # 保存处理后的数据和标签
|
||||
save_samp_path = save_path / f"{samp_id}"
|
||||
save_samp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# # # 读取信号数据
|
||||
tho_data_raw, tho_length, tho_fs, tho_second = utils.read_signal_txt(tho_signal_path, dtype=float, verbose=True)
|
||||
# abd_data_raw, abd_length, abd_fs, abd_second = utils.read_signal_txt(abd_signal_path, dtype=float, verbose=True)
|
||||
# flowp_data_raw, flowp_length, flowp_fs, flowp_second = utils.read_signal_txt(flowp_signal_path, dtype=float, verbose=True)
|
||||
# flowt_data_raw, flowt_length, flowt_fs, flowt_second = utils.read_signal_txt(flowt_signal_path, dtype=float, verbose=True)
|
||||
# spo2_data_raw, spo2_length, spo2_fs, spo2_second = utils.read_signal_txt(spo2_signal_path, dtype=float, verbose=True)
|
||||
stage_data_raw, stage_length, stage_fs, stage_second = utils.read_signal_txt(stage_signal_path, dtype=str, verbose=True)
|
||||
|
||||
|
||||
#
|
||||
# # 预处理与滤波
|
||||
# tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=tho_data_raw, effort_fs=tho_fs)
|
||||
# abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=abd_data_raw, effort_fs=abd_fs)
|
||||
# flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=flowp_data_raw, effort_fs=flowp_fs)
|
||||
# flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=flowt_data_raw, effort_fs=flowt_fs)
|
||||
|
||||
# 降采样
|
||||
# old_tho_fs = tho_fs
|
||||
# tho_fs = conf["effort"]["downsample_fs"]
|
||||
# tho_data_filt = utils.downsample_signal_fast(original_signal=tho_data_filt, original_fs=old_tho_fs, target_fs=tho_fs)
|
||||
# old_abd_fs = abd_fs
|
||||
# abd_fs = conf["effort"]["downsample_fs"]
|
||||
# abd_data_filt = utils.downsample_signal_fast(original_signal=abd_data_filt, original_fs=old_abd_fs, target_fs=abd_fs)
|
||||
# old_flowp_fs = flowp_fs
|
||||
# flowp_fs = conf["effort"]["downsample_fs"]
|
||||
# flowp_data_filt = utils.downsample_signal_fast(original_signal=flowp_data_filt, original_fs=old_flowp_fs, target_fs=flowp_fs)
|
||||
# old_flowt_fs = flowt_fs
|
||||
# flowt_fs = conf["effort"]["downsample_fs"]
|
||||
# flowt_data_filt = utils.downsample_signal_fast(original_signal=flowt_data_filt, original_fs=old_flowt_fs, target_fs=flowt_fs)
|
||||
|
||||
# spo2不降采样
|
||||
# spo2_data_filt = spo2_data_raw
|
||||
# spo2_fs = spo2_fs
|
||||
|
||||
label_data = utils.read_raw_psg_label(path=label_path)
|
||||
event_mask, score_mask = utils.generate_event_mask(signal_second=tho_second, event_df=label_data, use_correct=False, with_score=False)
|
||||
# event_mask > 0 的部分为1,其他为0
|
||||
score_mask = np.where(event_mask > 0, 1, 0)
|
||||
|
||||
# 根据睡眠分期生成不可用区间
|
||||
wake_mask = utils.get_wake_mask(stage_data_raw)
|
||||
# 剔除短于60秒的觉醒区间
|
||||
wake_mask = utils.remove_short_durations(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), min_duration_sec=60)
|
||||
# 合并短于120秒的觉醒区间
|
||||
wake_mask = utils.merge_short_gaps(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), max_gap_sec=60)
|
||||
|
||||
disable_label = wake_mask
|
||||
|
||||
disable_label = disable_label[:tho_second]
|
||||
|
||||
|
||||
# 复制事件文件 到保存路径
|
||||
sa_label_save_name = f"{samp_id}_" + label_path.name
|
||||
shutil.copyfile(label_path, save_samp_path / sa_label_save_name)
|
||||
#
|
||||
# 新建一个dataframe,分别是秒数、SA标签,
|
||||
save_dict = {
|
||||
"Second": np.arange(tho_second),
|
||||
"SA_Label": event_mask,
|
||||
"SA_Score": score_mask,
|
||||
"Disable_Label": disable_label,
|
||||
"Resp_LowAmp_Label": np.zeros_like(event_mask),
|
||||
"Resp_Movement_Label": np.zeros_like(event_mask),
|
||||
"Resp_AmpChange_Label": np.zeros_like(event_mask),
|
||||
"BCG_LowAmp_Label": np.zeros_like(event_mask),
|
||||
"BCG_Movement_Label": np.zeros_like(event_mask),
|
||||
"BCG_AmpChange_Label": np.zeros_like(event_mask)
|
||||
}
|
||||
|
||||
mask_label_save_name = f"{samp_id}_Processed_Labels.csv"
|
||||
utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml"
|
||||
# disable_df_path = project_root_path / "排除区间.xlsx"
|
||||
#
|
||||
conf = utils.load_dataset_conf(yaml_path)
|
||||
|
||||
root_path = Path(conf["root_path"])
|
||||
save_path = Path(conf["mask_save_path"])
|
||||
select_ids = conf["select_ids"]
|
||||
#
|
||||
print(f"select_ids: {select_ids}")
|
||||
print(f"root_path: {root_path}")
|
||||
print(f"save_path: {save_path}")
|
||||
#
|
||||
org_signal_root_path = root_path / "PSG_Aligned"
|
||||
label_root_path = root_path / "PSG_Aligned"
|
||||
#
|
||||
# all_samp_disable_df = utils.read_disable_excel(disable_df_path)
|
||||
#
|
||||
# process_one_signal(select_ids[0], show=True)
|
||||
# #
|
||||
for samp_id in select_ids:
|
||||
print(f"Processing sample ID: {samp_id}")
|
||||
process_one_signal(samp_id, show=False)
|
||||
print(f"Finished processing sample ID: {samp_id}\n\n")
|
||||
pass
|
||||
246
event_mask_process/HYS_process.py
Normal file
246
event_mask_process/HYS_process.py
Normal file
@ -0,0 +1,246 @@
|
||||
"""
|
||||
本脚本完成对呼研所数据的处理,包含以下功能:
|
||||
1. 数据读取与预处理
|
||||
从传入路径中,进行数据和标签的读取,并进行初步的预处理
|
||||
预处理包括为数据进行滤波、去噪等操作
|
||||
2. 数据清洗与异常值处理
|
||||
3. 输出清晰后的统计信息
|
||||
4. 数据保存
|
||||
将处理后的数据保存到指定路径,便于后续使用
|
||||
主要是保存切分后的数据位置和标签
|
||||
5. 可视化
|
||||
提供数据处理前后的可视化对比,帮助理解数据变化
|
||||
绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等
|
||||
|
||||
|
||||
|
||||
# 低幅值区间规则标定与剔除
|
||||
# 高幅值连续体动规则标定与剔除
|
||||
# 手动标定不可用区间提剔除
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
project_root_path = Path(__file__).resolve().parent.parent
|
||||
|
||||
import shutil
|
||||
import draw_tools
|
||||
import utils
|
||||
import numpy as np
|
||||
import signal_method
|
||||
import os
|
||||
|
||||
|
||||
os.environ['DISPLAY'] = "localhost:10.0"
|
||||
|
||||
|
||||
def process_one_signal(samp_id, show=False):
|
||||
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
|
||||
if not signal_path:
|
||||
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
|
||||
signal_path = signal_path[0]
|
||||
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
|
||||
|
||||
label_path = (label_root_path / f"{samp_id}").glob("SA Label_corrected.csv")
|
||||
if not label_path:
|
||||
raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}")
|
||||
label_path = list(label_path)[0]
|
||||
print(f"Processing Label_corrected file: {label_path}")
|
||||
|
||||
# 保存处理后的数据和标签
|
||||
save_samp_path = save_path / f"{samp_id}"
|
||||
save_samp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=True)
|
||||
|
||||
signal_data, resp_data, resp_fs, bcg_data, bcg_fs = signal_method.signal_filter_split(conf=conf, signal_data_raw=signal_data_raw, signal_fs=signal_fs)
|
||||
|
||||
# 降采样
|
||||
old_resp_fs = resp_fs
|
||||
resp_fs = conf["resp"]["downsample_fs_2"]
|
||||
resp_data = utils.downsample_signal_fast(original_signal=resp_data, original_fs=old_resp_fs, target_fs=resp_fs)
|
||||
bcg_fs = conf["bcg"]["downsample_fs"]
|
||||
bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs)
|
||||
|
||||
label_data = utils.read_label_csv(path=label_path)
|
||||
event_mask, score_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data)
|
||||
|
||||
manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[
|
||||
all_samp_disable_df["id"] == samp_id])
|
||||
print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}")
|
||||
|
||||
# 分析Resp的低幅值区间
|
||||
resp_low_amp_conf = conf.get("resp_low_amp", None)
|
||||
if resp_low_amp_conf is not None:
|
||||
resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal(
|
||||
signal_data=resp_data,
|
||||
sampling_rate=resp_fs,
|
||||
**resp_low_amp_conf
|
||||
)
|
||||
print(
|
||||
f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}, count_low_amp_positions: {len(resp_low_amp_position_list)}")
|
||||
else:
|
||||
resp_low_amp_mask, resp_low_amp_position_list = None, None
|
||||
print("resp_low_amp_mask is None")
|
||||
|
||||
# 分析Resp的高幅值伪迹区间
|
||||
resp_movement_conf = conf.get("resp_movement", None)
|
||||
if resp_movement_conf is not None:
|
||||
raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement(
|
||||
signal_data=resp_data,
|
||||
sampling_rate=resp_fs,
|
||||
**resp_movement_conf
|
||||
)
|
||||
print(
|
||||
f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}")
|
||||
else:
|
||||
resp_movement_mask, resp_movement_position_list = None, None
|
||||
print("resp_movement_mask is None")
|
||||
|
||||
resp_movement_revise_conf = conf.get("resp_movement_revise", None)
|
||||
if resp_movement_mask is not None and resp_movement_revise_conf is not None:
|
||||
resp_movement_mask, resp_movement_position_list = signal_method.movement_revise(
|
||||
signal_data=resp_data,
|
||||
movement_mask=resp_movement_mask,
|
||||
movement_list=resp_movement_position_list,
|
||||
sampling_rate=resp_fs,
|
||||
**resp_movement_revise_conf,
|
||||
verbose=False
|
||||
)
|
||||
print(
|
||||
f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}")
|
||||
else:
|
||||
print("resp_movement_mask revise is skipped")
|
||||
|
||||
# 分析Resp的幅值突变区间
|
||||
resp_amp_change_conf = conf.get("resp_amp_change", None)
|
||||
if resp_amp_change_conf is not None and resp_movement_mask is not None:
|
||||
resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v3(
|
||||
signal_data=resp_data,
|
||||
movement_mask=resp_movement_mask,
|
||||
movement_list=resp_movement_position_list,
|
||||
sampling_rate=resp_fs,
|
||||
**resp_amp_change_conf,
|
||||
verbose=False)
|
||||
print(
|
||||
f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}")
|
||||
else:
|
||||
resp_amp_change_mask = None
|
||||
print("amp_change_mask is None")
|
||||
|
||||
# 分析Bcg的低幅值区间
|
||||
bcg_low_amp_conf = conf.get("bcg_low_amp", None)
|
||||
if bcg_low_amp_conf is not None:
|
||||
bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal(
|
||||
signal_data=bcg_data,
|
||||
sampling_rate=bcg_fs,
|
||||
**bcg_low_amp_conf
|
||||
)
|
||||
print(
|
||||
f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}, count_low_amp_positions: {len(bcg_low_amp_position_list)}")
|
||||
else:
|
||||
bcg_low_amp_mask, bcg_low_amp_position_list = None, None
|
||||
print("bcg_low_amp_mask is None")
|
||||
|
||||
# 分析Bcg的高幅值伪迹区间
|
||||
bcg_movement_conf = conf.get("bcg_movement", None)
|
||||
if bcg_movement_conf is not None:
|
||||
raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement(
|
||||
signal_data=bcg_data,
|
||||
sampling_rate=bcg_fs,
|
||||
**bcg_movement_conf
|
||||
)
|
||||
print(
|
||||
f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}, count_movement_positions: {len(bcg_movement_position_list)}")
|
||||
else:
|
||||
bcg_movement_mask = None
|
||||
print("bcg_movement_mask is None")
|
||||
|
||||
# 分析Bcg的幅值突变区间
|
||||
if bcg_movement_mask is not None:
|
||||
bcg_amp_change_mask, bcg_amp_change_list = signal_method.position_based_sleep_recognition_v2(
|
||||
signal_data=bcg_data,
|
||||
movement_mask=bcg_movement_mask,
|
||||
sampling_rate=bcg_fs)
|
||||
print(
|
||||
f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}, count_amp_change_positions: {len(bcg_amp_change_list)}")
|
||||
else:
|
||||
bcg_amp_change_mask = None
|
||||
print("bcg_amp_change_mask is None")
|
||||
|
||||
# 如果signal_data采样率过,进行降采样
|
||||
if signal_fs == 1000:
|
||||
signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100)
|
||||
signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs,
|
||||
target_fs=100)
|
||||
signal_fs = 100
|
||||
|
||||
draw_tools.draw_signal_with_mask(samp_id=samp_id,
|
||||
signal_data=signal_data,
|
||||
signal_fs=signal_fs,
|
||||
resp_data=resp_data,
|
||||
resp_fs=resp_fs,
|
||||
bcg_data=bcg_data,
|
||||
bcg_fs=bcg_fs,
|
||||
signal_disable_mask=manual_disable_mask,
|
||||
resp_low_amp_mask=resp_low_amp_mask,
|
||||
resp_movement_mask=resp_movement_mask,
|
||||
resp_change_mask=resp_amp_change_mask,
|
||||
resp_sa_mask=event_mask,
|
||||
bcg_low_amp_mask=bcg_low_amp_mask,
|
||||
bcg_movement_mask=bcg_movement_mask,
|
||||
bcg_change_mask=bcg_amp_change_mask,
|
||||
show=show,
|
||||
save_path=save_samp_path / f"{samp_id}_Signal_Plots.png")
|
||||
|
||||
# 复制事件文件 到保存路径
|
||||
sa_label_save_name = f"{samp_id}_" + label_path.name
|
||||
shutil.copyfile(label_path, save_samp_path / sa_label_save_name)
|
||||
|
||||
# 新建一个dataframe,分别是秒数、SA标签,SA质量标签,禁用标签,Resp低幅值标签,Resp体动标签,Resp幅值突变标签,Bcg低幅值标签,Bcg体动标签,Bcg幅值突变标签
|
||||
save_dict = {
|
||||
"Second": np.arange(signal_second),
|
||||
"SA_Label": event_mask,
|
||||
"SA_Score": score_mask,
|
||||
"Disable_Label": manual_disable_mask,
|
||||
"Resp_LowAmp_Label": resp_low_amp_mask if resp_low_amp_mask is not None else np.zeros(signal_second, dtype=int),
|
||||
"Resp_Movement_Label": resp_movement_mask if resp_movement_mask is not None else np.zeros(signal_second,
|
||||
dtype=int),
|
||||
"Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second,
|
||||
dtype=int),
|
||||
"BCG_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int),
|
||||
"BCG_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second,
|
||||
dtype=int),
|
||||
"BCG_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second,
|
||||
dtype=int)
|
||||
}
|
||||
|
||||
mask_label_save_name = f"{samp_id}_Processed_Labels.csv"
|
||||
utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
yaml_path = project_root_path / "dataset_config/HYS_config.yaml"
|
||||
disable_df_path = project_root_path / "排除区间.xlsx"
|
||||
|
||||
conf = utils.load_dataset_conf(yaml_path)
|
||||
select_ids = conf["select_ids"]
|
||||
root_path = Path(conf["root_path"])
|
||||
save_path = Path(conf["mask_save_path"])
|
||||
|
||||
print(f"select_ids: {select_ids}")
|
||||
print(f"root_path: {root_path}")
|
||||
print(f"save_path: {save_path}")
|
||||
|
||||
org_signal_root_path = root_path / "OrgBCG_Aligned"
|
||||
label_root_path = root_path / "Label"
|
||||
|
||||
all_samp_disable_df = utils.read_disable_excel(disable_df_path)
|
||||
|
||||
# process_one_signal(select_ids[6], show=True)
|
||||
#
|
||||
for samp_id in select_ids:
|
||||
print(f"Processing sample ID: {samp_id}")
|
||||
process_one_signal(samp_id, show=False)
|
||||
print(f"Finished processing sample ID: {samp_id}\n\n")
|
||||
42
event_mask_process/SHHS1_process.py
Normal file
42
event_mask_process/SHHS1_process.py
Normal file
@ -0,0 +1,42 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
project_root_path = Path(__file__).resolve().parent.parent
|
||||
|
||||
import shutil
|
||||
import draw_tools
|
||||
import utils
|
||||
import numpy as np
|
||||
import signal_method
|
||||
import os
|
||||
import mne
|
||||
from tqdm import tqdm
|
||||
import xml.etree.ElementTree as ET
|
||||
import re
|
||||
|
||||
# 获取分期和事件标签,以及不可用区间
|
||||
|
||||
def process_one_signal(samp_id, show=False):
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
yaml_path = project_root_path / "dataset_config/SHHS1_config.yaml"
|
||||
|
||||
conf = utils.load_dataset_conf(yaml_path)
|
||||
|
||||
root_path = Path(conf["root_path"])
|
||||
save_path = Path(conf["mask_save_path"])
|
||||
|
||||
print(f"root_path: {root_path}")
|
||||
print(f"save_path: {save_path}")
|
||||
|
||||
org_signal_root_path = root_path / "OrgBCG_Aligned"
|
||||
label_root_path = root_path / "Label"
|
||||
|
||||
0
event_mask_process/__init__.py
Normal file
0
event_mask_process/__init__.py
Normal file
6
signal_method/__init__.py
Normal file
6
signal_method/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .rule_base_event import detect_low_amplitude_signal, detect_movement
|
||||
from .rule_base_event import position_based_sleep_recognition_v2, position_based_sleep_recognition_v3
|
||||
from .rule_base_event import movement_revise
|
||||
from .time_metrics import calc_mav_by_slide_windows
|
||||
from .signal_process import signal_filter_split, rpeak2hr, psg_effort_filter, rpeak2rri_interpolation
|
||||
from .normalize_method import normalize_resp_signal_by_segment
|
||||
52
signal_method/normalize_method.py
Normal file
52
signal_method/normalize_method.py
Normal file
@ -0,0 +1,52 @@
|
||||
import utils
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
|
||||
def normalize_resp_signal_by_segment(resp_signal: np.ndarray, resp_fs, movement_mask, enable_list):
|
||||
# 根据呼吸信号的幅值改变区间,对每段进行Z-Score标准化
|
||||
normalized_resp_signal = np.zeros_like(resp_signal)
|
||||
# 全部填成nan
|
||||
normalized_resp_signal[:] = np.nan
|
||||
|
||||
resp_signal_no_movement = resp_signal.copy()
|
||||
|
||||
|
||||
resp_signal_no_movement[np.array(movement_mask == 1).repeat(resp_fs)] = np.nan
|
||||
|
||||
|
||||
for i in range(len(enable_list)):
|
||||
enable_start = enable_list[i][0] * resp_fs
|
||||
enable_end = enable_list[i][1] * resp_fs
|
||||
segment = resp_signal_no_movement[enable_start:enable_end]
|
||||
|
||||
# print(f"Normalizing segment {i+1}/{len(enable_list)}: start={enable_start}, end={enable_end}, length={len(segment)}")
|
||||
|
||||
segment_mean = np.nanmean(segment)
|
||||
segment_std = np.nanstd(segment)
|
||||
if segment_std == 0:
|
||||
raise ValueError(f"segment_std is zero! segment_start: {enable_start}, segment_end: {enable_end}")
|
||||
|
||||
# 同下一个enable区间的体动一起进行标准化
|
||||
if i <= len(enable_list) - 2:
|
||||
enable_end = enable_list[i + 1][0] * resp_fs
|
||||
raw_segment = resp_signal[enable_start:enable_end]
|
||||
normalized_resp_signal[enable_start:enable_end] = (raw_segment - segment_mean) / segment_std
|
||||
|
||||
|
||||
#如果enable区间不从0开始,则将前面的部分也进行标准化
|
||||
if enable_list[0][0] > 0:
|
||||
new_enable_start = 0
|
||||
enable_start = enable_list[0][0] * resp_fs
|
||||
enable_end = enable_list[0][1] * resp_fs
|
||||
segment = resp_signal_no_movement[enable_start:enable_end]
|
||||
|
||||
segment_mean = np.nanmean(segment)
|
||||
segment_std = np.nanstd(segment)
|
||||
if segment_std == 0:
|
||||
raise ValueError(f"segment_std is zero! segment_start: {enable_start}, segment_end: {enable_end}")
|
||||
|
||||
raw_segment = resp_signal[new_enable_start:enable_start]
|
||||
normalized_resp_signal[new_enable_start:enable_start] = (raw_segment - segment_mean) / segment_std
|
||||
|
||||
return normalized_resp_signal
|
||||
688
signal_method/rule_base_event.py
Normal file
688
signal_method/rule_base_event.py
Normal file
@ -0,0 +1,688 @@
|
||||
import utils
|
||||
from utils.operation_tools import timing_decorator
|
||||
import numpy as np
|
||||
from utils import merge_short_gaps, remove_short_durations, event_mask_2_list, collect_values
|
||||
from signal_method.time_metrics import calc_mav_by_slide_windows
|
||||
|
||||
|
||||
@timing_decorator()
|
||||
def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=None,
|
||||
std_median_multiplier=4.5, compare_intervals_sec=[30, 60],
|
||||
interval_multiplier=2.5,
|
||||
merge_gap_sec=10, min_duration_sec=5,
|
||||
low_amplitude_periods=None):
|
||||
"""
|
||||
检测信号中的体动状态,结合两种方法:标准差比较和前后窗口幅值对比。
|
||||
使用反射填充处理信号边界。
|
||||
|
||||
参数:
|
||||
- signal_data: numpy array,输入的信号数据
|
||||
- sampling_rate: int,信号的采样率(Hz)
|
||||
- window_size_sec: float,分析窗口的时长(秒),默认值为 2 秒
|
||||
- stride_sec: float,窗口滑动步长(秒),默认值为None(等于window_size_sec,无重叠)
|
||||
- std_median_multiplier: float,标准差中位数的乘数阈值,默认值为 4.5
|
||||
- compare_intervals_sec: list,用于比较的时间间隔列表(秒),默认为 [30, 60]
|
||||
- interval_multiplier: float,间隔中位数的上限乘数,默认值为 2.5
|
||||
- merge_gap_sec: float,要合并的体动状态之间的最大间隔(秒),默认值为 10 秒
|
||||
- min_duration_sec: float,要保留的体动状态的最小持续时间(秒),默认值为 5 秒
|
||||
- low_amplitude_periods: numpy array,低幅值期间的掩码(1表示低幅值期间),默认为None
|
||||
|
||||
返回:
|
||||
- movement_mask: numpy array,体动状态的掩码(1表示体动,0表示睡眠)
|
||||
"""
|
||||
# 计算窗口大小(样本数)
|
||||
window_samples = int(window_size_sec * sampling_rate)
|
||||
|
||||
# 如果未指定步长,设置为窗口大小(无重叠)
|
||||
if stride_sec is None:
|
||||
stride_sec = window_size_sec
|
||||
|
||||
# 计算步长(样本数)
|
||||
stride_samples = int(stride_sec * sampling_rate)
|
||||
|
||||
# 确保步长至少为1
|
||||
stride_samples = max(1, stride_samples)
|
||||
|
||||
# 计算需要的最大填充大小(基于比较间隔)
|
||||
max_interval_samples = int(max(compare_intervals_sec) * sampling_rate)
|
||||
|
||||
# 应用反射填充以正确处理边界
|
||||
# 填充大小为最大比较间隔的一半,以确保边界有足够的上下文
|
||||
pad_size = max_interval_samples
|
||||
padded_signal = np.pad(signal_data, pad_size, mode='reflect')
|
||||
|
||||
# 计算填充后的窗口数量
|
||||
num_windows = max(1, (len(padded_signal) - window_samples) // stride_samples + 1)
|
||||
|
||||
# 初始化窗口标准差数组
|
||||
window_std = np.zeros(num_windows)
|
||||
# 计算每个窗口的标准差
|
||||
# 分窗计算标准差
|
||||
for i in range(num_windows):
|
||||
start_idx = i * stride_samples
|
||||
end_idx = min(start_idx + window_samples, len(padded_signal))
|
||||
|
||||
# 处理窗口,包括可能不完整的最后一个窗口
|
||||
window_data = padded_signal[start_idx:end_idx]
|
||||
if len(window_data) > 0:
|
||||
window_std[i] = np.std(window_data, ddof=1)
|
||||
else:
|
||||
window_std[i] = 0
|
||||
|
||||
# 计算原始信号对应的窗口索引范围
|
||||
# 填充后,原始信号从pad_size开始
|
||||
orig_start_window = pad_size // stride_samples
|
||||
if stride_sec == 1:
|
||||
orig_end_window = orig_start_window + (len(signal_data) // stride_samples)
|
||||
else:
|
||||
orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + 1
|
||||
|
||||
# 只保留原始信号对应的窗口标准差
|
||||
original_window_std = window_std[orig_start_window:orig_end_window]
|
||||
num_original_windows = len(original_window_std)
|
||||
|
||||
# 创建时间点数组(秒)
|
||||
time_points = np.arange(num_original_windows) * stride_sec
|
||||
|
||||
# # 如果提供了低幅值期间的掩码,则在计算全局中位数时排除这些期间
|
||||
# if low_amplitude_periods is not None and len(low_amplitude_periods) == num_original_windows:
|
||||
# valid_std = original_window_std[low_amplitude_periods == 0]
|
||||
# if len(valid_std) == 0: # 如果所有窗口都在低幅值期间
|
||||
# valid_std = original_window_std # 使用全部窗口
|
||||
# else:
|
||||
# valid_std = original_window_std
|
||||
|
||||
valid_std = original_window_std ##20250418新修改
|
||||
|
||||
# ---------------------- 方法一:基于STD的体动判定 ----------------------#
|
||||
# 计算所有有效窗口标准差的中位数
|
||||
median_std = np.median(valid_std)
|
||||
|
||||
# 当窗口标准差大于中位数的倍数,判定为体动状态
|
||||
std_movement = np.where((original_window_std > (median_std * std_median_multiplier)), 1, 0)
|
||||
|
||||
# ------------------ 方法二:基于前后信号幅值变化的体动判定 ------------------#
|
||||
amplitude_movement = np.zeros(num_original_windows, dtype=int)
|
||||
|
||||
# 定义基于时间粒度的比较间隔索引
|
||||
compare_intervals_idx = [int(interval // stride_sec) for interval in compare_intervals_sec]
|
||||
|
||||
# 逐窗口判断
|
||||
for win_idx in range(num_original_windows):
|
||||
# 全局索引(在填充后的窗口数组中)
|
||||
global_win_idx = win_idx + orig_start_window
|
||||
|
||||
# 对每个比较间隔进行检查
|
||||
for interval_idx in compare_intervals_idx:
|
||||
# 确定比较范围的结束索引(在填充后的窗口数组中)
|
||||
end_idx = min(global_win_idx + interval_idx, len(window_std))
|
||||
|
||||
# 提取相应时间范围内的标准差值
|
||||
if global_win_idx < end_idx:
|
||||
interval_std = window_std[global_win_idx:end_idx]
|
||||
|
||||
# 计算该间隔的中位数
|
||||
interval_median = np.median(interval_std)
|
||||
|
||||
# 计算上下阈值
|
||||
upper_threshold = interval_median * interval_multiplier
|
||||
|
||||
# 检查当前窗口是否超出阈值范围,如果超出则直接标记为体动
|
||||
if window_std[global_win_idx] > upper_threshold:
|
||||
amplitude_movement[win_idx] = 1
|
||||
break # 一旦确定为体动就不需要继续检查其他间隔
|
||||
|
||||
# 将两种方法的结果合并:只要其中一种方法判定为体动,最终结果就是体动
|
||||
movement_mask = np.logical_or(std_movement, amplitude_movement).astype(int)
|
||||
raw_movement_mask = movement_mask
|
||||
|
||||
# 如果需要合并间隔小的体动状态
|
||||
if merge_gap_sec > 0:
|
||||
movement_mask = merge_short_gaps(movement_mask, time_points, merge_gap_sec)
|
||||
|
||||
# 如果需要移除短时体动状态
|
||||
if min_duration_sec > 0:
|
||||
movement_mask = remove_short_durations(movement_mask, time_points, min_duration_sec)
|
||||
|
||||
# raw_movement_mask, movement_mask恢复对应秒数,而不是点数
|
||||
raw_movement_mask = raw_movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate]
|
||||
movement_mask = movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate]
|
||||
|
||||
# 比较剔除的体动,如果被剔除的体动所在区域有高于3std的幅值,则不剔除
|
||||
removed_movement_mask = (raw_movement_mask - movement_mask) > 0
|
||||
removed_movement_start = np.where(np.diff(np.concatenate([[0], removed_movement_mask])) == 1)[0]
|
||||
removed_movement_end = np.where(np.diff(np.concatenate([removed_movement_mask, [0]])) == -1)[0]
|
||||
|
||||
for start, end in zip(removed_movement_start, removed_movement_end):
|
||||
# print(start ,end)
|
||||
# 计算剔除的体动区域的幅值
|
||||
if np.nanmax(signal_data[start * sampling_rate:(end + 1) * sampling_rate]) > median_std * std_median_multiplier:
|
||||
movement_mask[start:end + 1] = 1
|
||||
|
||||
# raw体动起止位置 [[start, end], [start, end], ...]
|
||||
raw_movement_position_list = event_mask_2_list(raw_movement_mask)
|
||||
|
||||
# merge体动起止位置 [[start, end], [start, end], ...]
|
||||
movement_position_list = event_mask_2_list(movement_mask)
|
||||
|
||||
return raw_movement_mask, movement_mask, raw_movement_position_list, movement_position_list
|
||||
|
||||
|
||||
def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up_interval_multiplier: float,
|
||||
down_interval_multiplier: float, compare_intervals_sec, merge_gap_sec, min_duration_sec, verbose=False):
|
||||
"""
|
||||
基于标准差对已有体动掩码进行修正。 用于大尺度的体动检测后的位置精细修正
|
||||
|
||||
参数:
|
||||
- signal_data: numpy array,输入的信号数据
|
||||
- sampling_rate: int,信号的采样率(Hz)
|
||||
- movement_mask: numpy array,已有的体动掩码(1表示体动,0表示睡眠)
|
||||
|
||||
返回:
|
||||
- revised_movement_mask: numpy array,修正后的体动掩码
|
||||
"""
|
||||
window_size = sampling_rate
|
||||
stride_size = sampling_rate
|
||||
|
||||
time_points = np.arange(len(signal_data))
|
||||
|
||||
compare_size = int(compare_intervals_sec // (stride_size / sampling_rate))
|
||||
|
||||
_, mav = calc_mav_by_slide_windows(signal_data, movement_mask=None, low_amp_mask=None, sampling_rate=sampling_rate,
|
||||
window_second=4, step_second=1,
|
||||
inner_window_second=4)
|
||||
|
||||
# 往左右两边取compare_size个点的mav,取平均值
|
||||
for start, end in movement_list:
|
||||
left_points = start - 20
|
||||
right_points = end + 20
|
||||
|
||||
left_values = collect_values(arr=mav, index=left_points, step=-1, limit=compare_size, mask=movement_mask)
|
||||
right_values = collect_values(arr=mav, index=right_points, step=1, limit=compare_size, mask=movement_mask)
|
||||
|
||||
left_value_metrics = np.median(left_values) if len(left_values) > 0 else 0
|
||||
right_value_metrics = np.median(right_values) if len(right_values) > 0 else 0
|
||||
|
||||
# if left_value_metrics == 0:
|
||||
# value_metrics = right_value_metrics
|
||||
# elif right_value_metrics == 0:
|
||||
# value_metrics = left_value_metrics
|
||||
# else:
|
||||
# value_metrics = np.mean([left_value_metrics, right_value_metrics])
|
||||
|
||||
if left_value_metrics == 0:
|
||||
left_value_metrics = right_value_metrics
|
||||
elif right_value_metrics == 0:
|
||||
right_value_metrics = left_value_metrics
|
||||
|
||||
if verbose:
|
||||
print(f"Revising movement from index {start} to {end}, left_metric: {left_value_metrics:.2f}, right_metric: {right_value_metrics:.2f}")
|
||||
|
||||
for i in range(left_points, right_points):
|
||||
if i < 0 or i >= len(mav):
|
||||
continue
|
||||
if i < start:
|
||||
value_metrics = left_value_metrics
|
||||
elif i > end:
|
||||
value_metrics = right_value_metrics
|
||||
else:
|
||||
value_metrics = (left_value_metrics + right_value_metrics) / 2
|
||||
|
||||
if mav[i] > (value_metrics * up_interval_multiplier) and movement_mask[i] == 0:
|
||||
movement_mask[i] = 1
|
||||
if verbose:
|
||||
print(f"Normal revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_metrics * up_interval_multiplier:.2f}")
|
||||
elif mav[i] < (value_metrics * down_interval_multiplier) and movement_mask[i] == 1:
|
||||
movement_mask[i] = 0
|
||||
if verbose:
|
||||
print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_metrics * down_interval_multiplier:.2f}")
|
||||
else:
|
||||
if verbose:
|
||||
print(f"No revision at index {i}, mav: {mav[i]:.2f}, up_threshold: {value_metrics * up_interval_multiplier:.2f}, down_threshold: {value_metrics * down_interval_multiplier:.2f}")
|
||||
#
|
||||
# 逐秒遍历mav,判断是否需要修正
|
||||
# print(f"Revising movement from index {start} to {end}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}")
|
||||
# for i in range(left_points, right_points):
|
||||
# if i < 0 or i >= len(mav):
|
||||
# continue
|
||||
# # print(f"Index {i}, mav: {mav[i]:.2f}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}")
|
||||
# if mav[i] > (value_metrics * up_interval_multiplier):
|
||||
# movement_mask[i] = 1
|
||||
# # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * up_interval_multiplier:.2f}")
|
||||
# elif mav[i] < (value_metrics * down_interval_multiplier):
|
||||
# movement_mask[i] = 0
|
||||
# # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * down_interval_multiplier:.2f}")
|
||||
# # else:
|
||||
# # print(f"No revision at index {i}, mav: {mav[i]:.2f}, up_threshold: {value_mean * up_interval_multiplier:.2f}, down_threshold: {value_mean * down_interval_multiplier:.2f}")
|
||||
# #
|
||||
# 如果需要合并间隔小的体动状态
|
||||
if merge_gap_sec > 0:
|
||||
movement_mask = merge_short_gaps(movement_mask, time_points, merge_gap_sec)
|
||||
|
||||
# 如果需要移除短时体动状态
|
||||
if min_duration_sec > 0:
|
||||
movement_mask = remove_short_durations(movement_mask, time_points, min_duration_sec)
|
||||
|
||||
movement_list = event_mask_2_list(movement_mask)
|
||||
return movement_mask, movement_list
|
||||
|
||||
|
||||
@timing_decorator()
|
||||
def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, stride_sec=None,
|
||||
amplitude_threshold=50, merge_gap_sec=10, min_duration_sec=5):
|
||||
"""
|
||||
检测信号中的低幅值状态,通过计算RMS值判断信号强度是否低于设定阈值。
|
||||
|
||||
参数:
|
||||
- signal_data: numpy array,输入的信号数据
|
||||
- sampling_rate: int,信号的采样率(Hz)
|
||||
- window_size_sec: float,分析窗口的时长(秒),默认值为 1 秒
|
||||
- stride_sec: float,窗口滑动步长(秒),默认值为None(等于window_size_sec,无重叠)
|
||||
- amplitude_threshold: float,RMS阈值,低于此值表示低幅值状态,默认值为 50
|
||||
- merge_gap_sec: float,要合并的状态之间的最大间隔(秒),默认值为 10 秒
|
||||
- min_duration_sec: float,要保留的状态的最小持续时间(秒),默认值为 5 秒
|
||||
|
||||
返回:
|
||||
- low_amplitude_mask: numpy array,低幅值状态的掩码(1表示低幅值,0表示正常幅值)
|
||||
"""
|
||||
# 计算窗口大小(样本数)
|
||||
window_samples = int(window_size_sec * sampling_rate)
|
||||
|
||||
# 如果未指定步长,设置为窗口大小(无重叠)
|
||||
if stride_sec is None:
|
||||
stride_sec = window_size_sec
|
||||
|
||||
# 计算步长(样本数)
|
||||
stride_samples = int(stride_sec * sampling_rate)
|
||||
|
||||
# 确保步长至少为1
|
||||
stride_samples = max(sampling_rate, stride_samples)
|
||||
|
||||
# 处理信号边界,使用反射填充
|
||||
pad_size = window_samples // 2
|
||||
padded_signal = np.pad(signal_data, pad_size, mode='reflect')
|
||||
|
||||
# 计算填充后的窗口数量
|
||||
num_windows = max(1, (len(padded_signal) - window_samples) // stride_samples + 1)
|
||||
|
||||
# 初始化RMS值数组
|
||||
rms_values = np.zeros(num_windows)
|
||||
|
||||
# 计算每个窗口的RMS值
|
||||
for i in range(num_windows):
|
||||
start_idx = i * stride_samples
|
||||
end_idx = min(start_idx + window_samples, len(signal_data))
|
||||
|
||||
# 处理窗口,包括可能不完整的最后一个窗口
|
||||
window_data = signal_data[start_idx:end_idx]
|
||||
if len(window_data) > 0:
|
||||
rms_values[i] = np.sqrt(np.mean(window_data ** 2))
|
||||
else:
|
||||
rms_values[i] = 0
|
||||
|
||||
# 生成初始低幅值掩码:RMS低于阈值的窗口标记为1(低幅值),其他为0
|
||||
low_amplitude_mask = np.where(rms_values <= amplitude_threshold, 1, 0)
|
||||
|
||||
# 计算原始信号对应的窗口索引范围
|
||||
orig_start_window = pad_size // stride_samples
|
||||
if stride_sec == 1:
|
||||
orig_end_window = orig_start_window + (len(signal_data) // stride_samples)
|
||||
else:
|
||||
orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + 1
|
||||
|
||||
# 只保留原始信号对应的窗口低幅值掩码
|
||||
low_amplitude_mask = low_amplitude_mask[orig_start_window:orig_end_window]
|
||||
# print("low_amplitude_mask_length: ", len(low_amplitude_mask))
|
||||
num_original_windows = len(low_amplitude_mask)
|
||||
|
||||
# 转换为时间轴上的状态序列
|
||||
# 计算每个窗口对应的时间点(秒)
|
||||
time_points = np.arange(num_original_windows) * stride_sec
|
||||
|
||||
# 如果需要合并间隔小的状态
|
||||
if merge_gap_sec > 0:
|
||||
low_amplitude_mask = merge_short_gaps(low_amplitude_mask, time_points, merge_gap_sec)
|
||||
|
||||
# 如果需要移除短时状态
|
||||
if min_duration_sec > 0:
|
||||
low_amplitude_mask = remove_short_durations(low_amplitude_mask, time_points, min_duration_sec)
|
||||
|
||||
low_amplitude_mask = low_amplitude_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate]
|
||||
|
||||
# 低幅值状态起止位置 [[start, end], [start, end], ...]
|
||||
low_amplitude_position_list = event_mask_2_list(low_amplitude_mask)
|
||||
|
||||
return low_amplitude_mask, low_amplitude_position_list
|
||||
|
||||
|
||||
def get_typical_segment_for_continues_signal(signal_data, sampling_rate=100, window_size=30, step_size=1):
|
||||
"""
|
||||
获取十个片段
|
||||
:param signal_data: 信号数据
|
||||
:param sampling_rate: 采样率
|
||||
:param window_size: 窗口大小(秒)
|
||||
:param step_size: 步长(秒)
|
||||
:return: 典型片段列表
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# 基于体动位置和幅值的睡姿识别
|
||||
# 主要是依靠体动mask,将整夜分割成多个有效片段,然后每个片段计算幅值指标,判断两个片段的幅值指标是否存在显著差异,如果存在显著差异,则认为存在睡姿变化
|
||||
# 考虑到每个片段长度为10s,所以每个片段的幅值指标计算时间长度为10s,然后计算每个片段的幅值指标
|
||||
# 仅对比相邻片段的幅值指标,如果存在显著差异,则认为存在睡姿变化,即每个体动相邻的30秒内存在睡姿变化,如果片段不足30秒,则按实际长度对比
|
||||
|
||||
@timing_decorator()
|
||||
def position_based_sleep_recognition_v1(signal_data, movement_mask, sampling_rate=100, window_size_sec=30,
|
||||
interval_to_movement=10):
|
||||
mav_calc_window_sec = 2 # 计算mav的窗口大小,单位秒
|
||||
# 获取有效片段起止位置
|
||||
valid_mask = 1 - movement_mask
|
||||
valid_starts = np.where(np.diff(np.concatenate([[0], valid_mask])) == 1)[0]
|
||||
valid_ends = np.where(np.diff(np.concatenate([valid_mask, [0]])) == -1)[0]
|
||||
|
||||
movement_start = np.where(np.diff(np.concatenate([[0], movement_mask])) == 1)[0]
|
||||
movement_end = np.where(np.diff(np.concatenate([movement_mask, [0]])) == -1)[0]
|
||||
|
||||
segment_left_average_amplitude = []
|
||||
segment_right_average_amplitude = []
|
||||
segment_left_average_energy = []
|
||||
segment_right_average_energy = []
|
||||
|
||||
# window_samples = int(window_size_sec * sampling_rate)
|
||||
# pad_size = window_samples // 2
|
||||
# padded_signal = np.pad(signal_data, pad_size, mode='reflect')
|
||||
|
||||
for start, end in zip(valid_starts, valid_ends):
|
||||
start *= sampling_rate
|
||||
end *= sampling_rate
|
||||
# 避免过短的片段
|
||||
if end - start <= sampling_rate: # 小于1秒的片段不考虑
|
||||
continue
|
||||
# 获取当前片段数据
|
||||
|
||||
|
||||
elif end - start < (window_size_sec * interval_to_movement + 1) * sampling_rate:
|
||||
left_start = start
|
||||
left_end = min(end, left_start + window_size_sec * sampling_rate)
|
||||
right_start = max(start, end - window_size_sec * sampling_rate)
|
||||
right_end = end
|
||||
else:
|
||||
left_start = start + interval_to_movement * sampling_rate
|
||||
left_end = left_start + window_size_sec * sampling_rate
|
||||
right_start = end - interval_to_movement * sampling_rate - window_size_sec * sampling_rate
|
||||
right_end = end
|
||||
|
||||
# 新的end - start确保为200的整数倍
|
||||
if (left_end - left_start) % (mav_calc_window_sec * sampling_rate) != 0:
|
||||
left_end = left_start + ((left_end - left_start) // (mav_calc_window_sec * sampling_rate)) * (
|
||||
mav_calc_window_sec * sampling_rate)
|
||||
if (right_end - right_start) % (mav_calc_window_sec * sampling_rate) != 0:
|
||||
right_end = right_start + ((right_end - right_start) // (mav_calc_window_sec * sampling_rate)) * (
|
||||
mav_calc_window_sec * sampling_rate)
|
||||
|
||||
# 计算每个片段的幅值指标
|
||||
left_mav = np.mean(np.max(signal_data[left_start:left_end].reshape(-1, mav_calc_window_sec * sampling_rate),
|
||||
axis=0)) - np.mean(
|
||||
np.min(signal_data[left_start:left_end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0))
|
||||
right_mav = np.mean(
|
||||
np.max(signal_data[right_start:right_end].reshape(-1, mav_calc_window_sec * sampling_rate),
|
||||
axis=0)) - np.mean(
|
||||
np.min(signal_data[right_start:right_end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0))
|
||||
segment_left_average_amplitude.append(left_mav)
|
||||
segment_right_average_amplitude.append(right_mav)
|
||||
|
||||
left_energy = np.sum(np.abs(signal_data[left_start:left_end] ** 2))
|
||||
right_energy = np.sum(np.abs(signal_data[right_start:right_end] ** 2))
|
||||
segment_left_average_energy.append(left_energy)
|
||||
segment_right_average_energy.append(right_energy)
|
||||
|
||||
position_changes = []
|
||||
position_change_times = []
|
||||
# 判断是否存在显著变化 (可根据实际情况调整阈值)
|
||||
threshold_amplitude = 0.1 # 幅值变化阈值
|
||||
threshold_energy = 0.1 # 能量变化阈值
|
||||
|
||||
for i in range(1, len(segment_left_average_amplitude)):
|
||||
# 计算幅值指标的变化率
|
||||
left_amplitude_change = abs(segment_left_average_amplitude[i] - segment_left_average_amplitude[i - 1]) / max(
|
||||
segment_left_average_amplitude[i - 1], 1e-6)
|
||||
right_amplitude_change = abs(segment_right_average_amplitude[i] - segment_right_average_amplitude[i - 1]) / max(
|
||||
segment_right_average_amplitude[i - 1], 1e-6)
|
||||
|
||||
# 计算能量指标的变化率
|
||||
left_energy_change = abs(segment_left_average_energy[i] - segment_left_average_energy[i - 1]) / max(
|
||||
segment_left_average_energy[i - 1], 1e-6)
|
||||
right_energy_change = abs(segment_right_average_energy[i] - segment_right_average_energy[i - 1]) / max(
|
||||
segment_right_average_energy[i - 1], 1e-6)
|
||||
|
||||
# 如果左右通道中的任一通道同时满足幅值和能量的变化阈值,则认为存在姿势变化
|
||||
left_significant_change = (left_amplitude_change > threshold_amplitude) and (
|
||||
left_energy_change > threshold_energy)
|
||||
right_significant_change = (right_amplitude_change > threshold_amplitude) and (
|
||||
right_energy_change > threshold_energy)
|
||||
|
||||
if left_significant_change or right_significant_change:
|
||||
# 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示
|
||||
position_changes.append(1)
|
||||
position_change_times.append((movement_start[i - 1], movement_end[i - 1]))
|
||||
else:
|
||||
position_changes.append(0) # 0表示不存在姿势变化
|
||||
|
||||
return position_changes, position_change_times
|
||||
|
||||
|
||||
def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rate=100):
|
||||
"""
|
||||
|
||||
:param signal_data:
|
||||
:param movement_mask: mask的采样率为1Hz
|
||||
:param sampling_rate:
|
||||
:param window_size_sec:
|
||||
:return:
|
||||
"""
|
||||
mav_calc_window_sec = 2 # 计算mav的窗口大小,单位秒
|
||||
# 获取有效片段起止位置
|
||||
valid_mask = 1 - movement_mask
|
||||
valid_starts = np.where(np.diff(np.concatenate([[0], valid_mask])) == 1)[0]
|
||||
valid_ends = np.where(np.diff(np.concatenate([valid_mask, [0]])) == -1)[0]
|
||||
|
||||
movement_start = np.where(np.diff(np.concatenate([[0], movement_mask])) == 1)[0]
|
||||
movement_end = np.where(np.diff(np.concatenate([movement_mask, [0]])) == -1)[0]
|
||||
|
||||
segment_average_amplitude = []
|
||||
segment_average_energy = []
|
||||
|
||||
signal_data_no_movement = signal_data.copy()
|
||||
for start, end in zip(movement_start, movement_end):
|
||||
signal_data_no_movement[start * sampling_rate:end * sampling_rate] = np.nan
|
||||
|
||||
# from matplotlib import pyplot as plt
|
||||
# plt.plot(signal_data, alpha=0.3, color='gray')
|
||||
# plt.plot(signal_data_no_movement, color='blue', linewidth=1)
|
||||
# plt.show()
|
||||
|
||||
for start, end in zip(valid_starts, valid_ends):
|
||||
start *= sampling_rate
|
||||
end *= sampling_rate
|
||||
# 避免过短的片段
|
||||
if end - start <= sampling_rate: # 小于1秒的片段不考虑
|
||||
continue
|
||||
|
||||
# 新的end - start确保为200的整数倍
|
||||
if (end - start) % (mav_calc_window_sec * sampling_rate) != 0:
|
||||
end = start + ((end - start) // (mav_calc_window_sec * sampling_rate)) * (
|
||||
mav_calc_window_sec * sampling_rate)
|
||||
|
||||
# 计算每个片段的幅值指标
|
||||
mav = np.nanmean(
|
||||
np.nanmax(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate),
|
||||
axis=0)) - np.nanmean(
|
||||
np.nanmin(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0))
|
||||
segment_average_amplitude.append(mav)
|
||||
|
||||
energy = np.nansum(np.abs(signal_data_no_movement[start:end] ** 2))
|
||||
segment_average_energy.append(energy)
|
||||
|
||||
position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int)
|
||||
position_change_times = []
|
||||
# 判断是否存在显著变化 (可根据实际情况调整阈值)
|
||||
threshold_amplitude = 0.1 # 幅值变化阈值
|
||||
threshold_energy = 0.1 # 能量变化阈值
|
||||
|
||||
for i in range(1, len(segment_average_amplitude)):
|
||||
# 计算幅值指标的变化率
|
||||
amplitude_change = abs(segment_average_amplitude[i] - segment_average_amplitude[i - 1]) / max(
|
||||
segment_average_amplitude[i - 1], 1e-6)
|
||||
|
||||
# 计算能量指标的变化率
|
||||
energy_change = abs(segment_average_energy[i] - segment_average_energy[i - 1]) / max(
|
||||
segment_average_energy[i - 1], 1e-6)
|
||||
|
||||
significant_change = (amplitude_change > threshold_amplitude) and (energy_change > threshold_energy)
|
||||
|
||||
if significant_change:
|
||||
# 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示
|
||||
position_changes[movement_start[i - 1]:movement_end[i - 1]] = 1
|
||||
position_change_times.append((movement_start[i - 1], movement_end[i - 1]))
|
||||
|
||||
return position_changes, position_change_times
|
||||
|
||||
|
||||
def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_list, sampling_rate, mav_calc_window_sec,
|
||||
threshold_amplitude, threshold_energy, verbose=False):
|
||||
"""
|
||||
|
||||
:param threshold_energy:
|
||||
:param threshold_amplitude:
|
||||
:param mav_calc_window_sec:
|
||||
:param movement_list:
|
||||
:param signal_data:
|
||||
:param movement_mask: mask的采样率为1Hz
|
||||
:param sampling_rate:
|
||||
:param window_size_sec:
|
||||
:return:
|
||||
"""
|
||||
# 获取有效片段起止位置
|
||||
|
||||
valid_list = utils.event_mask_2_list(movement_mask, event_true=False)
|
||||
|
||||
segment_average_amplitude = []
|
||||
segment_average_energy = []
|
||||
|
||||
signal_data_no_movement = signal_data.copy()
|
||||
for start, end in movement_list:
|
||||
signal_data_no_movement[start * sampling_rate:end * sampling_rate] = np.nan
|
||||
|
||||
# from matplotlib import pyplot as plt
|
||||
# plt.plot(signal_data, alpha=0.3, color='gray')
|
||||
# plt.plot(signal_data_no_movement, color='blue', linewidth=1)
|
||||
# plt.show()
|
||||
|
||||
if len(valid_list) < 2:
|
||||
return [], []
|
||||
|
||||
def clac_mav(data_segment):
|
||||
# 确定data_segment长度为mav_calc_window_sec的整数倍
|
||||
if len(data_segment) % (mav_calc_window_sec * sampling_rate) != 0:
|
||||
data_segment = data_segment[:-(len(data_segment) % (mav_calc_window_sec * sampling_rate))]
|
||||
|
||||
mav = np.nanstd(
|
||||
np.nanmax(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate),
|
||||
axis=0) -
|
||||
np.nanmin(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=0))
|
||||
return mav
|
||||
|
||||
def clac_energy(data_segment):
|
||||
energy = np.nansum(np.abs(data_segment ** 2)) // (len(data_segment) // sampling_rate)
|
||||
return energy
|
||||
|
||||
def calc_mav_by_quantiles(data_segment):
|
||||
# 先计算所有的mav值
|
||||
if len(data_segment) % (mav_calc_window_sec * sampling_rate) != 0:
|
||||
data_segment = data_segment[:-(len(data_segment) % (mav_calc_window_sec * sampling_rate))]
|
||||
|
||||
mav_values = np.nanmax(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=1) - np.nanmin(
|
||||
data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=1)
|
||||
# 计算分位数
|
||||
q20 = np.nanpercentile(mav_values, 20)
|
||||
q80 = np.nanpercentile(mav_values, 80)
|
||||
|
||||
mav_values = mav_values[(mav_values >= q20) & (mav_values <= q80)]
|
||||
mav = np.nanmean(mav_values)
|
||||
return mav
|
||||
|
||||
position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int)
|
||||
position_change_list = []
|
||||
|
||||
pre_valid_start = valid_list[0][0] * sampling_rate
|
||||
pre_valid_end = valid_list[0][1] * sampling_rate
|
||||
|
||||
if verbose:
|
||||
print(f"Total movement segments to analyze: {len(movement_list)}")
|
||||
print(f"Total valid segments available: {len(valid_list)}")
|
||||
|
||||
for i in range(len(movement_list)):
|
||||
if verbose:
|
||||
print(f"Analyzing movement segment {i + 1}/{len(movement_list)}")
|
||||
|
||||
if i + 1 >= len(valid_list):
|
||||
if verbose:
|
||||
print("No more valid segments to compare. Ending analysis.")
|
||||
break
|
||||
|
||||
next_valid_start = valid_list[i + 1][0] * sampling_rate
|
||||
next_valid_end = valid_list[i + 1][1] * sampling_rate
|
||||
|
||||
movement_start = movement_list[i][0]
|
||||
movement_end = movement_list[i][1]
|
||||
|
||||
# 避免过短的片段
|
||||
if movement_end - movement_start <= 1: # 小于1秒的片段不考虑
|
||||
if verbose:
|
||||
print(
|
||||
f"Skipping movement segment {i + 1} due to insufficient length. movement start: {movement_start}, movement end: {movement_end}")
|
||||
pre_valid_start = pre_valid_start
|
||||
pre_valid_end = next_valid_end
|
||||
continue
|
||||
|
||||
# 计算前后片段的幅值和能量
|
||||
# left_mav = clac_mav(signal_data_no_movement[pre_valid_start:pre_valid_end])
|
||||
# right_mav = clac_mav(signal_data_no_movement[next_valid_start:next_valid_end])
|
||||
# left_energy = clac_energy(signal_data_no_movement[pre_valid_start:pre_valid_end])
|
||||
# right_energy = clac_energy(signal_data_no_movement[next_valid_start:next_valid_end])
|
||||
|
||||
left_mav = calc_mav_by_quantiles(signal_data_no_movement[pre_valid_start:pre_valid_end])
|
||||
right_mav = calc_mav_by_quantiles(signal_data_no_movement[next_valid_start:next_valid_end])
|
||||
|
||||
|
||||
# 计算幅值指标的变化率
|
||||
amplitude_change = abs(right_mav - left_mav) / max(left_mav, 1e-6)
|
||||
# # 计算能量指标的变化率
|
||||
# energy_change = abs(right_energy - left_energy) / max(left_energy, 1e-6)
|
||||
|
||||
# significant_change = (amplitude_change > threshold_amplitude) and (energy_change > threshold_energy)
|
||||
significant_change = (amplitude_change > threshold_amplitude)
|
||||
if significant_change:
|
||||
# print(
|
||||
# f"Significant position change detected between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}")
|
||||
if verbose:
|
||||
print(f"Significant position change detected between segments {movement_start}s and {movement_end}:s left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right:{next_valid_start}to{next_valid_end} right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}")
|
||||
# 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示
|
||||
position_changes[movement_start:movement_end] = 1
|
||||
position_change_list.append(movement_list[i])
|
||||
# 更新前后片段
|
||||
pre_valid_start = next_valid_start
|
||||
pre_valid_end = next_valid_end
|
||||
|
||||
else:
|
||||
# print(
|
||||
# f"No significant position change between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}")
|
||||
if verbose:
|
||||
print(f"No significant position change between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}")
|
||||
|
||||
# 仅更新前片段
|
||||
pre_valid_start = pre_valid_start
|
||||
pre_valid_end = next_valid_end
|
||||
|
||||
return position_changes, position_change_list
|
||||
62
signal_method/shhs_tools.py
Normal file
62
signal_method/shhs_tools.py
Normal file
@ -0,0 +1,62 @@
|
||||
import xml.etree.ElementTree as ET
|
||||
ANNOTATION_MAP = {
|
||||
"Wake|0": 0,
|
||||
"Stage 1 sleep|1": 1,
|
||||
"Stage 2 sleep|2": 2,
|
||||
"Stage 3 sleep|3": 3,
|
||||
"Stage 4 sleep|4": 4,
|
||||
"REM sleep|5": 5,
|
||||
"Unscored|9": 9,
|
||||
"Movement|6": 6
|
||||
}
|
||||
|
||||
SA_EVENTS = ['Central apnea', 'Hypopnea', 'Obstructive apnea']
|
||||
|
||||
def parse_sleep_annotations(annotation_path):
|
||||
"""解析睡眠分期注释"""
|
||||
try:
|
||||
tree = ET.parse(annotation_path)
|
||||
root = tree.getroot()
|
||||
events = []
|
||||
for scored_event in root.findall('.//ScoredEvent'):
|
||||
event_type = scored_event.find('EventType').text
|
||||
if event_type != "Stages|Stages":
|
||||
continue
|
||||
description = scored_event.find('EventConcept').text
|
||||
start = float(scored_event.find('Start').text)
|
||||
duration = float(scored_event.find('Duration').text)
|
||||
if description not in ANNOTATION_MAP:
|
||||
continue
|
||||
events.append({
|
||||
'onset': start,
|
||||
'duration': duration,
|
||||
'description': description,
|
||||
'stage': ANNOTATION_MAP[description]
|
||||
})
|
||||
return events
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def extract_osa_events(annotation_path):
|
||||
"""提取睡眠呼吸暂停事件"""
|
||||
try:
|
||||
tree = ET.parse(annotation_path)
|
||||
root = tree.getroot()
|
||||
events = []
|
||||
for scored_event in root.findall('.//ScoredEvent'):
|
||||
event_concept = scored_event.find('EventConcept').text
|
||||
event_type = event_concept.split('|')[0].strip()
|
||||
if event_type in SA_EVENTS:
|
||||
start = float(scored_event.find('Start').text)
|
||||
duration = float(scored_event.find('Duration').text)
|
||||
if duration >= 10:
|
||||
events.append({
|
||||
'start': start,
|
||||
'duration': duration,
|
||||
'end': start + duration,
|
||||
'type': event_type
|
||||
})
|
||||
return events
|
||||
except Exception as e:
|
||||
return []
|
||||
107
signal_method/signal_process.py
Normal file
107
signal_method/signal_process.py
Normal file
@ -0,0 +1,107 @@
|
||||
import numpy as np
|
||||
from scipy.interpolate import interp1d
|
||||
|
||||
import utils
|
||||
|
||||
def signal_filter_split(conf, signal_data_raw, signal_fs, verbose=True):
|
||||
# 滤波
|
||||
# 50Hz陷波滤波器
|
||||
# signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs)
|
||||
if verbose:
|
||||
print("Applying 50Hz notch filter...")
|
||||
signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs)
|
||||
|
||||
resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs)
|
||||
resp_fs = conf["resp"]["downsample_fs_1"]
|
||||
resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs)
|
||||
resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20)
|
||||
resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"],
|
||||
low_cut=conf["resp_filter"]["low_cut"],
|
||||
high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"],
|
||||
sample_rate=resp_fs)
|
||||
if verbose:
|
||||
print("Begin plotting signal data...")
|
||||
|
||||
# fig = plt.figure(figsize=(12, 8))
|
||||
# # 绘制三个图raw_data、resp_data_1、resp_data_2
|
||||
# ax0 = fig.add_subplot(3, 1, 1)
|
||||
# ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue')
|
||||
# ax0.set_title('Raw Signal Data')
|
||||
# ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
|
||||
# ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange')
|
||||
# ax1.set_title('Resp Data after Average Filtering')
|
||||
# ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
|
||||
# ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green')
|
||||
# ax2.set_title('Resp Data after Butterworth Filtering')
|
||||
# plt.tight_layout()
|
||||
# plt.show()
|
||||
|
||||
bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"],
|
||||
low_cut=conf["bcg_filter"]["low_cut"],
|
||||
high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"],
|
||||
sample_rate=signal_fs)
|
||||
|
||||
|
||||
return signal_data, resp_data_2, resp_fs, bcg_data, signal_fs
|
||||
|
||||
|
||||
def psg_effort_filter(conf, effort_data_raw, effort_fs):
|
||||
# 滤波
|
||||
effort_data_1 = utils.bessel(data=effort_data_raw, _type=conf["effort_filter"]["filter_type"],
|
||||
low_cut=conf["effort_filter"]["low_cut"],
|
||||
high_cut=conf["effort_filter"]["high_cut"], order=conf["effort_filter"]["order"],
|
||||
sample_rate=effort_fs)
|
||||
# 移动平均
|
||||
effort_data_2 = utils.average_filter(raw_data=effort_data_1, sample_rate=effort_fs, window_size_sec=conf["average_filter"]["window_size_sec"])
|
||||
return effort_data_raw, effort_data_2, effort_fs
|
||||
|
||||
|
||||
def rpeak2hr(rpeak_indices, signal_length, ecg_fs):
|
||||
hr_signal = np.zeros(signal_length)
|
||||
for i in range(1, len(rpeak_indices)):
|
||||
rri = rpeak_indices[i] - rpeak_indices[i - 1]
|
||||
if rri == 0:
|
||||
continue
|
||||
hr = 60 * ecg_fs / rri # 心率,单位:bpm
|
||||
if hr > 120:
|
||||
hr = 120
|
||||
elif hr < 30:
|
||||
hr = 30
|
||||
hr_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = hr
|
||||
# 填充最后一个R峰之后的心率值
|
||||
if len(rpeak_indices) > 1:
|
||||
hr_signal[rpeak_indices[-1]:] = hr_signal[rpeak_indices[-2]]
|
||||
return hr_signal
|
||||
|
||||
def rpeak2rri_repeat(rpeak_indices, signal_length, ecg_fs):
|
||||
rri_signal = np.zeros(signal_length)
|
||||
for i in range(1, len(rpeak_indices)):
|
||||
rri = rpeak_indices[i] - rpeak_indices[i - 1]
|
||||
rri_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = rri
|
||||
# 填充最后一个R峰之后的RRI值
|
||||
if len(rpeak_indices) > 1:
|
||||
rri_signal[rpeak_indices[-1]:] = rri_signal[rpeak_indices[-2]]
|
||||
|
||||
# 遍历异常值
|
||||
for i in range(1, len(rpeak_indices)):
|
||||
rri = rpeak_indices[i] - rpeak_indices[i - 1]
|
||||
if rri < 0.3 * ecg_fs or rri > 2 * ecg_fs:
|
||||
rri_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = 0
|
||||
|
||||
return rri_signal
|
||||
|
||||
def rpeak2rri_interpolation(rpeak_indices, ecg_fs, rri_fs):
|
||||
r_peak_time = np.asarray(rpeak_indices) / ecg_fs
|
||||
rri = np.diff(r_peak_time)
|
||||
t_rri = r_peak_time[1:]
|
||||
|
||||
mask = (rri > 0.3) & (rri < 2.0)
|
||||
rri_clean = rri[mask]
|
||||
t_rri_clean = t_rri[mask]
|
||||
|
||||
t_uniform = np.arange(t_rri_clean[0], t_rri_clean[-1], 1/rri_fs)
|
||||
f = interp1d(t_rri_clean, rri_clean, kind='linear', fill_value="extrapolate")
|
||||
rri_uniform = f(t_uniform)
|
||||
|
||||
return rri_uniform, rri_fs
|
||||
|
||||
45
signal_method/time_metrics.py
Normal file
45
signal_method/time_metrics.py
Normal file
@ -0,0 +1,45 @@
|
||||
from utils.operation_tools import calculate_by_slide_windows
|
||||
from utils.operation_tools import timing_decorator
|
||||
import numpy as np
|
||||
|
||||
|
||||
@timing_decorator()
|
||||
def calc_mav_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window_second=10, step_second=1, inner_window_second=2):
|
||||
if movement_mask is not None:
|
||||
assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}"
|
||||
# assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}"
|
||||
# print(f"movement_mask_length: {len(movement_mask)}, signal_data_length: {len(signal_data)}")
|
||||
processed_mask = movement_mask.copy()
|
||||
else:
|
||||
processed_mask = None
|
||||
|
||||
def mav_func(x):
|
||||
return np.mean(np.nanmax(x.reshape(-1, inner_window_second*sampling_rate), axis=1) - np.nanmin(x.reshape(-1, inner_window_second*sampling_rate), axis=1)) / 2
|
||||
mav_nan, mav = calculate_by_slide_windows(mav_func, signal_data, processed_mask, sampling_rate=sampling_rate,
|
||||
window_second=window_second, step_second=step_second)
|
||||
|
||||
return mav_nan, mav
|
||||
|
||||
@timing_decorator()
|
||||
def calc_wavefactor_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100):
|
||||
assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}"
|
||||
assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}"
|
||||
|
||||
processed_mask = movement_mask.copy()
|
||||
processed_mask = processed_mask.repeat(sampling_rate)
|
||||
wavefactor_nan, wavefactor = calculate_by_slide_windows(lambda x: np.sqrt(np.mean(x ** 2)) / np.mean(np.abs(x)),
|
||||
signal_data, processed_mask, sampling_rate=sampling_rate, window_second=2, step_second=1)
|
||||
|
||||
return wavefactor_nan, wavefactor
|
||||
|
||||
@timing_decorator()
|
||||
def calc_peakfactor_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100):
|
||||
assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}"
|
||||
assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}"
|
||||
|
||||
processed_mask = movement_mask.copy()
|
||||
processed_mask = processed_mask.repeat(sampling_rate)
|
||||
peakfactor_nan, peakfactor = calculate_by_slide_windows(lambda x: np.max(np.abs(x)) / np.sqrt(np.mean(x ** 2)),
|
||||
signal_data, processed_mask, sampling_rate=sampling_rate, window_second=2, step_second=1)
|
||||
|
||||
return peakfactor_nan, peakfactor
|
||||
354
utils/HYS_FileReader.py
Normal file
354
utils/HYS_FileReader.py
Normal file
@ -0,0 +1,354 @@
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import utils
|
||||
from .event_map import N2Chn
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from .operation_tools import event_mask_2_list
|
||||
# 尝试导入 Polars
|
||||
try:
|
||||
import polars as pl
|
||||
|
||||
HAS_POLARS = True
|
||||
except ImportError:
|
||||
HAS_POLARS = False
|
||||
|
||||
|
||||
def read_signal_txt(path: Union[str, Path], dtype, verbose=True, is_peak=False):
|
||||
"""
|
||||
Read a txt file and return the first column as a numpy array.
|
||||
|
||||
Args:
|
||||
:param is_peak:
|
||||
:param path:
|
||||
:param verbose:
|
||||
:param dtype:
|
||||
Returns:
|
||||
np.ndarray: The first column of the txt file as a numpy array.
|
||||
|
||||
"""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
if HAS_POLARS:
|
||||
df = pl.read_csv(path, has_header=False, infer_schema_length=0)
|
||||
signal_data_raw = df[:, 0].to_numpy().astype(dtype)
|
||||
else:
|
||||
df = pd.read_csv(path, header=None, dtype=dtype)
|
||||
signal_data_raw = df.iloc[:, 0].to_numpy()
|
||||
|
||||
signal_original_length = len(signal_data_raw)
|
||||
signal_fs = int(path.stem.split("_")[-1])
|
||||
if is_peak:
|
||||
signal_second = None
|
||||
signal_length = None
|
||||
else:
|
||||
signal_second = signal_original_length // signal_fs
|
||||
# 根据采样率进行截断
|
||||
signal_data_raw = signal_data_raw[:signal_second * signal_fs]
|
||||
signal_length = len(signal_data_raw)
|
||||
|
||||
if verbose:
|
||||
print(f"Signal file read from {path}")
|
||||
print(f"signal_fs: {signal_fs}")
|
||||
print(f"signal_original_length: {signal_original_length}")
|
||||
print(f"signal_after_cut_off_length: {signal_length}")
|
||||
print(f"signal_second: {signal_second}")
|
||||
|
||||
return signal_data_raw, signal_length, signal_fs, signal_second
|
||||
|
||||
|
||||
def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame:
|
||||
"""
|
||||
Read a CSV file and return it as a pandas DataFrame.
|
||||
|
||||
Args:
|
||||
path (str | Path): Path to the CSV file.
|
||||
verbose (bool):
|
||||
Returns:
|
||||
pd.DataFrame: The content of the CSV file as a pandas DataFrame.
|
||||
:param path:
|
||||
:param verbose:
|
||||
"""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
# 直接用pandas读取 包含中文 故指定编码
|
||||
df = pd.read_csv(path, encoding="gbk")
|
||||
if verbose:
|
||||
print(f"Label file read from {path}, number of rows: {len(df)}")
|
||||
|
||||
# 统计打标情况
|
||||
# isLabeled=1 表示已打标
|
||||
# Event type 有值的为PSG导出的事件
|
||||
# Event type 为nan的为手动打标的事件
|
||||
# score=1 显著事件, score=2 为受干扰事件 score=3 为非显著应删除事件
|
||||
# 确认后的事件在correct_EventsType
|
||||
# 输出事件信息 按照总计事件、低通气、中枢性、阻塞性、混合型按行输出 格式为 总计/来自PSG/手动/删除/未标注
|
||||
# Columns:
|
||||
# Index Event type Stage Time Epoch Date Duration HR bef. HR extr. HR delta O2 bef. O2 min. O2 delta Body Position Validation Start End score remark correct_Start correct_End correct_EventsType isLabeled
|
||||
# Event type:
|
||||
# Hypopnea
|
||||
# Central apnea
|
||||
# Obstructive apnea
|
||||
# Mixed apnea
|
||||
|
||||
num_total = np.sum((df["isLabeled"] == 1) & (df["score"] != 3))
|
||||
|
||||
num_psg_events = np.sum(df["Event type"].notna())
|
||||
num_manual_events = np.sum(df["Event type"].isna())
|
||||
|
||||
num_deleted = np.sum(df["score"] == 3)
|
||||
|
||||
# 统计事件
|
||||
num_unlabeled = np.sum(df["isLabeled"] == -1)
|
||||
|
||||
num_psg_hyp = np.sum(df["Event type"] == "Hypopnea")
|
||||
num_psg_csa = np.sum(df["Event type"] == "Central apnea")
|
||||
num_psg_osa = np.sum(df["Event type"] == "Obstructive apnea")
|
||||
num_psg_msa = np.sum(df["Event type"] == "Mixed apnea")
|
||||
|
||||
num_hyp = np.sum((df["correct_EventsType"] == "Hypopnea") & (df["score"] != 3))
|
||||
num_csa = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] != 3))
|
||||
num_osa = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] != 3))
|
||||
num_msa = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] != 3))
|
||||
|
||||
num_manual_hyp = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Hypopnea"))
|
||||
num_manual_csa = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Central apnea"))
|
||||
num_manual_osa = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Obstructive apnea"))
|
||||
num_manual_msa = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Mixed apnea"))
|
||||
|
||||
num_deleted_hyp = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Hypopnea"))
|
||||
num_deleted_csa = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Central apnea"))
|
||||
num_deleted_osa = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Obstructive apnea"))
|
||||
num_deleted_msa = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Mixed apnea"))
|
||||
|
||||
num_unlabeled_hyp = np.sum((df["isLabeled"] == -1) & (df["Event type"] == "Hypopnea"))
|
||||
num_unlabeled_csa = np.sum((df["isLabeled"] == -1) & (df["Event type"] == "Central apnea"))
|
||||
num_unlabeled_osa = np.sum((df["isLabeled"] == -1) & (df["Event type"] == "Obstructive apnea"))
|
||||
num_unlabeled_msa = np.sum((df["isLabeled"] == -1) & (df["Event type"] == "Mixed apnea"))
|
||||
|
||||
num_hyp_1_score = np.sum((df["correct_EventsType"] == "Hypopnea") & (df["score"] == 1))
|
||||
num_csa_1_score = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] == 1))
|
||||
num_osa_1_score = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] == 1))
|
||||
num_msa_1_score = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] == 1))
|
||||
|
||||
num_hyp_2_score = np.sum((df["correct_EventsType"] == "Hypopnea") & (df["score"] == 2))
|
||||
num_csa_2_score = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] == 2))
|
||||
num_osa_2_score = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] == 2))
|
||||
num_msa_2_score = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] == 2))
|
||||
|
||||
num_hyp_3_score = np.sum((df["correct_EventsType"] == "Hypopnea") & (df["score"] == 3))
|
||||
num_csa_3_score = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] == 3))
|
||||
num_osa_3_score = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] == 3))
|
||||
num_msa_3_score = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] == 3))
|
||||
|
||||
num_1_score = np.sum(df["score"] == 1)
|
||||
num_2_score = np.sum(df["score"] == 2)
|
||||
num_3_score = np.sum(df["score"] == 3)
|
||||
|
||||
if verbose:
|
||||
print("Event Statistics:")
|
||||
# 格式化输出 总计/来自PSG/手动/删除/未标注 指定宽度
|
||||
print(f"Type {'Total':^8s} / {'From PSG':^8s} / {'Manual':^8s} / {'Deleted':^8s} / {'Unlabeled':^8s}")
|
||||
print(
|
||||
f"Hyp: {num_hyp:^8d} / {num_psg_hyp:^8d} / {num_manual_hyp:^8d} / {num_deleted_hyp:^8d} / {num_unlabeled_hyp:^8d}")
|
||||
print(
|
||||
f"CSA: {num_csa:^8d} / {num_psg_csa:^8d} / {num_manual_csa:^8d} / {num_deleted_csa:^8d} / {num_unlabeled_csa:^8d}")
|
||||
print(
|
||||
f"OSA: {num_osa:^8d} / {num_psg_osa:^8d} / {num_manual_osa:^8d} / {num_deleted_osa:^8d} / {num_unlabeled_osa:^8d}")
|
||||
print(
|
||||
f"MSA: {num_msa:^8d} / {num_psg_msa:^8d} / {num_manual_msa:^8d} / {num_deleted_msa:^8d} / {num_unlabeled_msa:^8d}")
|
||||
print(
|
||||
f"All: {num_total:^8d} / {num_psg_events:^8d} / {num_manual_events:^8d} / {num_deleted:^8d} / {num_unlabeled:^8d}")
|
||||
|
||||
print("Score Statistics (only for non-deleted events and manual created events):")
|
||||
print(f"Type {'Total':^8s} / {'Score 1':^8s} / {'Score 2':^8s} / {'Score 3':^8s}")
|
||||
print(f"Hyp: {num_hyp:^8d} / {num_hyp_1_score:^8d} / {num_hyp_2_score:^8d} / {num_hyp_3_score:^8d}")
|
||||
print(f"CSA: {num_csa:^8d} / {num_csa_1_score:^8d} / {num_csa_2_score:^8d} / {num_csa_3_score:^8d}")
|
||||
print(f"OSA: {num_osa:^8d} / {num_osa_1_score:^8d} / {num_osa_2_score:^8d} / {num_osa_3_score:^8d}")
|
||||
print(f"MSA: {num_msa:^8d} / {num_msa_1_score:^8d} / {num_msa_2_score:^8d} / {num_msa_3_score:^8d}")
|
||||
print(f"All: {num_total:^8d} / {num_1_score:^8d} / {num_2_score:^8d} / {num_3_score:^8d}")
|
||||
|
||||
df["Start"] = df["Start"].astype(int)
|
||||
df["End"] = df["End"].astype(int)
|
||||
return df
|
||||
|
||||
|
||||
def read_raw_psg_label(path: Union[str, Path], verbose=True) -> pd.DataFrame:
|
||||
"""
|
||||
Read a CSV file and return it as a pandas DataFrame.
|
||||
|
||||
Args:
|
||||
path (str | Path): Path to the CSV file.
|
||||
verbose (bool):
|
||||
Returns:
|
||||
pd.DataFrame: The content of the CSV file as a pandas DataFrame.
|
||||
:param path:
|
||||
:param verbose:
|
||||
"""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
# 直接用pandas读取 包含中文 故指定编码
|
||||
df = pd.read_csv(path, encoding="gbk")
|
||||
if verbose:
|
||||
print(f"Label file read from {path}, number of rows: {len(df)}")
|
||||
|
||||
num_psg_events = np.sum(df["Event type"].notna())
|
||||
# 统计事件
|
||||
num_psg_hyp = np.sum(df["Event type"] == "Hypopnea")
|
||||
num_psg_csa = np.sum(df["Event type"] == "Central apnea")
|
||||
num_psg_osa = np.sum(df["Event type"] == "Obstructive apnea")
|
||||
num_psg_msa = np.sum(df["Event type"] == "Mixed apnea")
|
||||
|
||||
|
||||
|
||||
if verbose:
|
||||
print("Event Statistics:")
|
||||
# 格式化输出 总计/来自PSG/手动/删除/未标注 指定宽度
|
||||
print(f"Type {'Total':^8s}")
|
||||
print(
|
||||
f"Hyp: {num_psg_hyp:^8d} ")
|
||||
print(
|
||||
f"CSA: {num_psg_csa:^8d} ")
|
||||
print(
|
||||
f"OSA: {num_psg_osa:^8d} ")
|
||||
print(
|
||||
f"MSA: {num_psg_msa:^8d} ")
|
||||
print(
|
||||
f"All: {num_psg_events:^8d}")
|
||||
|
||||
df["Start"] = df["Start"].astype(int)
|
||||
df["End"] = df["End"].astype(int)
|
||||
return df
|
||||
|
||||
def read_disable_excel(path: Union[str, Path]) -> pd.DataFrame:
|
||||
"""
|
||||
Read an Excel file and return it as a pandas DataFrame.
|
||||
|
||||
Args:
|
||||
path (str | Path): Path to the Excel file.
|
||||
Returns:
|
||||
pd.DataFrame: The content of the Excel file as a pandas DataFrame.
|
||||
"""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
# 直接用pandas读取
|
||||
df = pd.read_excel(path)
|
||||
df["id"] = df["id"].astype(int)
|
||||
df["start"] = df["start"].astype(int)
|
||||
df["end"] = df["end"].astype(int)
|
||||
return df
|
||||
|
||||
|
||||
def read_mask_execl(path: Union[str, Path]):
|
||||
"""
|
||||
Read an Excel file and return the mask as a numpy array.
|
||||
Args:
|
||||
path (str | Path): Path to the Excel file.
|
||||
Returns:
|
||||
np.ndarray: The mask as a numpy array.
|
||||
"""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
df = pd.read_csv(path)
|
||||
event_mask = df.to_dict(orient="list")
|
||||
for key in event_mask:
|
||||
event_mask[key] = np.array(event_mask[key])
|
||||
|
||||
event_list = {"RespAmpChangeSegment": event_mask_2_list(1 - event_mask["Resp_AmpChange_Label"]),
|
||||
"BCGAmpChangeSegment": event_mask_2_list(1 - event_mask["BCG_AmpChange_Label"]),
|
||||
"EnableSegment": event_mask_2_list(1 - event_mask["Disable_Label"]),
|
||||
"DisableSegment": event_mask_2_list(event_mask["Disable_Label"])}
|
||||
|
||||
|
||||
return event_mask, event_list
|
||||
|
||||
|
||||
def read_psg_mask_excel(path: Union[str, Path]):
|
||||
|
||||
df = pd.read_csv(path)
|
||||
event_mask = df.to_dict(orient="list")
|
||||
for key in event_mask:
|
||||
event_mask[key] = np.array(event_mask[key])
|
||||
|
||||
return event_mask
|
||||
|
||||
|
||||
def read_psg_channel(path_str: Union[str, Path], channel_number: list[int], verbose=True):
|
||||
"""
|
||||
读取PSG文件中特定通道的数据。
|
||||
|
||||
参数:
|
||||
path_str (Union[str, Path]): 存放PSG文件的文件夹路径。
|
||||
channel_name (str): 需要读取的通道名称。
|
||||
返回:
|
||||
np.ndarray: 指定通道的数据数组。
|
||||
"""
|
||||
path = Path(path_str)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"PSG Dir not found: {path}")
|
||||
|
||||
if not path.is_dir():
|
||||
raise NotADirectoryError(f"PSG Dir not found: {path}")
|
||||
channel_data = {}
|
||||
# 遍历检查通道对应的文件是否存在
|
||||
for ch_id in channel_number:
|
||||
ch_name = N2Chn[ch_id]
|
||||
ch_path = list(path.glob(f"{ch_name}*.txt"))
|
||||
|
||||
if not any(ch_path):
|
||||
raise FileNotFoundError(f"PSG Channel file not found: {ch_path}")
|
||||
|
||||
if len(ch_path) > 1:
|
||||
print(f"Warning!!! PSG Channel file more than one: {ch_path}")
|
||||
|
||||
if ch_id == 8:
|
||||
# sleep stage 特例 读取为整数
|
||||
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=str, verbose=verbose)
|
||||
# 转换为整数数组
|
||||
for stage_str, stage_number in utils.Stage2N.items():
|
||||
np.place(ch_signal, ch_signal == stage_str, stage_number)
|
||||
ch_signal = ch_signal.astype(int)
|
||||
elif ch_id == 1:
|
||||
# Rpeak 特例 读取为整数
|
||||
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=int, verbose=verbose, is_peak=True)
|
||||
else:
|
||||
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=float, verbose=verbose)
|
||||
channel_data[ch_name] = {
|
||||
"name": ch_name,
|
||||
"path": ch_path[0],
|
||||
"data": ch_signal,
|
||||
"length": ch_length,
|
||||
"fs": ch_fs,
|
||||
"second": ch_second
|
||||
}
|
||||
|
||||
return channel_data
|
||||
|
||||
|
||||
def read_psg_label(path: Union[str, Path], verbose=True):
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
# 直接用pandas读取 包含中文 故指定编码
|
||||
df = pd.read_csv(path, encoding="gbk")
|
||||
if verbose:
|
||||
print(f"Label file read from {path}, number of rows: {len(df)}")
|
||||
|
||||
# 丢掉Event type为空的行
|
||||
df = df.dropna(subset=["Event type"], how='all').reset_index(drop=True)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
12
utils/__init__.py
Normal file
12
utils/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel, read_psg_label, read_raw_psg_label, read_psg_mask_excel
|
||||
from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list
|
||||
from .operation_tools import merge_short_gaps, remove_short_durations
|
||||
from .operation_tools import collect_values
|
||||
from .operation_tools import save_process_label
|
||||
from .operation_tools import none_to_nan_mask
|
||||
from .operation_tools import get_wake_mask
|
||||
from .operation_tools import fill_spo2_anomaly
|
||||
from .split_method import resp_split
|
||||
from .HYS_FileReader import read_mask_execl, read_psg_channel
|
||||
from .event_map import E2N, N2Chn, Stage2N, ColorCycle
|
||||
from .filter_func import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel, adjust_sample_rate
|
||||
42
utils/event_map.py
Normal file
42
utils/event_map.py
Normal file
@ -0,0 +1,42 @@
|
||||
# apnea event type to number mapping
|
||||
E2N = {
|
||||
"Hypopnea": 1,
|
||||
"Central apnea": 2,
|
||||
"Obstructive apnea": 3,
|
||||
"Mixed apnea": 4
|
||||
}
|
||||
|
||||
N2Chn = {
|
||||
1: "Rpeak",
|
||||
2: "ECG_Sync",
|
||||
3: "Effort Tho",
|
||||
4: "Effort Abd",
|
||||
5: "Flow P",
|
||||
6: "Flow T",
|
||||
7: "SpO2",
|
||||
8: "5_class"
|
||||
}
|
||||
|
||||
Stage2N = {
|
||||
"W": 5,
|
||||
"N1": 3,
|
||||
"N2": 2,
|
||||
"N3": 1,
|
||||
"R": 4,
|
||||
}
|
||||
|
||||
# 设定事件和其对应颜色
|
||||
# event_code color event
|
||||
# 0 黑色 背景
|
||||
# 1 粉色 低通气
|
||||
# 2 蓝色 中枢性
|
||||
# 3 红色 阻塞型
|
||||
# 4 灰色 混合型
|
||||
# 5 绿色 血氧饱和度下降
|
||||
# 6 橙色 大体动
|
||||
# 7 橙色 小体动
|
||||
# 8 橙色 深呼吸
|
||||
# 9 橙色 脉冲体动
|
||||
# 10 橙色 无效片段
|
||||
ColorCycle = ["black", "pink", "blue", "red", "silver", "green", "orange", "orange", "orange", "orange",
|
||||
"orange"]
|
||||
155
utils/filter_func.py
Normal file
155
utils/filter_func.py
Normal file
@ -0,0 +1,155 @@
|
||||
from utils.operation_tools import timing_decorator
|
||||
import numpy as np
|
||||
from scipy import signal, ndimage
|
||||
|
||||
|
||||
@timing_decorator()
|
||||
def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10, sample_rate=1000):
|
||||
if _type == "lowpass": # 低通滤波处理
|
||||
sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos')
|
||||
return signal.sosfiltfilt(sos, np.array(data))
|
||||
elif _type == "bandpass": # 带通滤波处理
|
||||
low = low_cut / (sample_rate * 0.5)
|
||||
high = high_cut / (sample_rate * 0.5)
|
||||
sos = signal.butter(order, [low, high], btype='bandpass', output='sos')
|
||||
return signal.sosfiltfilt(sos, np.array(data))
|
||||
elif _type == "highpass": # 高通滤波处理
|
||||
sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos')
|
||||
return signal.sosfiltfilt(sos, np.array(data))
|
||||
else: # 警告,滤波器类型必须有
|
||||
raise ValueError("Please choose a type of fliter")
|
||||
|
||||
|
||||
@timing_decorator()
|
||||
def bessel(data, _type, low_cut=0.0, high_cut=0.0, order=4, sample_rate=1000):
|
||||
if _type == "lowpass": # 低通滤波处理
|
||||
b, a = signal.bessel(order, low_cut / (sample_rate * 0.5), btype='lowpass', analog=False, norm='mag')
|
||||
return signal.filtfilt(b, a, np.array(data))
|
||||
elif _type == "bandpass": # 带通滤波处理
|
||||
low = low_cut / (sample_rate * 0.5)
|
||||
high = high_cut / (sample_rate * 0.5)
|
||||
b, a = signal.bessel(order, [low, high], btype='bandpass', analog=False, norm='mag')
|
||||
return signal.filtfilt(b, a, np.array(data))
|
||||
elif _type == "highpass": # 高通滤波处理
|
||||
b, a = signal.bessel(order, high_cut / (sample_rate * 0.5), btype='highpass', analog=False, norm='mag')
|
||||
return signal.filtfilt(b, a, np.array(data))
|
||||
else: # 警告,滤波器类型必须有
|
||||
raise ValueError("Please choose a type of fliter")
|
||||
|
||||
|
||||
@timing_decorator()
|
||||
def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000):
|
||||
"""
|
||||
高效整数倍降采样长信号,分段处理以优化内存和速度。
|
||||
|
||||
参数:
|
||||
original_signal : array-like, 原始信号数组
|
||||
original_fs : float, 原始采样率 (Hz)
|
||||
target_fs : float, 目标采样率 (Hz)
|
||||
chunk_size : int, 每段处理的样本数,默认100000
|
||||
|
||||
返回:
|
||||
downsampled_signal : array-like, 降采样后的信号
|
||||
"""
|
||||
# 输入验证
|
||||
if not isinstance(original_signal, np.ndarray):
|
||||
original_signal = np.array(original_signal)
|
||||
if original_fs <= target_fs:
|
||||
raise ValueError("目标采样率必须小于原始采样率")
|
||||
if target_fs <= 0 or original_fs <= 0:
|
||||
raise ValueError("采样率必须为正数")
|
||||
|
||||
# 计算降采样因子(必须为整数)
|
||||
downsample_factor = original_fs / target_fs
|
||||
if not downsample_factor.is_integer():
|
||||
raise ValueError("降采样因子必须为整数倍")
|
||||
downsample_factor = int(downsample_factor)
|
||||
|
||||
# 计算总输出长度
|
||||
total_length = len(original_signal)
|
||||
output_length = total_length // downsample_factor
|
||||
|
||||
# 初始化输出数组
|
||||
downsampled_signal = np.zeros(output_length)
|
||||
|
||||
# 分段处理
|
||||
for start in range(0, total_length, chunk_size):
|
||||
end = min(start + chunk_size, total_length)
|
||||
chunk = original_signal[start:end]
|
||||
|
||||
# 使用decimate进行整数倍降采样
|
||||
chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True)
|
||||
|
||||
# 计算输出位置
|
||||
out_start = start // downsample_factor
|
||||
out_end = out_start + len(chunk_downsampled)
|
||||
if out_end > output_length:
|
||||
chunk_downsampled = chunk_downsampled[:output_length - out_start]
|
||||
|
||||
downsampled_signal[out_start:out_end] = chunk_downsampled
|
||||
|
||||
return downsampled_signal
|
||||
|
||||
def upsample_signal(original_signal, original_fs, target_fs):
|
||||
"""
|
||||
信号升采样
|
||||
|
||||
参数:
|
||||
original_signal : array-like, 原始信号数组
|
||||
original_fs : float, 原始采样率 (Hz)
|
||||
target_fs : float, 目标采样率 (Hz)
|
||||
|
||||
返回:
|
||||
upsampled_signal : array-like, 升采样后的信号
|
||||
"""
|
||||
if not isinstance(original_signal, np.ndarray):
|
||||
original_signal = np.array(original_signal)
|
||||
if target_fs <= original_fs:
|
||||
raise ValueError("目标采样率必须大于原始采样率")
|
||||
if target_fs <= 0 or original_fs <= 0:
|
||||
raise ValueError("采样率必须为正数")
|
||||
|
||||
upsample_factor = target_fs / original_fs
|
||||
num_output_samples = int(len(original_signal) * upsample_factor)
|
||||
|
||||
upsampled_signal = signal.resample(original_signal, num_output_samples)
|
||||
|
||||
return upsampled_signal
|
||||
|
||||
|
||||
def adjust_sample_rate(signal_data, original_fs, target_fs):
|
||||
"""
|
||||
根据信号的原始采样率和目标采样率,自动选择升采样或降采样。
|
||||
|
||||
参数:
|
||||
signal_data : array-like, 原始信号数组
|
||||
original_fs : float, 原始采样率 (Hz)
|
||||
target_fs : float, 目标采样率 (Hz)
|
||||
|
||||
返回:
|
||||
adjusted_signal : array-like, 调整采样率后的信号
|
||||
"""
|
||||
if original_fs == target_fs:
|
||||
return signal_data
|
||||
elif original_fs > target_fs:
|
||||
return downsample_signal_fast(signal_data, original_fs, target_fs)
|
||||
else:
|
||||
return upsample_signal(signal_data, original_fs, target_fs)
|
||||
|
||||
|
||||
@timing_decorator()
|
||||
def average_filter(raw_data, sample_rate, window_size_sec=20):
|
||||
kernel = np.ones(window_size_sec * sample_rate) / (window_size_sec * sample_rate)
|
||||
filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect')
|
||||
convolve_filter_signal = raw_data - filtered
|
||||
return convolve_filter_signal
|
||||
|
||||
|
||||
# 陷波滤波器
|
||||
@timing_decorator()
|
||||
def notch_filter(data, notch_freq=50.0, quality_factor=30.0, sample_rate=1000):
|
||||
nyquist = 0.5 * sample_rate
|
||||
norm_notch_freq = notch_freq / nyquist
|
||||
b, a = signal.iirnotch(norm_notch_freq, quality_factor)
|
||||
filtered_data = signal.filtfilt(b, a, data)
|
||||
return filtered_data
|
||||
380
utils/operation_tools.py
Normal file
380
utils/operation_tools.py
Normal file
@ -0,0 +1,380 @@
|
||||
import time
|
||||
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from matplotlib import pyplot as plt
|
||||
import yaml
|
||||
from numpy.ma.core import append
|
||||
from scipy.interpolate import PchipInterpolator
|
||||
|
||||
from utils.event_map import E2N
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||||
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
||||
|
||||
from scipy import ndimage, signal
|
||||
from functools import wraps
|
||||
|
||||
# 全局配置
|
||||
class Config:
|
||||
time_verbose = False
|
||||
|
||||
def timing_decorator():
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
if Config.time_verbose: # 运行时检查全局配置
|
||||
print(f"函数 '{func.__name__}' 执行耗时: {elapsed_time:.4f} 秒")
|
||||
return result
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
@timing_decorator()
|
||||
def read_auto(file_path):
|
||||
# print('suffix: ', file_path.suffix)
|
||||
if file_path.suffix == '.txt':
|
||||
# 使用pandas read csv读取txt
|
||||
return pd.read_csv(file_path, header=None).to_numpy().reshape(-1)
|
||||
elif file_path.suffix == '.npy':
|
||||
return np.load(file_path.__str__())
|
||||
elif file_path.suffix == '.base64':
|
||||
with open(file_path) as f:
|
||||
files = f.readlines()
|
||||
|
||||
data = np.array(files, dtype=int)
|
||||
return data
|
||||
else:
|
||||
raise ValueError('这个文件类型不支持,需要自己写读取程序')
|
||||
|
||||
|
||||
def merge_short_gaps(state_sequence, time_points, max_gap_sec):
|
||||
"""
|
||||
合并状态序列中间隔小于指定时长的段
|
||||
|
||||
参数:
|
||||
- state_sequence: numpy array,状态序列(0/1)
|
||||
- time_points: numpy array,每个状态点对应的时间点
|
||||
- max_gap_sec: float,要合并的最大间隔(秒)
|
||||
|
||||
返回:
|
||||
- merged_sequence: numpy array,合并后的状态序列
|
||||
"""
|
||||
if len(state_sequence) <= 1:
|
||||
return state_sequence
|
||||
|
||||
merged_sequence = state_sequence.copy()
|
||||
|
||||
# 找出状态转换点
|
||||
transitions = np.diff(np.concatenate([[0], merged_sequence, [0]]))
|
||||
# 找出状态1的起始和结束位置
|
||||
state_starts = np.where(transitions == 1)[0]
|
||||
state_ends = np.where(transitions == -1)[0] - 1
|
||||
|
||||
# 检查每对连续的状态1
|
||||
for i in range(len(state_starts) - 1):
|
||||
if state_ends[i] < len(time_points) and state_starts[i + 1] < len(time_points):
|
||||
# 计算间隔时长
|
||||
gap_duration = time_points[state_starts[i + 1]] - time_points[state_ends[i]]
|
||||
# 如果间隔小于阈值,则合并
|
||||
if gap_duration <= max_gap_sec:
|
||||
merged_sequence[state_ends[i]:state_starts[i + 1]] = 1
|
||||
|
||||
return merged_sequence
|
||||
|
||||
|
||||
def remove_short_durations(state_sequence, time_points, min_duration_sec):
|
||||
"""
|
||||
移除状态序列中持续时间短于指定阈值的段
|
||||
|
||||
参数:
|
||||
- state_sequence: numpy array,状态序列(0/1)
|
||||
- time_points: numpy array,每个状态点对应的时间点
|
||||
- min_duration_sec: float,要保留的最小持续时间(秒)
|
||||
|
||||
返回:
|
||||
- filtered_sequence: numpy array,过滤后的状态序列
|
||||
"""
|
||||
if len(state_sequence) <= 1:
|
||||
return state_sequence
|
||||
|
||||
filtered_sequence = state_sequence.copy()
|
||||
|
||||
# 找出状态转换点
|
||||
transitions = np.diff(np.concatenate([[0], filtered_sequence, [0]]))
|
||||
# 找出状态1的起始和结束位置
|
||||
state_starts = np.where(transitions == 1)[0]
|
||||
state_ends = np.where(transitions == -1)[0] - 1
|
||||
|
||||
# 检查每个状态1的持续时间
|
||||
for i in range(len(state_starts)):
|
||||
if state_starts[i] < len(time_points) and state_ends[i] < len(time_points):
|
||||
# 计算持续时间
|
||||
duration = time_points[state_ends[i]] - time_points[state_starts[i]]
|
||||
if state_ends[i] == len(time_points) - 1:
|
||||
# 如果是最后一个窗口,加上一个窗口的长度
|
||||
duration += time_points[1] - time_points[0] if len(time_points) > 1 else 0
|
||||
|
||||
# 如果持续时间短于阈值,则移除
|
||||
if duration < min_duration_sec:
|
||||
filtered_sequence[state_starts[i]:state_ends[i] + 1] = 0
|
||||
|
||||
return filtered_sequence
|
||||
|
||||
|
||||
@timing_decorator()
|
||||
def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, window_second=20, step_second=None):
|
||||
# 处理标志位长度与 signal_data 对齐
|
||||
# if calc_mask is None:
|
||||
# calc_mask = np.zeros(len(signal_data), dtype=bool)
|
||||
|
||||
if step_second is None:
|
||||
step_second = window_second
|
||||
|
||||
step_length = step_second * sampling_rate
|
||||
window_length = window_second * sampling_rate
|
||||
|
||||
origin_seconds = len(signal_data) // sampling_rate
|
||||
total_samples = len(signal_data)
|
||||
|
||||
|
||||
# reflect padding
|
||||
left_pad_size = int(window_length // 2)
|
||||
right_pad_size = window_length - left_pad_size
|
||||
data = np.pad(signal_data, (left_pad_size, right_pad_size), mode='reflect')
|
||||
|
||||
num_segments = int(np.ceil(len(signal_data) / step_length))
|
||||
values_nan = np.full(num_segments, np.nan)
|
||||
|
||||
# print(f"num_segments: {num_segments}, step_length: {step_length}, window_length: {window_length}")
|
||||
for i in range(num_segments):
|
||||
# 包含体动则仅计算不含体动部分
|
||||
start = int(i * step_length)
|
||||
end = start + window_length
|
||||
segment = data[start:end]
|
||||
values_nan[i] = func(segment)
|
||||
|
||||
|
||||
values_nan = values_nan.repeat(step_second)[:origin_seconds]
|
||||
|
||||
if calc_mask is not None:
|
||||
for i in range(len(values_nan)):
|
||||
if calc_mask[i]:
|
||||
values_nan[i] = np.nan
|
||||
|
||||
values = values_nan.copy()
|
||||
|
||||
# 插值处理体动区域的 NaN 值
|
||||
def interpolate_nans(x, t):
|
||||
valid_mask = ~np.isnan(x)
|
||||
return np.interp(t, t[valid_mask], x[valid_mask])
|
||||
|
||||
values = interpolate_nans(values, np.arange(len(values)))
|
||||
else:
|
||||
values = values_nan.copy()
|
||||
|
||||
return values_nan, values
|
||||
|
||||
|
||||
def load_dataset_conf(yaml_path):
|
||||
with open(yaml_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# select_ids = config.get('select_id', [])
|
||||
# root_path = config.get('root_path', None)
|
||||
# data_path = Path(root_path)
|
||||
return config
|
||||
|
||||
|
||||
def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray:
|
||||
disable_mask = np.zeros(signal_second, dtype=int)
|
||||
|
||||
for _, row in disable_df.iterrows():
|
||||
start = row["start"]
|
||||
end = row["end"]
|
||||
disable_mask[start:end] = 1
|
||||
return disable_mask
|
||||
|
||||
|
||||
def generate_event_mask(signal_second: int, event_df, use_correct=True, with_score=True):
|
||||
event_mask = np.zeros(signal_second, dtype=int)
|
||||
if with_score:
|
||||
score_mask = np.zeros(signal_second, dtype=int)
|
||||
else:
|
||||
score_mask = None
|
||||
if use_correct:
|
||||
start_name = "correct_Start"
|
||||
end_name = "correct_End"
|
||||
event_type_name = "correct_EventsType"
|
||||
else:
|
||||
start_name = "Start"
|
||||
end_name = "End"
|
||||
event_type_name = "Event type"
|
||||
|
||||
# 剔除start = -1 的行
|
||||
event_df = event_df[event_df[start_name] >= 0]
|
||||
|
||||
for _, row in event_df.iterrows():
|
||||
start = row[start_name]
|
||||
end = row[end_name] + 1
|
||||
event_mask[start:end] = E2N[row[event_type_name]]
|
||||
if with_score:
|
||||
score_mask[start:end] = row["score"]
|
||||
return event_mask, score_mask
|
||||
|
||||
|
||||
def event_mask_2_list(mask, event_true=True):
|
||||
if event_true:
|
||||
event_2_normal = -1
|
||||
normal_2_event = 1
|
||||
_append = 0
|
||||
else:
|
||||
event_2_normal = 1
|
||||
normal_2_event = -1
|
||||
_append = 1
|
||||
mask_start = np.where(np.diff(mask, prepend=_append, append=_append) == normal_2_event)[0]
|
||||
mask_end = np.where(np.diff(mask, prepend=_append, append=_append) == event_2_normal)[0]
|
||||
event_list =[[start, end] for start, end in zip(mask_start, mask_end)]
|
||||
return event_list
|
||||
|
||||
|
||||
def collect_values(arr: np.ndarray, index: int, step: int, limit: int, mask=None) -> list:
|
||||
"""收集非 NaN 值,直到达到指定数量或边界"""
|
||||
values = []
|
||||
count = 0
|
||||
mask = mask if mask is not None else arr
|
||||
while count < limit and 0 <= index < len(mask):
|
||||
if not np.isnan(mask[index]):
|
||||
values.append(arr[index])
|
||||
count += 1
|
||||
index += step
|
||||
return values
|
||||
|
||||
|
||||
def save_process_label(save_path: Path, save_dict: dict):
|
||||
save_df = pd.DataFrame(save_dict)
|
||||
save_df.to_csv(save_path, index=False)
|
||||
|
||||
def none_to_nan_mask(mask, ref):
|
||||
"""将None转换为与ref形状相同的nan掩码"""
|
||||
if mask is None:
|
||||
return np.full_like(ref, np.nan)
|
||||
else:
|
||||
# 将mask中的0替换为nan,其他的保持
|
||||
mask = np.where(mask == 0, np.nan, mask)
|
||||
return mask
|
||||
|
||||
def get_wake_mask(sleep_stage_mask):
|
||||
# 将N1, N2, N3, REM视为睡眠 0,其他为清醒 1
|
||||
# 输入是字符 分别有 'W', 'N1', 'N2', 'N3', 'R' 等
|
||||
wake_mask = np.where(np.isin(sleep_stage_mask, ['N1', 'N2', 'N3', 'REM', 'R']), 0, 1)
|
||||
return wake_mask
|
||||
|
||||
def detect_spo2_anomaly(spo2, fs, diff_thresh=7):
|
||||
anomaly = np.zeros(len(spo2), dtype=bool)
|
||||
|
||||
# 生理范围
|
||||
anomaly |= (spo2 < 50) | (spo2 > 100)
|
||||
|
||||
# 突变
|
||||
diff = np.abs(np.diff(spo2, prepend=spo2[0]))
|
||||
anomaly |= diff > diff_thresh
|
||||
|
||||
# NaN
|
||||
anomaly |= np.isnan(spo2)
|
||||
|
||||
return anomaly
|
||||
|
||||
def merge_close_anomalies(anomaly, fs, min_gap_duration):
|
||||
min_gap = int(min_gap_duration * fs)
|
||||
merged = anomaly.copy()
|
||||
|
||||
i = 0
|
||||
n = len(anomaly)
|
||||
|
||||
while i < n:
|
||||
if not anomaly[i]:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 当前异常段
|
||||
start = i
|
||||
while i < n and anomaly[i]:
|
||||
i += 1
|
||||
end = i
|
||||
|
||||
# 向后看 gap
|
||||
j = end
|
||||
while j < n and not anomaly[j]:
|
||||
j += 1
|
||||
|
||||
if j < n and (j - end) < min_gap:
|
||||
merged[end:j] = True
|
||||
|
||||
return merged
|
||||
|
||||
def fill_spo2_anomaly(
|
||||
spo2_data,
|
||||
spo2_fs,
|
||||
max_fill_duration,
|
||||
min_gap_duration,
|
||||
):
|
||||
spo2 = spo2_data.astype(float).copy()
|
||||
n = len(spo2)
|
||||
|
||||
anomaly = detect_spo2_anomaly(spo2, spo2_fs)
|
||||
anomaly = merge_close_anomalies(anomaly, spo2_fs, min_gap_duration)
|
||||
|
||||
max_len = int(max_fill_duration * spo2_fs)
|
||||
|
||||
valid_mask = ~anomaly
|
||||
|
||||
i = 0
|
||||
while i < n:
|
||||
if not anomaly[i]:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
start = i
|
||||
while i < n and anomaly[i]:
|
||||
i += 1
|
||||
end = i
|
||||
|
||||
seg_len = end - start
|
||||
|
||||
# 超长异常段
|
||||
if seg_len > max_len:
|
||||
spo2[start:end] = np.nan
|
||||
valid_mask[start:end] = False
|
||||
continue
|
||||
|
||||
has_left = start > 0 and valid_mask[start - 1]
|
||||
has_right = end < n and valid_mask[end]
|
||||
|
||||
# 开头异常:单侧填充
|
||||
if not has_left and has_right:
|
||||
spo2[start:end] = spo2[end]
|
||||
continue
|
||||
|
||||
# 结尾异常:单侧填充
|
||||
if has_left and not has_right:
|
||||
spo2[start:end] = spo2[start - 1]
|
||||
continue
|
||||
|
||||
# 两侧都有 → PCHIP
|
||||
if has_left and has_right:
|
||||
x = np.array([start - 1, end])
|
||||
y = np.array([spo2[start - 1], spo2[end]])
|
||||
|
||||
interp = PchipInterpolator(x, y)
|
||||
spo2[start:end] = interp(np.arange(start, end))
|
||||
continue
|
||||
|
||||
# 两侧都没有(极端情况)
|
||||
spo2[start:end] = np.nan
|
||||
valid_mask[start:end] = False
|
||||
|
||||
return spo2, valid_mask
|
||||
108
utils/signal_process.py
Normal file
108
utils/signal_process.py
Normal file
@ -0,0 +1,108 @@
|
||||
from utils.operation_tools import timing_decorator
|
||||
import numpy as np
|
||||
from scipy import signal, ndimage
|
||||
|
||||
|
||||
@timing_decorator()
|
||||
def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10, sample_rate=1000):
|
||||
if _type == "lowpass": # 低通滤波处理
|
||||
sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos')
|
||||
return signal.sosfiltfilt(sos, np.array(data))
|
||||
elif _type == "bandpass": # 带通滤波处理
|
||||
low = low_cut / (sample_rate * 0.5)
|
||||
high = high_cut / (sample_rate * 0.5)
|
||||
sos = signal.butter(order, [low, high], btype='bandpass', output='sos')
|
||||
return signal.sosfiltfilt(sos, np.array(data))
|
||||
elif _type == "highpass": # 高通滤波处理
|
||||
sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos')
|
||||
return signal.sosfiltfilt(sos, np.array(data))
|
||||
else: # 警告,滤波器类型必须有
|
||||
raise ValueError("Please choose a type of fliter")
|
||||
|
||||
|
||||
def bessel(data, _type, low_cut=0.0, high_cut=0.0, order=4, sample_rate=1000):
|
||||
if _type == "lowpass": # 低通滤波处理
|
||||
b, a = signal.bessel(order, low_cut / (sample_rate * 0.5), btype='lowpass', analog=False, norm='mag')
|
||||
return signal.filtfilt(b, a, np.array(data))
|
||||
elif _type == "bandpass": # 带通滤波处理
|
||||
low = low_cut / (sample_rate * 0.5)
|
||||
high = high_cut / (sample_rate * 0.5)
|
||||
b, a = signal.bessel(order, [low, high], btype='bandpass', analog=False, norm='mag')
|
||||
return signal.filtfilt(b, a, np.array(data))
|
||||
elif _type == "highpass": # 高通滤波处理
|
||||
b, a = signal.bessel(order, high_cut / (sample_rate * 0.5), btype='highpass', analog=False, norm='mag')
|
||||
return signal.filtfilt(b, a, np.array(data))
|
||||
else: # 警告,滤波器类型必须有
|
||||
raise ValueError("Please choose a type of fliter")
|
||||
|
||||
|
||||
@timing_decorator()
|
||||
def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000):
|
||||
"""
|
||||
高效整数倍降采样长信号,分段处理以优化内存和速度。
|
||||
|
||||
参数:
|
||||
original_signal : array-like, 原始信号数组
|
||||
original_fs : float, 原始采样率 (Hz)
|
||||
target_fs : float, 目标采样率 (Hz)
|
||||
chunk_size : int, 每段处理的样本数,默认100000
|
||||
|
||||
返回:
|
||||
downsampled_signal : array-like, 降采样后的信号
|
||||
"""
|
||||
# 输入验证
|
||||
if not isinstance(original_signal, np.ndarray):
|
||||
original_signal = np.array(original_signal)
|
||||
if original_fs <= target_fs:
|
||||
raise ValueError("目标采样率必须小于原始采样率")
|
||||
if target_fs <= 0 or original_fs <= 0:
|
||||
raise ValueError("采样率必须为正数")
|
||||
|
||||
# 计算降采样因子(必须为整数)
|
||||
downsample_factor = original_fs / target_fs
|
||||
if not downsample_factor.is_integer():
|
||||
raise ValueError("降采样因子必须为整数倍")
|
||||
downsample_factor = int(downsample_factor)
|
||||
|
||||
# 计算总输出长度
|
||||
total_length = len(original_signal)
|
||||
output_length = total_length // downsample_factor
|
||||
|
||||
# 初始化输出数组
|
||||
downsampled_signal = np.zeros(output_length)
|
||||
|
||||
# 分段处理
|
||||
for start in range(0, total_length, chunk_size):
|
||||
end = min(start + chunk_size, total_length)
|
||||
chunk = original_signal[start:end]
|
||||
|
||||
# 使用decimate进行整数倍降采样
|
||||
chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True)
|
||||
|
||||
# 计算输出位置
|
||||
out_start = start // downsample_factor
|
||||
out_end = out_start + len(chunk_downsampled)
|
||||
if out_end > output_length:
|
||||
chunk_downsampled = chunk_downsampled[:output_length - out_start]
|
||||
|
||||
downsampled_signal[out_start:out_end] = chunk_downsampled
|
||||
|
||||
return downsampled_signal
|
||||
|
||||
|
||||
@timing_decorator()
|
||||
def average_filter(raw_data, sample_rate, window_size_sec=20):
|
||||
kernel = np.ones(window_size_sec * sample_rate) / (window_size_sec * sample_rate)
|
||||
filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect')
|
||||
convolve_filter_signal = raw_data - filtered
|
||||
return convolve_filter_signal
|
||||
|
||||
|
||||
# 陷波滤波器
|
||||
@timing_decorator()
|
||||
def notch_filter(data, notch_freq=50.0, quality_factor=30.0, sample_rate=1000):
|
||||
nyquist = 0.5 * sample_rate
|
||||
norm_notch_freq = notch_freq / nyquist
|
||||
b, a = signal.iirnotch(norm_notch_freq, quality_factor)
|
||||
filtered_data = signal.filtfilt(b, a, data)
|
||||
return filtered_data
|
||||
56
utils/split_method.py
Normal file
56
utils/split_method.py
Normal file
@ -0,0 +1,56 @@
|
||||
|
||||
def check_split(event_mask, current_start, window_sec, verbose=False):
|
||||
# 检查当前窗口是否包含在禁用区间或低幅值区间内
|
||||
resp_movement_mask = event_mask["Resp_Movement_Label"][current_start : current_start + window_sec]
|
||||
resp_low_amp_mask = event_mask["Resp_LowAmp_Label"][current_start : current_start + window_sec]
|
||||
|
||||
# 体动与低幅值进行与计算
|
||||
low_move_mask = resp_movement_mask | resp_low_amp_mask
|
||||
if low_move_mask.sum() > 2/3 * window_sec:
|
||||
if verbose:
|
||||
print(f"{current_start}-{current_start + window_sec} rejected due to movement/low amplitude mask more than 2/3")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def resp_split(dataset_config, event_mask, event_list, verbose=False):
|
||||
# 提取体动区间和呼吸低幅值区间
|
||||
enable_list = event_list["EnableSegment"]
|
||||
disable_list = event_list["DisableSegment"]
|
||||
|
||||
# 读取数据集配置
|
||||
window_sec = dataset_config["window_sec"]
|
||||
stride_sec = dataset_config["stride_sec"]
|
||||
|
||||
segment_list = []
|
||||
disable_segment_list = []
|
||||
|
||||
# 遍历每个enable区间, 如果最后一个窗口不足stride的1/2,则舍弃,否则以enable_end为结尾截取一个窗口
|
||||
for enable_start, enable_end in enable_list:
|
||||
current_start = enable_start
|
||||
while current_start + window_sec <= enable_end:
|
||||
if check_split(event_mask, current_start, window_sec, verbose):
|
||||
segment_list.append((current_start, current_start + window_sec))
|
||||
else:
|
||||
disable_segment_list.append((current_start, current_start + window_sec))
|
||||
current_start += stride_sec
|
||||
# 检查最后一个窗口是否需要添加
|
||||
if (enable_end - current_start >= stride_sec / 2) and (enable_end - current_start >= window_sec):
|
||||
if check_split(event_mask, enable_end - window_sec, window_sec, verbose):
|
||||
segment_list.append((enable_end - window_sec, enable_end))
|
||||
else:
|
||||
disable_segment_list.append((enable_end - window_sec, enable_end))
|
||||
|
||||
# 遍历每个disable区间, 如果最后一个窗口不足stride的1/2,则舍弃,否则以disable_end为结尾截取一个窗口
|
||||
for disable_start, disable_end in disable_list:
|
||||
current_start = disable_start
|
||||
while current_start + window_sec <= disable_end:
|
||||
disable_segment_list.append((current_start, current_start + window_sec))
|
||||
current_start += stride_sec
|
||||
# 检查最后一个窗口是否需要添加
|
||||
if (disable_end - current_start >= stride_sec / 2) and (disable_end - current_start >= window_sec):
|
||||
disable_segment_list.append((disable_end - window_sec, disable_end))
|
||||
|
||||
|
||||
return segment_list, disable_segment_list
|
||||
|
||||
105
utils/statistics_metrics.py
Normal file
105
utils/statistics_metrics.py
Normal file
@ -0,0 +1,105 @@
|
||||
from utils.operation_tools import timing_decorator
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
@timing_decorator()
|
||||
def statistic_amplitude_metrics(data, aml_interval=None, time_interval=None):
|
||||
"""
|
||||
计算不同幅值区间占比和时间,最后汇总成混淆矩阵
|
||||
|
||||
参数:
|
||||
data: 采样率为1秒的一维序列,其中体动所在的区域用np.nan填充
|
||||
aml_interval: 幅值区间的分界点列表,默认为[200, 500, 1000, 2000]
|
||||
time_interval: 时间区间的分界点列表,单位为秒,默认为[60, 300, 1800, 3600]
|
||||
|
||||
返回:
|
||||
confusion_matrix: 幅值-时长统计矩阵
|
||||
summary: 汇总统计信息
|
||||
"""
|
||||
if aml_interval is None:
|
||||
aml_interval = [200, 500, 1000, 2000]
|
||||
|
||||
if time_interval is None:
|
||||
time_interval = [60, 300, 1800, 3600]
|
||||
# 检查输入
|
||||
if not isinstance(data, np.ndarray):
|
||||
data = np.array(data)
|
||||
|
||||
# 整个记录的时长(包括nan)
|
||||
total_duration = len(data)
|
||||
|
||||
# 创建幅值标签和时间标签
|
||||
amp_labels = [f"0-{aml_interval[0]}"]
|
||||
for i in range(len(aml_interval) - 1):
|
||||
amp_labels.append(f"{aml_interval[i]}-{aml_interval[i + 1]}")
|
||||
amp_labels.append(f"{aml_interval[-1]}+")
|
||||
|
||||
time_labels = [f"0-{time_interval[0]}"]
|
||||
for i in range(len(time_interval) - 1):
|
||||
time_labels.append(f"{time_interval[i]}-{time_interval[i + 1]}")
|
||||
time_labels.append(f"{time_interval[-1]}+")
|
||||
|
||||
# 初始化结果矩阵(时长)和片段数矩阵
|
||||
result_matrix = np.zeros((len(amp_labels), len(time_labels))) # 时长矩阵
|
||||
segment_count_matrix = np.zeros((len(amp_labels), len(time_labels))) # 片段数矩阵
|
||||
|
||||
# 有效信号总量(非NaN的数据点数量)
|
||||
valid_signal_length = np.sum(~np.isnan(data))
|
||||
|
||||
# 添加信号开始和结束的边界条件
|
||||
signal_padded = np.concatenate(([np.nan], data, [np.nan]))
|
||||
diff = np.diff(np.isnan(signal_padded).astype(int))
|
||||
|
||||
# 连续片段的起始位置(从 nan 变为非 nan)
|
||||
segment_starts = np.where(diff == -1)[0]
|
||||
# 连续片段的结束位置(从非 nan 变为 nan)
|
||||
segment_ends = np.where(diff == 1)[0]
|
||||
|
||||
# 计算每个片段的时长和平均幅值,并填充结果矩阵
|
||||
for start, end in zip(segment_starts, segment_ends):
|
||||
segment = data[start:end]
|
||||
duration = end - start # 时长(单位:秒)
|
||||
mean_amplitude = np.nanmean(segment) # 片段平均幅值
|
||||
|
||||
# 确定幅值区间
|
||||
if mean_amplitude <= aml_interval[0]:
|
||||
amp_idx = 0
|
||||
elif mean_amplitude > aml_interval[-1]:
|
||||
amp_idx = len(aml_interval)
|
||||
else:
|
||||
amp_idx = np.searchsorted(aml_interval, mean_amplitude)
|
||||
|
||||
# 确定时长区间
|
||||
if duration <= time_interval[0]:
|
||||
time_idx = 0
|
||||
elif duration > time_interval[-1]:
|
||||
time_idx = len(time_interval)
|
||||
else:
|
||||
time_idx = np.searchsorted(time_interval, duration)
|
||||
|
||||
# 在对应位置累加该片段的时长和片段数
|
||||
result_matrix[amp_idx, time_idx] += duration
|
||||
segment_count_matrix[amp_idx, time_idx] += 1 # 片段数加1
|
||||
|
||||
# 创建DataFrame以便于展示和后续处理
|
||||
confusion_matrix = pd.DataFrame(result_matrix, index=amp_labels, columns=time_labels)
|
||||
|
||||
# 计算行和列的总和
|
||||
confusion_matrix['总计'] = confusion_matrix.sum(axis=1)
|
||||
row_totals = confusion_matrix['总计'].copy()
|
||||
|
||||
# 计算百分比(相对于有效记录时长)
|
||||
confusion_matrix_percent = confusion_matrix.div(total_duration) * 100
|
||||
|
||||
# 汇总统计
|
||||
summary = {
|
||||
'total_duration': total_duration,
|
||||
'total_valid_signal': valid_signal_length,
|
||||
'amplitude_distribution': row_totals.to_dict(),
|
||||
'amplitude_percent': row_totals.div(total_duration) * 100,
|
||||
'time_distribution': confusion_matrix.sum(axis=0).to_dict(),
|
||||
'time_percent': confusion_matrix.sum(axis=0).div(total_duration) * 100
|
||||
}
|
||||
|
||||
return summary, (confusion_matrix, segment_count_matrix, confusion_matrix_percent, valid_signal_length,
|
||||
total_duration, time_labels, amp_labels)
|
||||
Loading…
Reference in New Issue
Block a user