feat: Add utility functions for signal processing and event mapping

- Created a new module `utils/__init__.py` to consolidate utility imports.
- Added `event_map.py` for mapping apnea event types to numerical values and colors.
- Implemented various filtering functions in `filter_func.py`, including Butterworth, Bessel, downsampling, and notch filters.
- Developed `operation_tools.py` for dataset configuration loading, event mask generation, and signal processing utilities.
- Introduced `split_method.py` for segmenting data based on movement and amplitude criteria.
- Added `statistics_metrics.py` for calculating amplitude metrics and generating confusion matrices.
- Included a new Excel file for additional data storage.
This commit is contained in:
marques 2026-03-24 21:15:05 +08:00
commit 8ee5980906
35 changed files with 5251 additions and 0 deletions

470
AGENT_BACKGROUND_CN.md Normal file
View File

@ -0,0 +1,470 @@
# 项目背景文档(供 Agent 快速上手)
## 1. 文档目的
这份文档用于帮助新进入仓库的 agent 快速建立项目心智模型,重点回答下面几个问题:
- 这个仓库在做什么
- 核心处理链路是什么
- 代码从哪里开始看
- 输入数据长什么样,输出产物长什么样
- 哪些模块已经可用,哪些模块还只是占位或半成品
---
## 2. 一句话理解项目
`DataPrepare` 是一个睡眠呼吸暂停数据预处理仓库,主要用于把原始的 BCG / PSG 信号与呼吸暂停标签整理成可训练的数据集。
当前主线可以概括为两步:
1. 在 `event_mask_process/` 中把原始标注和规则检测结果整理成逐秒标签掩码。
2. 在 `dataset_builder/` 中根据这些掩码切出固定长度窗口,并将信号和窗口索引保存为 `npz` 数据集。
---
## 3. 仓库总览
| 目录 | 作用 | 备注 |
| --- | --- | --- |
| `dataset_config/` | 各数据集的 YAML 配置 | 包含路径、采样率、阈值、切窗参数 |
| `event_mask_process/` | 生成逐秒标签掩码 | 主入口之一 |
| `dataset_builder/` | 将信号切成窗口并保存数据集 | 主入口之二 |
| `signal_method/` | 信号预处理、规则检测、特征计算 | 规则逻辑核心 |
| `utils/` | 读文件、标签转换、切窗、滤波、通用工具 | 底层支撑 |
| `draw_tools/` | 全夜信号图、分段图、统计图 | 调试与可视化 |
| `dataset_tools/` | 辅助脚本 | 包括配对拷贝、SHHS 标注检查 |
| `output/` | 中间结果样例 | 仓库内已有若干处理结果示例 |
---
## 4. 项目主线流程
### 4.1 HYSBCG 主线)
对应文件:
- `event_mask_process/HYS_process.py`
- `dataset_builder/HYS_dataset.py`
处理过程:
1. 从 `root_path/OrgBCG_Aligned/<id>/OrgBCG_Sync_*.txt` 读取原始 BCG 同步信号。
2. 从 `root_path/Label/<id>/SA Label_corrected.csv` 读取人工修正后的呼吸暂停标签。
3. 对原始 BCG 进行 50Hz 陷波,再拆出:
- 呼吸分量 `resp_data`
- BCG 分量 `bcg_data`
4. 根据配置做规则检测,生成逐秒掩码:
- `Resp_LowAmp_Label`
- `Resp_Movement_Label`
- `Resp_AmpChange_Label`
- `BCG_LowAmp_Label`
- `BCG_Movement_Label`
- `BCG_AmpChange_Label`
- `Disable_Label`(来自 `排除区间.xlsx`
- `SA_Label` / `SA_Score`(来自人工标签)
5. 将这些结果保存为 `output/HYS/<id>/<id>_Processed_Labels.csv`
6. 在数据集构建阶段读取上述逐秒标签,按 `window_sec``stride_sec` 切成窗口。
7. 保存处理后的信号 `Signals/*.npz`、窗口索引 `Segments_List/*.npz`、标签副本 `Labels/*.csv`
### 4.2 HYS_PSGPSG 主线)
对应文件:
- `event_mask_process/HYS_PSG_process.py`
- `dataset_builder/HYS_PSG_dataset.py`
处理过程:
1. 从 `root_path/PSG_Aligned/<id>/` 读取 PSG 通道,包括:
- `Rpeak`
- `ECG_Sync`
- `Effort Tho`
- `Effort Abd`
- `Flow P`
- `Flow T`
- `SpO2`
- `5_class`
2. 从同目录读取 `SA Label_Sync.csv`
3. 在 `HYS_PSG_process.py` 中生成较简化的逐秒标签:
- `SA_Label`
- `SA_Score`
- `Disable_Label`(主要由睡眠分期里的清醒期生成)
- 其余 `Resp_*` / `BCG_*` 掩码目前全部置零
4. 在 `HYS_PSG_dataset.py` 中对胸腹带、流量、SpO2、RRI 做统一重采样与长度对齐。
5. 对 SpO2 进行异常填补或置空,再切成窗口,保存为 PSG 数据集。
### 4.3 设计上的共性
无论是 HYS 还是 HYS_PSG核心设计都是
1. 先把整夜记录整理成逐秒掩码。
2. 再根据掩码和切窗规则提取训练片段。
这意味着以后如果要改“可用性判断”或“切窗逻辑”,优先看:
- `event_mask_process/`
- `utils/split_method.py`
---
## 5. 数据与目录约定
### 5.1 外部数据目录
这个仓库本身不包含原始数据,真正的数据目录由 YAML 中的绝对路径指定,例如:
- `/mnt/disk_wd/marques_dataset/DataCombine2023/HYS`
- `/mnt/disk_wd/marques_dataset/shhs/polysomnography/shhs1`
因此在新环境运行时,第一件事通常不是改代码,而是先改 `dataset_config/*.yaml` 里的绝对路径。
### 5.2 HYS 数据目录约定
代码默认外部目录至少包含:
- `OrgBCG_Aligned/<id>/OrgBCG_Sync_<fs>.txt`
- `PSG_Aligned/<id>/...`
- `Label/<id>/SA Label_corrected.csv`
### 5.3 HYS_PSG 数据目录约定
代码默认 `PSG_Aligned/<id>/` 内文件命名符合下面的模式:
- `Rpeak*.txt`
- `ECG_Sync*.txt`
- `Effort Tho*.txt`
- `Effort Abd*.txt`
- `Flow P*.txt`
- `Flow T*.txt`
- `SpO2*.txt`
- `5_class*.txt`
- `SA Label_Sync.csv`
### 5.4 处理中间结果目录
仓库内 `output/` 保存的是“中间标签与图像结果”,不是最终训练数据集。
最终数据集通常会写到 YAML 中 `dataset_save_path` 指向的外部目录。
---
## 6. 核心入口脚本
### 6.1 主入口
| 场景 | 脚本 |
| --- | --- |
| HYS 逐秒标签生成 | `event_mask_process/HYS_process.py` |
| HYS 数据集构建 | `dataset_builder/HYS_dataset.py` |
| HYS_PSG 逐秒标签生成 | `event_mask_process/HYS_PSG_process.py` |
| HYS_PSG 数据集构建 | `dataset_builder/HYS_PSG_dataset.py` |
### 6.2 辅助脚本
| 脚本 | 作用 |
| --- | --- |
| `dataset_tools/resp_pair_copy.py` | 将 HYS/ZD5Y 的 BCG 与 PSG 原始文件拷贝为配对数据集 |
| `dataset_tools/shhs_annotations_check.py` | 统计 SHHS XML 标注中的事件类型组合 |
| `event_mask_process/SHHS1_process.py` | SHHS1 处理入口占位,目前未实现 |
### 6.3 入口脚本的共同特点
- 大多数脚本没有命令行参数接口。
- 配置文件路径通常直接写在 `if __name__ == '__main__':` 中。
- 运行逻辑依赖 YAML 中的绝对路径。
所以如果 agent 要“跑脚本”,通常需要先确认:
1. 当前机器是否存在对应外部数据目录。
2. YAML 路径是否匹配当前环境。
---
## 7. 关键模块说明
### 7.1 `utils/`
#### `utils/HYS_FileReader.py`
负责各种输入文件读取,是仓库最重要的底层模块之一:
- `read_signal_txt`:读取单通道 txt并根据文件名中的采样率推断 `fs`
- `read_label_csv`:读取人工修正的 HYS 标签
- `read_raw_psg_label`:读取 PSG 原始同步标签
- `read_disable_excel`:读取 `排除区间.xlsx`
- `read_mask_execl`:读取处理后的逐秒标签 CSV并生成事件片段列表
- `read_psg_channel`:按通道名读取 PSG 文件夹里的多通道数据
#### `utils/operation_tools.py`
负责标签转换、片段提取和通用处理:
- `load_dataset_conf`:读取 YAML
- `generate_event_mask`:把事件表转换成逐秒 `SA_Label`
- `generate_disable_mask`:把 Excel 中的排除区间转换成逐秒 `Disable_Label`
- `event_mask_2_list`:把 0/1 掩码转为 `[start, end]` 列表
- `merge_short_gaps` / `remove_short_durations`:对逐秒掩码做时长后处理
- `fill_spo2_anomaly`:修补 SpO2 异常段
#### `utils/split_method.py`
真正的切窗规则在这里:
- 默认按 `window_sec` / `stride_sec` 滑窗
- 只在 `EnableSegment` 内生成可用窗口
- 如果一个窗口中 `Resp_Movement_Label | Resp_LowAmp_Label` 超过窗口时长的 2/3则该窗口转入 `disable_segment_list`
#### `utils/filter_func.py`
提供滤波和采样率处理:
- Butterworth / Bessel
- 陷波
- 整数倍降采样
- 自动升降采样
- 移动平均去趋势
### 7.2 `signal_method/`
#### `signal_method/signal_process.py`
负责信号预处理:
- `signal_filter_split`:把原始 OrgBCG 信号拆成呼吸分量和 BCG 分量
- `psg_effort_filter`:处理 PSG 努力带 / 流量信号
- `rpeak2hr` / `rpeak2rri_interpolation`:由 R 峰生成 HR / RRI
#### `signal_method/rule_base_event.py`
规则检测主逻辑:
- `detect_movement`:基于滑窗标准差和局部幅值比较检测体动
- `movement_revise`:对体动掩码做二次修正
- `detect_low_amplitude_signal`:检测低幅值
- `position_based_sleep_recognition_v2/v3`:根据体动前后幅值变化标记姿势/幅值变化段
#### `signal_method/normalize_method.py`
`normalize_resp_signal_by_segment` 会按“可用片段”做分段 z-score 标准化。
HYS 中通常按 `Resp_AmpChange_Label` 的反向片段来归一化,目的是减少整夜幅值漂移带来的影响。
### 7.3 `draw_tools/`
主要用于人工检查处理质量:
- `draw_signal_with_mask`:画 HYS 全夜原始信号 + 规则掩码
- `draw_psg_signal`:画 HYS_PSG 全夜 PSG 信号
- `draw_psg_label` / `draw_psg_bcg_label`:按窗口导出分段图
### 7.4 `dataset_builder/`
核心职责是把“整夜记录 + 逐秒掩码”转换成训练样本:
- 保存处理后的多通道信号到 `npz`
- 保存窗口起止列表到 `npz`
- 把标签 CSV 一并拷贝到数据集目录
---
## 8. 标签、通道和编码约定
### 8.1 呼吸暂停事件编码
定义于 `utils/event_map.py`
| 事件 | 编码 |
| --- | --- |
| `Hypopnea` | 1 |
| `Central apnea` | 2 |
| `Obstructive apnea` | 3 |
| `Mixed apnea` | 4 |
### 8.2 PSG 通道编号映射
同样定义于 `utils/event_map.py`
| 编号 | 通道名 |
| --- | --- |
| 1 | `Rpeak` |
| 2 | `ECG_Sync` |
| 3 | `Effort Tho` |
| 4 | `Effort Abd` |
| 5 | `Flow P` |
| 6 | `Flow T` |
| 7 | `SpO2` |
| 8 | `5_class` |
### 8.3 睡眠分期编码
`5_class` 在读取时会转成整数:
| 分期 | 编码 |
| --- | --- |
| `N3` | 1 |
| `N2` | 2 |
| `N1` | 3 |
| `R` | 4 |
| `W` | 5 |
### 8.4 逐秒标签 CSV 字段
`Processed_Labels.csv` 的核心字段为:
- `Second`
- `SA_Label`
- `SA_Score`
- `Disable_Label`
- `Resp_LowAmp_Label`
- `Resp_Movement_Label`
- `Resp_AmpChange_Label`
- `BCG_LowAmp_Label`
- `BCG_Movement_Label`
- `BCG_AmpChange_Label`
样例可见仓库内:
- `output/HYS/220/220_Processed_Labels.csv`
- `output/HYS_PSG/220/220_Processed_Labels.csv`
---
## 9. 主要配置文件
### 9.1 HYS
`dataset_config/HYS_config.yaml` 主要控制:
- 样本 ID 列表
- 原始数据根目录
- 中间标签保存目录
- 呼吸 / BCG 的滤波和降采样参数
- 低幅值、体动、幅值变化检测阈值
- 数据集窗口长度和步长
### 9.2 HYS_PSG
`dataset_config/HYS_PSG_config.yaml` 主要控制:
- 样本 ID 列表
- 目标统一采样率 `target_fs`
- 努力带 / 流量滤波参数
- SpO2 异常填补参数
- 数据集输出路径
### 9.3 其他配置
- `dataset_config/ZD5Y_config.yaml`:另一套 BCG 规则配置
- `dataset_config/SHHS1_config.yaml`SHHS1 预留配置
- `dataset_config/RESP_PAIR_HYS_config.yaml`HYS 配对原始数据拷贝
- `dataset_config/RESP_PAIR_ZD5Y_config.yaml`ZD5Y 配对原始数据拷贝
---
## 10. 输出产物说明
### 10.1 中间输出
典型位置:
- `output/HYS/<id>/<id>_Processed_Labels.csv`
- `output/HYS/<id>/<id>_Signal_Plots.png`
- `output/HYS_PSG/<id>/<id>_Processed_Labels.csv`
- `output/HYS_PSG/<id>/<id>_Signal_Plots.png`
- `output/HYS_PSG/<id>/<id>_Signal_Plots_fill.png`
### 10.2 最终数据集输出
`dataset_builder/*` 保存到 YAML 指定的外部目录,结构通常是:
- `Signals/`
- `Segments_List/`
- `Labels/`
HYS 的 `Signals/*.npz` 里主要保存:
- `bcg_signal_notch`
- `bcg_signal`
- `resp_signal`
HYS_PSG 的 `Signals/*.npz` 里主要保存:
- `Effort Tho`
- `Effort Abd`
- `Effort`
- `Flow P`
- `Flow T`
- `SpO2`
- `HR`
- `RRI`
- `5_class`
`Segments_List/*.npz` 中主要保存:
- `segment_list`
- `disable_segment_list`
---
## 11. 推荐阅读顺序
如果 agent 是第一次接触这个仓库,建议按下面顺序阅读:
1. `dataset_config/HYS_config.yaml`
2. `event_mask_process/HYS_process.py`
3. `utils/operation_tools.py`
4. `utils/split_method.py`
5. `dataset_builder/HYS_dataset.py`
6. `signal_method/signal_process.py`
7. `signal_method/rule_base_event.py`
8. `utils/HYS_FileReader.py`
如果要看 PSG 主线,再继续读:
1. `dataset_config/HYS_PSG_config.yaml`
2. `event_mask_process/HYS_PSG_process.py`
3. `dataset_builder/HYS_PSG_dataset.py`
4. `draw_tools/draw_label.py`
---
## 12. 常见修改入口
如果你要改不同类型的问题,可以优先从这些文件入手:
| 需求 | 优先看哪里 |
| --- | --- |
| 调整阈值、采样率、窗口长度 | `dataset_config/*.yaml` |
| 修改逐秒标签生成逻辑 | `event_mask_process/*.py` |
| 修改体动 / 低幅值 / 幅值变化规则 | `signal_method/rule_base_event.py` |
| 修改滤波与重采样 | `signal_method/signal_process.py`、`utils/filter_func.py` |
| 修改切窗规则 | `utils/split_method.py` |
| 修改输入文件解析 | `utils/HYS_FileReader.py` |
| 修改事件编码或通道映射 | `utils/event_map.py` |
| 修改图像输出样式 | `draw_tools/*.py` |
---
## 13. 当前实现状态与注意事项
下面这些点对 agent 很重要:
1. 仓库没有依赖清单文件(如 `requirements.txt` / `pyproject.toml`),依赖需要从源码导入中反推,当前至少涉及 `numpy`、`pandas`、`scipy`、`matplotlib`、`seaborn`、`yaml`、`tqdm`、`rich`、`polars`、`lxml`、`mne`。
2. 大部分脚本依赖绝对路径,迁移环境时优先修改 YAML。
3. `event_mask_process/SHHS1_process.py` 目前基本为空,占位多于实现。
4. `event_mask_process/HYS_PSG_process.py` 当前是“简化版标签生成”,核心只用了 SA 标签和睡眠分期,`Resp_*` / `BCG_*` 掩码还没有真正实现。
5. `dataset_builder/HYS_dataset.py` 里虽然保留了分段可视化入口,但实际绘图调用被注释掉了,因此默认不会导出分段图。
6. `utils/signal_process.py``utils/filter_func.py` 有部分重复实现;当前实际被 `utils/__init__.py` 导出并广泛使用的是 `filter_func.py` 中的版本。
7. `README.md` 目前非常简略,真正的项目逻辑主要还是要靠源码理解。
---
## 14. 对 Agent 最有价值的结论
如果只能记住几件事,请记住下面这些:
1. 这是一个“先做逐秒掩码,再做固定窗口切片”的数据准备仓库。
2. HYS 主线比 HYS_PSG 更完整,规则检测主要服务于 HYS。
3. 切窗逻辑集中在 `utils/split_method.py`,不是分散在各个 builder 里。
4. 原始数据不在仓库里,仓库只是代码和部分中间结果样例。
5. 大多数运行问题都不是代码逻辑错,而是路径、文件命名、采样率和外部数据结构不匹配。

191
README.md Normal file
View File

@ -0,0 +1,191 @@
# DataPrepare
`DataPrepare` 是一个面向睡眠呼吸暂停数据的预处理仓库,用于把原始 BCG / PSG 信号与事件标签整理成可训练的数据集。
当前仓库的主线是:
1. 生成逐秒标签掩码
2. 按固定窗口切分数据集
3. 输出可视化结果用于人工检查
更详细的项目背景说明见 [AGENT_BACKGROUND_CN.md](./AGENT_BACKGROUND_CN.md)。
## 项目在做什么
仓库主要服务两类数据:
- `HYS`:以 `OrgBCG` 为核心结合人工修正的呼吸暂停标签提取呼吸分量、BCG 分量并生成逐秒可用性掩码
- `HYS_PSG`:以 PSG 多通道信号为核心整理胸腹带、流量、SpO2、RRI、睡眠分期与同步呼吸暂停标签
两条主线都采用同一个设计:
1. 先把整夜记录整理成逐秒标签
2. 再根据标签切出固定长度窗口
## 仓库结构
| 目录 | 作用 |
| --- | --- |
| `dataset_config/` | 数据集配置文件,包含路径、采样率、阈值、切窗参数 |
| `event_mask_process/` | 逐秒标签掩码生成脚本 |
| `dataset_builder/` | 数据集切片与保存脚本 |
| `signal_method/` | 信号预处理、规则检测、特征计算 |
| `utils/` | 文件读取、标签转换、切窗、滤波等通用工具 |
| `draw_tools/` | 全夜图、分段图和统计图绘制 |
| `dataset_tools/` | 辅助脚本如配对数据拷贝、SHHS 标注检查 |
| `output/` | 仓库内的中间结果样例 |
## 核心流程
### HYS
对应脚本:
- `event_mask_process/HYS_process.py`
- `dataset_builder/HYS_dataset.py`
流程概要:
1. 读取 `OrgBCG_Aligned/<id>/OrgBCG_Sync_*.txt`
2. 读取 `Label/<id>/SA Label_corrected.csv`
3. 对原始信号做陷波、呼吸分量提取和 BCG 分量提取
4. 检测低幅值、体动、幅值变化,并结合 `排除区间.xlsx` 生成逐秒标签
5. 保存 `Processed_Labels.csv`
6. 根据 `window_sec``stride_sec` 切分训练窗口并保存 `npz`
### HYS_PSG
对应脚本:
- `event_mask_process/HYS_PSG_process.py`
- `dataset_builder/HYS_PSG_dataset.py`
流程概要:
1. 读取 `PSG_Aligned/<id>/` 下的多通道信号
2. 读取 `SA Label_Sync.csv`
3. 根据同步标签与睡眠分期生成逐秒标签
4. 对努力带、流量、SpO2、RRI 做统一重采样和长度对齐
5. 切分窗口并保存 PSG 数据集
## 快速开始
### 1. 先检查配置
所有主脚本都依赖 `dataset_config/*.yaml` 中的绝对路径。
在新机器或新数据目录下,优先修改这些配置文件:
- `dataset_config/HYS_config.yaml`
- `dataset_config/HYS_PSG_config.yaml`
- `dataset_config/ZD5Y_config.yaml`
- `dataset_config/SHHS1_config.yaml`
重点检查:
- `root_path`
- `mask_save_path``save_path`
- `dataset_save_path`
- `dataset_visual_path`
- `select_ids`
### 2. 生成逐秒标签
HYS
```bash
python event_mask_process/HYS_process.py
```
HYS_PSG
```bash
python event_mask_process/HYS_PSG_process.py
```
执行后通常会在 `output/` 下生成:
- `*_Processed_Labels.csv`
- `*_Signal_Plots.png`
### 3. 构建数据集
HYS
```bash
python dataset_builder/HYS_dataset.py
```
HYS_PSG
```bash
python dataset_builder/HYS_PSG_dataset.py
```
输出目录由对应 YAML 中的 `dataset_save_path` 控制,通常包含:
- `Signals/`
- `Segments_List/`
- `Labels/`
### 4. 可视化与辅助工具
配对拷贝原始数据:
```bash
python dataset_tools/resp_pair_copy.py
```
检查 SHHS XML 标注:
```bash
python dataset_tools/shhs_annotations_check.py
```
## 输入与输出约定
### 输入
仓库本身不包含原始数据,原始数据目录由 YAML 指定。代码默认外部数据目录中至少存在:
- `OrgBCG_Aligned/<id>/OrgBCG_Sync_*.txt`
- `PSG_Aligned/<id>/...`
- `Label/<id>/SA Label_corrected.csv`
- `PSG_Aligned/<id>/SA Label_Sync.csv`
### 中间输出
仓库内 `output/` 保存的是中间标签与图像样例,不是最终训练数据集。例如:
- `output/HYS/<id>/<id>_Processed_Labels.csv`
- `output/HYS_PSG/<id>/<id>_Processed_Labels.csv`
### 最终输出
`dataset_builder/` 写入 YAML 中配置的外部目录。典型结构为:
- `Signals/*.npz`
- `Segments_List/*.npz`
- `Labels/*.csv`
## 关键文件
如果你是第一次阅读这个仓库,推荐优先看:
1. `dataset_config/HYS_config.yaml`
2. `event_mask_process/HYS_process.py`
3. `utils/operation_tools.py`
4. `utils/split_method.py`
5. `dataset_builder/HYS_dataset.py`
6. `signal_method/rule_base_event.py`
## 当前状态与注意事项
- 仓库目前没有依赖清单文件,常见依赖包括 `numpy`、`pandas`、`scipy`、`matplotlib`、`seaborn`、`yaml`、`tqdm`、`rich`、`polars`、`lxml`、`mne`
- 大多数脚本没有命令行参数接口,配置文件路径直接写在脚本 `__main__`
- `event_mask_process/SHHS1_process.py` 目前基本还是占位
- `event_mask_process/HYS_PSG_process.py` 当前实现偏简化,`Resp_*` / `BCG_*` 掩码尚未真正展开
- `output/` 里的文件更适合拿来理解格式与结果,不代表完整数据集
## 相关文档
- 详细项目背景:[AGENT_BACKGROUND_CN.md](./AGENT_BACKGROUND_CN.md)

View File

@ -0,0 +1,395 @@
import multiprocessing
import sys
from pathlib import Path
import os
import numpy as np
from utils import N2Chn
os.environ['DISPLAY'] = "localhost:10.0"
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import utils
import signal_method
import draw_tools
import shutil
import gc
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None):
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
total_seconds = min(
psg_data[i]["second"] for i in N2Chn.values() if i != "Rpeak"
)
for i in N2Chn.values():
if i == "Rpeak":
continue
length = int(total_seconds * psg_data[i]["fs"])
psg_data[i]["data"] = psg_data[i]["data"][:length]
psg_data[i]["length"] = length
psg_data[i]["second"] = total_seconds
psg_data["HR"] = {
"name": "HR",
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"],
psg_data["Rpeak"]["fs"]),
"fs": psg_data["ECG_Sync"]["fs"],
"length": psg_data["ECG_Sync"]["length"],
"second": psg_data["ECG_Sync"]["second"]
}
# 预处理与滤波
tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Tho"]["data"], effort_fs=psg_data["Effort Tho"]["fs"])
abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Abd"]["data"], effort_fs=psg_data["Effort Abd"]["fs"])
flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow P"]["data"], effort_fs=psg_data["Flow P"]["fs"])
flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow T"]["data"], effort_fs=psg_data["Flow T"]["fs"])
rri, rri_fs = signal_method.rpeak2rri_interpolation(rpeak_indices=psg_data["Rpeak"]["data"], ecg_fs=psg_data["ECG_Sync"]["fs"], rri_fs=100)
mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv")
if verbose:
print(f"mask_excel_path: {mask_excel_path}")
event_mask, event_list = utils.read_mask_execl(mask_excel_path)
enable_list = [[0, psg_data["Effort Tho"]["second"]]]
normalized_tho_signal = signal_method.normalize_resp_signal_by_segment(tho_data_filt, tho_fs, np.zeros(psg_data["Effort Tho"]["second"]), enable_list)
normalized_abd_signal = signal_method.normalize_resp_signal_by_segment(abd_data_filt, abd_fs, np.zeros(psg_data["Effort Abd"]["second"]), enable_list)
normalized_flowp_signal = signal_method.normalize_resp_signal_by_segment(flowp_data_filt, flowp_fs, np.zeros(psg_data["Flow P"]["second"]), enable_list)
normalized_flowt_signal = signal_method.normalize_resp_signal_by_segment(flowt_data_filt, flowt_fs, np.zeros(psg_data["Flow T"]["second"]), enable_list)
# 都调整至100Hz采样率
target_fs = conf["target_fs"]
normalized_tho_signal = utils.adjust_sample_rate(normalized_tho_signal, tho_fs, target_fs)
normalized_abd_signal = utils.adjust_sample_rate(normalized_abd_signal, abd_fs, target_fs)
normalized_flowp_signal = utils.adjust_sample_rate(normalized_flowp_signal, flowp_fs, target_fs)
normalized_flowt_signal = utils.adjust_sample_rate(normalized_flowt_signal, flowt_fs, target_fs)
spo2_data_filt = utils.adjust_sample_rate(psg_data["SpO2"]["data"], psg_data["SpO2"]["fs"], target_fs)
normalized_effort_signal = (normalized_tho_signal + normalized_abd_signal) / 2
rri = utils.adjust_sample_rate(rri, rri_fs, target_fs)
# 调整至相同长度
min_length = min(len(normalized_tho_signal), len(normalized_abd_signal), len(normalized_flowp_signal), len(normalized_flowt_signal), len(spo2_data_filt), len(normalized_effort_signal)
,len(rri))
min_length = min_length - min_length % target_fs # 保证是整数秒
normalized_tho_signal = normalized_tho_signal[:min_length]
normalized_abd_signal = normalized_abd_signal[:min_length]
normalized_flowp_signal = normalized_flowp_signal[:min_length]
normalized_flowt_signal = normalized_flowt_signal[:min_length]
spo2_data_filt = spo2_data_filt[:min_length]
normalized_effort_signal = normalized_effort_signal[:min_length]
rri = rri[:min_length]
tho_second = min_length / target_fs
for i in event_mask.keys():
event_mask[i] = event_mask[i][:int(tho_second)]
spo2_data_filt_fill, spo2_disable_mask = utils.fill_spo2_anomaly(spo2_data=spo2_data_filt,
spo2_fs=target_fs,
max_fill_duration=conf["spo2_fill__anomaly"]["max_fill_duration"],
min_gap_duration=conf["spo2_fill__anomaly"]["min_gap_duration"])
draw_tools.draw_psg_signal(
samp_id=samp_id,
tho_signal=normalized_tho_signal,
abd_signal=normalized_abd_signal,
flowp_signal=normalized_flowp_signal,
flowt_signal=normalized_flowt_signal,
spo2_signal=spo2_data_filt,
effort_signal=normalized_effort_signal,
rri_signal = rri,
fs=target_fs,
event_mask=event_mask["SA_Label"],
save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots.png",
show=show
)
draw_tools.draw_psg_signal(
samp_id=samp_id,
tho_signal=normalized_tho_signal,
abd_signal=normalized_abd_signal,
flowp_signal=normalized_flowp_signal,
flowt_signal=normalized_flowt_signal,
spo2_signal=spo2_data_filt_fill,
effort_signal=normalized_effort_signal,
rri_signal = rri,
fs=target_fs,
event_mask=event_mask["SA_Label"],
save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots_fill.png",
show=show
)
spo2_disable_mask = spo2_disable_mask[::target_fs]
min_len = min(len(event_mask["Disable_Label"]), len(spo2_disable_mask))
if len(event_mask["Disable_Label"]) != len(spo2_disable_mask):
print(f"Warning: Data length mismatch! Truncating to {min_len}.")
event_mask["Disable_Label"] = event_mask["Disable_Label"][:min_len] & spo2_disable_mask[:min_len]
event_list = {
"EnableSegment": utils.event_mask_2_list(1 - event_mask["Disable_Label"]),
"DisableSegment": utils.event_mask_2_list(event_mask["Disable_Label"])}
spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=conf["spo2_fill__anomaly"]["nan_to_num_value"])
segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose)
if verbose:
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
# 复制mask到processed_Labels文件夹
save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv"
shutil.copyfile(mask_excel_path, save_mask_excel_path)
# 复制SA Label_corrected.csv到processed_Labels文件夹
sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_Sync.csv")
if sa_label_corrected_path.exists():
save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv"
shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path)
else:
if verbose:
print(f"Warning: {sa_label_corrected_path} does not exist.")
# 保存处理后的信号和截取的片段列表
save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz"
save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz"
# psg_data更新为处理后的信号
# 用下划线替换键里面的空格
psg_data = {
"Effort Tho": {
"name": "Effort_Tho",
"data": normalized_tho_signal,
"fs": target_fs,
"length": len(normalized_tho_signal),
"second": len(normalized_tho_signal) / target_fs
},
"Effort Abd": {
"name": "Effort_Abd",
"data": normalized_abd_signal,
"fs": target_fs,
"length": len(normalized_abd_signal),
"second": len(normalized_abd_signal) / target_fs
},
"Effort": {
"name": "Effort",
"data": normalized_effort_signal,
"fs": target_fs,
"length": len(normalized_effort_signal),
"second": len(normalized_effort_signal) / target_fs
},
"Flow P": {
"name": "Flow_P",
"data": normalized_flowp_signal,
"fs": target_fs,
"length": len(normalized_flowp_signal),
"second": len(normalized_flowp_signal) / target_fs
},
"Flow T": {
"name": "Flow_T",
"data": normalized_flowt_signal,
"fs": target_fs,
"length": len(normalized_flowt_signal),
"second": len(normalized_flowt_signal) / target_fs
},
"SpO2": {
"name": "SpO2",
"data": spo2_data_filt_fill,
"fs": target_fs,
"length": len(spo2_data_filt_fill),
"second": len(spo2_data_filt_fill) / target_fs
},
"HR": {
"name": "HR",
"data": psg_data["HR"]["data"],
"fs": psg_data["HR"]["fs"],
"length": psg_data["HR"]["length"],
"second": psg_data["HR"]["second"]
},
"RRI": {
"name": "RRI",
"data": rri,
"fs": target_fs,
"length": len(rri),
"second": len(rri) / target_fs
},
"5_class": {
"name": "Stage",
"data": psg_data["5_class"]["data"],
"fs": psg_data["5_class"]["fs"],
"length": psg_data["5_class"]["length"],
"second": psg_data["5_class"]["second"]
}
}
np.savez_compressed(save_signal_path, **psg_data)
np.savez_compressed(save_segment_path,
segment_list=segment_list,
disable_segment_list=disable_segment_list)
if verbose:
print(f"Saved processed signals to: {save_signal_path}")
print(f"Saved segment list to: {save_segment_path}")
if draw_segment:
total_len = len(segment_list) + len(disable_segment_list)
if verbose:
print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}")
draw_tools.draw_psg_label(
psg_data=psg_data,
psg_label=event_mask["SA_Label"],
segment_list=segment_list,
save_path=visual_path / f"{samp_id}" / "enable",
verbose=verbose,
multi_p=multi_p,
multi_task_id=multi_task_id
)
draw_tools.draw_psg_label(
psg_data=psg_data,
psg_label=event_mask["SA_Label"],
segment_list=disable_segment_list,
save_path=visual_path / f"{samp_id}" / "disable",
verbose=verbose,
multi_p=multi_p,
multi_task_id=multi_task_id
)
# 显式删除大型对象
try:
del psg_data
del normalized_tho_signal, normalized_abd_signal
del normalized_flowp_signal, normalized_flowt_signal
del normalized_effort_signal
del spo2_data_filt, spo2_data_filt_fill
del rri
del event_mask, event_list
del segment_list, disable_segment_list
except:
pass
# 强制垃圾回收
gc.collect()
def multiprocess_entry(_progress, task_id, _id):
build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id)
def multiprocess_with_tqdm(args_list, n_processes):
from concurrent.futures import ProcessPoolExecutor
from rich import progress
with progress.Progress(
"[progress.description]{task.description}",
progress.BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
progress.MofNCompleteColumn(),
progress.TimeRemainingColumn(),
progress.TimeElapsedColumn(),
refresh_per_second=1, # bit slower updates
transient=False
) as progress:
futures = []
with multiprocessing.Manager() as manager:
_progress = manager.dict()
overall_progress_task = progress.add_task("[green]All jobs progress:")
with ProcessPoolExecutor(max_workers=n_processes, mp_context=multiprocessing.get_context("spawn")) as executor:
for i_args in range(len(args_list)):
args = args_list[i_args]
task_id = progress.add_task(f"task {i_args}", visible=True)
futures.append(executor.submit(multiprocess_entry, _progress, task_id, args_list[i_args]))
# monitor the progress:
while (n_finished := sum([future.done() for future in futures])) < len(
futures
):
progress.update(
overall_progress_task, completed=n_finished, total=len(futures)
)
for task_id, update_data in _progress.items():
desc = update_data.get("desc", "")
# update the progress bar for this task:
progress.update(
task_id,
completed=update_data.get("progress", 0),
total=update_data.get("total", 0),
description=desc
)
# raise any errors:
for future in futures:
future.result()
def multiprocess_with_pool(args_list, n_processes):
"""使用Pool每个worker处理固定数量任务后重启"""
from multiprocessing import Pool
# maxtasksperchild 设置每个worker处理多少任务后重启释放内存
with Pool(processes=n_processes, maxtasksperchild=2) as pool:
results = []
for samp_id in args_list:
result = pool.apply_async(
build_HYS_dataset_segment,
args=(samp_id, False, True, False, None, None)
)
results.append(result)
# 等待所有任务完成
for i, result in enumerate(results):
try:
result.get()
print(f"Completed {i + 1}/{len(args_list)}")
except Exception as e:
print(f"Error processing task {i}: {e}")
pool.close()
pool.join()
if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml"
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
mask_path = Path(conf["mask_save_path"])
save_path = Path(conf["dataset_config"]["dataset_save_path"])
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
dataset_config = conf["dataset_config"]
visual_path.mkdir(parents=True, exist_ok=True)
save_processed_signal_path = save_path / "Signals"
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
save_segment_list_path = save_path / "Segments_List"
save_segment_list_path.mkdir(parents=True, exist_ok=True)
save_processed_label_path = save_path / "Labels"
save_processed_label_path.mkdir(parents=True, exist_ok=True)
# print(f"select_ids: {select_ids}")
# print(f"root_path: {root_path}")
# print(f"save_path: {save_path}")
# print(f"visual_path: {visual_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned"
print(select_ids)
# build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True)
# for samp_id in select_ids:
# print(f"Processing sample ID: {samp_id}")
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
# multiprocess_with_tqdm(args_list=select_ids, n_processes=8)
multiprocess_with_pool(args_list=select_ids, n_processes=8)

View File

@ -0,0 +1,233 @@
import multiprocessing
import sys
from pathlib import Path
import os
import numpy as np
os.environ['DISPLAY'] = "localhost:10.0"
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import utils
import signal_method
import draw_tools
import shutil
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None):
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
if not signal_path:
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
signal_path = signal_path[0]
if verbose:
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv")
if verbose:
print(f"mask_excel_path: {mask_excel_path}")
event_mask, event_list = utils.read_mask_execl(mask_excel_path)
bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=verbose)
bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs, verbose=verbose)
normalized_resp_signal = signal_method.normalize_resp_signal_by_segment(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"])
# 如果signal_data采样率过进行降采样
if signal_fs == 1000:
bcg_signal_notch = utils.downsample_signal_fast(original_signal=bcg_signal_notch, original_fs=signal_fs, target_fs=100)
bcg_signal_raw = utils.downsample_signal_fast(original_signal=bcg_signal_raw, original_fs=signal_fs,
target_fs=100)
signal_fs = 100
if bcg_fs == 1000:
bcg_signal = utils.downsample_signal_fast(original_signal=bcg_signal, original_fs=bcg_fs, target_fs=100)
bcg_fs = 100
# draw_tools.draw_signal_with_mask(samp_id=samp_id,
# signal_data=resp_signal,
# signal_fs=resp_fs,
# resp_data=normalized_resp_signal,
# resp_fs=resp_fs,
# bcg_data=bcg_signal,
# bcg_fs=bcg_fs,
# signal_disable_mask=event_mask["Disable_Label"],
# resp_low_amp_mask=event_mask["Resp_LowAmp_Label"],
# resp_movement_mask=event_mask["Resp_Movement_Label"],
# resp_change_mask=event_mask["Resp_AmpChange_Label"],
# resp_sa_mask=event_mask["SA_Label"],
# bcg_low_amp_mask=event_mask["BCG_LowAmp_Label"],
# bcg_movement_mask=event_mask["BCG_Movement_Label"],
# bcg_change_mask=event_mask["BCG_AmpChange_Label"],
# show=show,
# save_path=None)
segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose)
if verbose:
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
# 复制mask到processed_Labels文件夹
save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv"
shutil.copyfile(mask_excel_path, save_mask_excel_path)
# 复制SA Label_corrected.csv到processed_Labels文件夹
sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_corrected.csv")
if sa_label_corrected_path.exists():
save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv"
shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path)
else:
if verbose:
print(f"Warning: {sa_label_corrected_path} does not exist.")
# 保存处理后的信号和截取的片段列表
save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz"
save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz"
bcg_data = {
"bcg_signal_notch": {
"name": "BCG_Signal_Notch",
"data": bcg_signal_notch,
"fs": signal_fs,
"length": len(bcg_signal_notch),
"second": len(bcg_signal_notch) // signal_fs
},
"bcg_signal":{
"name": "BCG_Signal_Raw",
"data": bcg_signal,
"fs": bcg_fs,
"length": len(bcg_signal),
"second": len(bcg_signal) // bcg_fs
},
"resp_signal": {
"name": "Resp_Signal",
"data": normalized_resp_signal,
"fs": resp_fs,
"length": len(normalized_resp_signal),
"second": len(normalized_resp_signal) // resp_fs
}
}
np.savez_compressed(save_signal_path, **bcg_data)
np.savez_compressed(save_segment_path,
segment_list=segment_list,
disable_segment_list=disable_segment_list)
if verbose:
print(f"Saved processed signals to: {save_signal_path}")
print(f"Saved segment list to: {save_segment_path}")
if draw_segment:
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
psg_data["HR"] = {
"name": "HR",
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"], psg_data["Rpeak"]["fs"]),
"fs": psg_data["ECG_Sync"]["fs"],
"length": psg_data["ECG_Sync"]["length"],
"second": psg_data["ECG_Sync"]["second"]
}
psg_label = utils.read_psg_label(sa_label_corrected_path, verbose=verbose)
psg_event_mask, _ = utils.generate_event_mask(event_df=psg_label, signal_second=psg_data["ECG_Sync"]["second"], use_correct=False)
total_len = len(segment_list) + len(disable_segment_list)
if verbose:
print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}")
# draw_tools.draw_psg_bcg_label(psg_data=psg_data,
# psg_label=psg_event_mask,
# bcg_data=bcg_data,
# event_mask=event_mask,
# segment_list=segment_list,
# save_path=visual_path / f"{samp_id}" / "enable",
# verbose=verbose,
# multi_p=multi_p,
# multi_task_id=multi_task_id
# )
#
# draw_tools.draw_psg_bcg_label(
# psg_data=psg_data,
# psg_label=psg_event_mask,
# bcg_data=bcg_data,
# event_mask=event_mask,
# segment_list=disable_segment_list,
# save_path=visual_path / f"{samp_id}" / "disable",
# verbose=verbose,
# multi_p=multi_p,
# multi_task_id=multi_task_id
# )
def multiprocess_entry(_progress, task_id, _id):
build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id)
def multiprocess_with_pool(args_list, n_processes):
"""使用Pool每个worker处理固定数量任务后重启"""
from multiprocessing import Pool
# maxtasksperchild 设置每个worker处理多少任务后重启释放内存
with Pool(processes=n_processes, maxtasksperchild=2) as pool:
results = []
for samp_id in args_list:
result = pool.apply_async(
build_HYS_dataset_segment,
args=(samp_id, False, True, False, None, None)
)
results.append(result)
# 等待所有任务完成
for i, result in enumerate(results):
try:
result.get()
print(f"Completed {i + 1}/{len(args_list)}")
except Exception as e:
print(f"Error processing task {i}: {e}")
pool.close()
pool.join()
if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/HYS_config.yaml"
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
mask_path = Path(conf["mask_save_path"])
save_path = Path(conf["dataset_config"]["dataset_save_path"])
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
dataset_config = conf["dataset_config"]
visual_path.mkdir(parents=True, exist_ok=True)
save_processed_signal_path = save_path / "Signals"
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
save_segment_list_path = save_path / "Segments_List"
save_segment_list_path.mkdir(parents=True, exist_ok=True)
save_processed_label_path = save_path / "Labels"
save_processed_label_path.mkdir(parents=True, exist_ok=True)
# print(f"select_ids: {select_ids}")
# print(f"root_path: {root_path}")
# print(f"save_path: {save_path}")
# print(f"visual_path: {visual_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned"
print(select_ids)
# build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True)
# for samp_id in select_ids:
# print(f"Processing sample ID: {samp_id}")
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
# multiprocess_with_tqdm(args_list=select_ids, n_processes=16)
multiprocess_with_pool(args_list=select_ids, n_processes=16)

View File

View File

@ -0,0 +1,139 @@
select_ids:
- 54
- 88
- 220
- 221
- 229
- 282
- 286
- 541
- 579
- 582
- 670
- 671
- 683
- 684
- 735
- 933
- 935
- 950
- 952
- 960
- 962
- 967
- 1302
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG
target_fs: 100
dataset_config:
window_sec: 180
stride_sec: 60
dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_PSG_dataset
dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_PSG_dataset/visualization
effort:
downsample_fs: 10
effort_filter:
filter_type: bandpass
low_cut: 0.05
high_cut: 0.5
order: 3
average_filter:
window_size_sec: 20
flow:
downsample_fs: 10
flow_filter:
filter_type: bandpass
low_cut: 0.05
high_cut: 0.5
order: 3
spo2_fill__anomaly:
max_fill_duration: 30
min_gap_duration: 10
nan_to_num_value: 95
#ecg:
# downsample_fs: 100
#
#ecg_filter:
# filter_type: bandpass
# low_cut: 0.5
# high_cut: 40
# order: 5
#resp:
# downsample_fs_1: None
# downsample_fs_2: 10
#
#resp_filter:
# filter_type: bandpass
# low_cut: 0.05
# high_cut: 0.5
# order: 3
#
#resp_low_amp:
# window_size_sec: 30
# stride_sec:
# amplitude_threshold: 3
# merge_gap_sec: 60
# min_duration_sec: 60
#
#resp_movement:
# window_size_sec: 20
# stride_sec: 1
# std_median_multiplier: 4
# compare_intervals_sec:
# - 60
# - 120
## - 180
# interval_multiplier: 3
# merge_gap_sec: 30
# min_duration_sec: 1
#
#resp_movement_revise:
# up_interval_multiplier: 3
# down_interval_multiplier: 2
# compare_intervals_sec: 30
# merge_gap_sec: 10
# min_duration_sec: 1
#
#resp_amp_change:
# mav_calc_window_sec: 4
# threshold_amplitude: 0.25
# threshold_energy: 0.4
#
#
#bcg:
# downsample_fs: 100
#
#bcg_filter:
# filter_type: bandpass
# low_cut: 1
# high_cut: 10
# order: 10
#
#bcg_low_amp:
# window_size_sec: 1
# stride_sec:
# amplitude_threshold: 8
# merge_gap_sec: 20
# min_duration_sec: 3
#
#
#bcg_movement:
# window_size_sec: 2
# stride_sec:
# merge_gap_sec: 20
# min_duration_sec: 4

View File

@ -0,0 +1,86 @@
select_ids:
- 1302
- 286
- 950
- 220
- 229
- 541
- 582
- 670
- 684
- 960
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS
resp:
downsample_fs_1: 100
downsample_fs_2: 10
resp_filter:
filter_type: bandpass
low_cut: 0.05
high_cut: 0.6
order: 3
resp_low_amp:
window_size_sec: 30
stride_sec:
amplitude_threshold: 3
merge_gap_sec: 60
min_duration_sec: 60
resp_movement:
window_size_sec: 20
stride_sec: 1
std_median_multiplier: 4
compare_intervals_sec:
- 60
- 120
# - 180
interval_multiplier: 3
merge_gap_sec: 30
min_duration_sec: 1
resp_movement_revise:
up_interval_multiplier: 3
down_interval_multiplier: 2
compare_intervals_sec: 30
merge_gap_sec: 10
min_duration_sec: 1
resp_amp_change:
mav_calc_window_sec: 4
threshold_amplitude: 0.25
threshold_energy: 0.4
bcg:
downsample_fs: 100
bcg_filter:
filter_type: bandpass
low_cut: 1
high_cut: 10
order: 10
bcg_low_amp:
window_size_sec: 1
stride_sec:
amplitude_threshold: 8
merge_gap_sec: 20
min_duration_sec: 3
bcg_movement:
window_size_sec: 2
stride_sec:
merge_gap_sec: 20
min_duration_sec: 4
dataset_config:
window_sec: 180
stride_sec: 60
dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset
dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset/visualization

View File

@ -0,0 +1,56 @@
select_ids:
- 1000
- 1004
- 1006
- 1009
- 1010
- 1300
- 1301
- 1302
- 1308
- 1314
- 1354
- 1374
- 1378
- 1478
- 220
- 221
- 229
- 282
- 285
- 286
- 54
- 541
- 579
- 582
- 670
- 671
- 683
- 684
- 686
- 703
- 704
- 726
- 735
- 736
- 88
- 893
- 933
- 935
- 939
- 950
- 952
- 954
- 955
- 956
- 960
- 961
- 962
- 967
- 969
- 971
- 972
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
pair_file_path: /mnt/disk_wd/marques_dataset/Resp_Pair_Dataset/HYS/Raw

View File

@ -0,0 +1,32 @@
select_ids:
- 3103
- 3105
- 3106
- 3107
- 3108
- 3110
- 3203
- 3204
- 3205
- 3209
- 3211
- 3301
- 3303
- 3304
- 3307
- 3308
- 3309
- 3403
- 3405
- 3406
- 3407
- 3408
- 3504
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/ZD5Y
pair_file_path: /mnt/disk_wd/marques_dataset/Resp_Pair_Dataset/ZD5Y/Raw

View File

@ -0,0 +1,41 @@
root_path: /mnt/disk_wd/marques_dataset/shhs/polysomnography/shhs1
mask_save_path: /mnt/disk_code/marques/dataprepare/output/shhs1
effort_target_fs: 10
ecg_target_fs: 100
dataset_config:
window_sec: 180
stride_sec: 60
dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/SHHS1_dataset
dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/SHHS1_dataset/visualization
effort:
downsample_fs: 10
effort_filter:
filter_type: bandpass
low_cut: 0.05
high_cut: 0.5
order: 3
average_filter:
window_size_sec: 20
flow:
downsample_fs: 10
flow_filter:
filter_type: bandpass
low_cut: 0.05
high_cut: 0.5
order: 3
spo2_fill__anomaly:
max_fill_duration: 30
min_gap_duration: 10
nan_to_num_value: 95

View File

@ -0,0 +1,88 @@
select_ids:
- 3103
- 3105
- 3106
- 3107
- 3108
- 3110
- 3203
- 3204
- 3205
- 3207
- 3208
- 3209
- 3212
- 3301
- 3303
- 3307
- 3403
- 3504
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/ZD5Y
save_path: /mnt/disk_code/marques/dataprepare/output/ZD5Y
resp:
downsample_fs_1: 100
downsample_fs_2: 10
resp_filter:
filter_type: bandpass
low_cut: 0.01
high_cut: 0.7
order: 3
resp_low_amp:
window_size_sec: 30
stride_sec:
amplitude_threshold: 3
merge_gap_sec: 60
min_duration_sec: 60
resp_movement:
window_size_sec: 20
stride_sec: 1
std_median_multiplier: 5
compare_intervals_sec:
- 60
- 120
# - 180
interval_multiplier: 3.5
merge_gap_sec: 30
min_duration_sec: 1
resp_movement_revise:
up_interval_multiplier: 3
down_interval_multiplier: 2
compare_intervals_sec: 30
merge_gap_sec: 10
min_duration_sec: 1
resp_amp_change:
mav_calc_window_sec: 1
threshold_amplitude: 0.25
threshold_energy: 0.4
bcg:
downsample_fs: 100
bcg_filter:
filter_type: bandpass
low_cut: 1
high_cut: 10
order: 10
bcg_low_amp:
window_size_sec: 1
stride_sec:
amplitude_threshold: 8
merge_gap_sec: 20
min_duration_sec: 3
bcg_movement:
window_size_sec: 2
stride_sec:
merge_gap_sec: 20
min_duration_sec: 4

View File

@ -0,0 +1,67 @@
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import utils
import shutil
def copy_one_resp_pair(one_id):
sync_type = "Sync"
org_bcg_file_path = sync_bcg_path / f"{one_id}"
dest_bcg_file_path = pair_file_path / f"{one_id}"
dest_bcg_file_path.mkdir(parents=True, exist_ok=True)
if not list(org_bcg_file_path.glob("OrgBCG_Sync_*.txt")):
if not list(org_bcg_file_path.glob("OrgBCG_RoughCut_*.txt")):
print(f"No OrgBCG files found for ID {one_id}.")
return
else:
sync_type = "RoughCut"
print(f"Using RoughCut files for ID {one_id}.")
for file in org_bcg_file_path.glob(f"OrgBCG_{sync_type}_*.txt"):
shutil.copyfile(file, dest_bcg_file_path / f"{one_id}_{file.name}".replace("_RoughCut", "").replace("_Sync", ""))
psg_file_path = sync_psg_path / f"{one_id}"
dest_psg_file_path = pair_file_path / f"{one_id}"
dest_psg_file_path.mkdir(parents=True, exist_ok=True)
# 检查上面的文件是否存在
psg_file_patterns = [
f"5_class_{sync_type}_*.txt",
f"Effort Abd_{sync_type}_*.txt",
f"Effort Tho_{sync_type}_*.txt",
f"Flow P_{sync_type}_*.txt",
f"Flow T_{sync_type}_*.txt",
f"SA Label_Sync.csv",
f"SpO2_{sync_type}_*.txt"
]
for pattern in psg_file_patterns:
if not list(psg_file_path.glob(pattern)):
print(f"No PSG files found for ID {one_id} with pattern {pattern}.")
return
for pattern in psg_file_patterns:
for file in psg_file_path.glob(pattern):
shutil.copyfile(file, dest_psg_file_path / f"{one_id}_{file.name.replace('_RoughCut', '').replace('_Sync', '')}")
if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/RESP_PAIR_ZD5Y_config.yaml"
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
sync_bcg_path = root_path / "OrgBCG_Aligned"
sync_psg_path = root_path / "PSG_Aligned"
pair_file_path = Path(conf["pair_file_path"])
# copy_one_resp_pair(961)
for samp_id in select_ids:
print(f"Processing {samp_id}...")
copy_one_resp_pair(samp_id)

View File

@ -0,0 +1,100 @@
import argparse
from pathlib import Path
from lxml import etree
from tqdm import tqdm
from collections import Counter
def main():
# 设定目标文件夹路径,你可以修改这里的路径,或者运行脚本时手动输入
# 默认为当前目录 '.'
# target_dir = "/mnt/disk_wd/marques_dataset/shhs/polysomnography/annotations-events-nsrr/shhs1"
target_dir = "/mnt/disk_wd/marques_dataset/shhs/polysomnography/annotations-events-nsrr/shhs2"
folder_path = Path(target_dir)
if not folder_path.exists():
print(f"错误: 路径 '{folder_path}' 不存在。")
return
# 1. 获取所有 XML 文件 (扁平结构,不递归子目录)
xml_files = list(folder_path.glob("*.xml"))
total_files = len(xml_files)
if total_files == 0:
print(f"'{folder_path}' 中没有找到 XML 文件。")
return
print(f"找到 {total_files} 个 XML 文件,准备开始处理...")
# 用于统计 (EventType, EventConcept) 组合的计数器
stats_counter = Counter()
# 2. 遍历文件,使用 tqdm 显示进度条
for xml_file in tqdm(xml_files, desc="Processing XMLs", unit="file"):
try:
# 使用 lxml 解析
tree = etree.parse(str(xml_file))
root = tree.getroot()
# 3. 定位到 ScoredEvent 节点
# SHHS XML 结构通常是: PSGAnnotation -> ScoredEvents -> ScoredEvent
# 我们直接查找所有的 ScoredEvent 节点
events = root.findall(".//ScoredEvent")
for event in events:
# 提取 EventType
type_node = event.find("EventType")
# 处理节点不存在或文本为空的情况
e_type = type_node.text.strip() if (type_node is not None and type_node.text) else "N/A"
# 提取 EventConcept
concept_node = event.find("EventConcept")
e_concept = concept_node.text.strip() if (concept_node is not None and concept_node.text) else "N/A"
# 4. 组合并计数
# 组合键为元组 (EventType, EventConcept)
key = (e_type, e_concept)
stats_counter[key] += 1
except etree.XMLSyntaxError:
print(f"\n[警告] 文件格式错误,跳过: {xml_file.name}")
except Exception as e:
print(f"\n[错误] 处理文件 {xml_file.name} 时出错: {e}")
# 5. 打印结果到终端
if stats_counter:
# --- 动态计算列宽 ---
# 获取所有 EventType 的最大长度,默认长度 9
max_type_width = max((len(k[0]) for k in stats_counter.keys()), default=9)
max_type_width = max(max_type_width, 9)
# 获取所有 EventConcept 的最大长度,默认长度 12
max_conc_width = max((len(k[1]) for k in stats_counter.keys()), default=12)
max_conc_width = max(max_conc_width, 12)
# 计算表格总宽度
total_line_width = max_type_width + max_conc_width + 10 + 6
print("\n" + "=" * total_line_width)
print(f"{'EventType':<{max_type_width}} | {'EventConcept':<{max_conc_width}} | {'Count':>10}")
print("-" * total_line_width)
# --- 修改处:按名称排序 ---
# sorted() 默认会对元组 (EventType, EventConcept) 进行字典序排序
# 即先按 EventType A-Z 排序,再按 EventConcept A-Z 排序
for (e_type, e_concept), count in sorted(stats_counter.items()):
print(f"{e_type:<{max_type_width}} | {e_concept:<{max_conc_width}} | {count:>10}")
print("=" * total_line_width)
else:
print("\n未提取到任何事件数据。")
print("=" * 90)
print(f"统计完成。共扫描 {total_files} 个文件。")
if __name__ == "__main__":
main()

2
draw_tools/__init__.py Normal file
View File

@ -0,0 +1,2 @@
from .draw_statics import draw_signal_with_mask, draw_psg_signal
from .draw_label import draw_psg_bcg_label,draw_psg_label

337
draw_tools/draw_label.py Normal file
View File

@ -0,0 +1,337 @@
from matplotlib.axes import Axes
from matplotlib.gridspec import GridSpec
from matplotlib.colors import PowerNorm
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import numpy as np
from tqdm.rich import tqdm
import utils
import gc
# 添加with_prediction参数
psg_bcg_chn_name2ax = {
"SpO2": 0,
"Flow T": 1,
"Flow P": 2,
"Effort Tho": 3,
"Effort Abd": 4,
"HR": 5,
"resp": 6,
"bcg": 7,
"Stage": 8,
"resp_twinx": 9,
"bcg_twinx": 10,
}
psg_chn_name2ax = {
"SpO2": 0,
"Flow T": 1,
"Flow P": 2,
"Effort Tho": 3,
"Effort Abd": 4,
"Effort": 5,
"HR": 6,
"RRI": 7,
"Stage": 8,
}
resp_chn_name2ax = {
"resp": 0,
"bcg": 1,
}
def create_psg_bcg_figure():
fig = plt.figure(figsize=(12, 8), dpi=200)
gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1])
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
axes = []
for i in range(9):
ax = fig.add_subplot(gs[i])
axes.append(ax)
axes[psg_bcg_chn_name2ax["SpO2"]].grid(True)
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["SpO2"]].set_ylim((85, 100))
axes[psg_bcg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["Flow T"]].grid(True)
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["Flow P"]].grid(True)
# axes[2].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["Effort Tho"]].grid(True)
# axes[3].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["Effort Abd"]].grid(True)
# axes[4].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["HR"]].grid(True)
axes[psg_bcg_chn_name2ax["HR"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["resp"]].grid(True)
axes[psg_bcg_chn_name2ax["resp"]].tick_params(axis='x', colors="white")
axes.append(axes[psg_bcg_chn_name2ax["resp"]].twinx())
axes[psg_bcg_chn_name2ax["bcg"]].grid(True)
# axes[5].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white")
axes.append(axes[psg_bcg_chn_name2ax["bcg"]].twinx())
axes[psg_bcg_chn_name2ax["Stage"]].grid(True)
# axes[7].xaxis.set_major_formatter(Params.FORMATTER)
return fig, axes
def create_psg_figure():
fig = plt.figure(figsize=(12, 8), dpi=200)
gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 1, 1, 1])
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
axes = []
for i in range(9):
ax = fig.add_subplot(gs[i])
axes.append(ax)
axes[psg_chn_name2ax["SpO2"]].grid(True)
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_chn_name2ax["SpO2"]].set_ylim((85, 100))
axes[psg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Flow T"]].grid(True)
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Flow P"]].grid(True)
# axes[2].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Effort Tho"]].grid(True)
# axes[3].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Effort Abd"]].grid(True)
# axes[4].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Effort"]].grid(True)
axes[psg_chn_name2ax["Effort"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["HR"]].grid(True)
axes[psg_chn_name2ax["HR"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["RRI"]].grid(True)
axes[psg_chn_name2ax["RRI"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Stage"]].grid(True)
return fig, axes
def create_resp_figure():
fig = plt.figure(figsize=(12, 6), dpi=100)
gs = GridSpec(2, 1, height_ratios=[3, 2])
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
axes = []
for i in range(2):
ax = fig.add_subplot(gs[i])
axes.append(ax)
axes[0].grid(True)
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
axes[0].tick_params(axis='x', colors="white")
axes[1].grid(True)
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
axes[1].tick_params(axis='x', colors="white")
return fig, axes
def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None,
event_codes: list[int] = None, multi_labels=None, ax2: Axes = None):
signal_fs = signal_data["fs"]
chn_signal = signal_data["data"]
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs)
ax.plot(time_axis, chn_signal[segment_start * signal_fs:segment_end * signal_fs], color='black',
label=signal_data["name"])
if event_mask is not None:
if multi_labels is None and event_codes is not None:
for event_code in event_codes:
mask = event_mask[segment_start:segment_end].repeat(signal_fs) == event_code
y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * mask).astype(float)
np.place(y, y == 0, np.nan)
ax.plot(time_axis, y, color=utils.ColorCycle[event_code])
elif multi_labels == "resp" and event_codes is not None:
ax.set_ylim(-6, 6)
# 建立第二个y轴坐标
ax2.cla()
ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
color='blue', alpha=0.8, label='Low Amplitude Mask')
ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
color='orange', alpha=0.8, label='Movement Mask')
ax2.plot(time_axis, event_mask["Resp_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3,
color='green', alpha=0.8, label='Amplitude Change Mask')
for event_code in event_codes:
sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code
score_mask = event_mask["SA_Score"][segment_start:segment_end].repeat(signal_fs)
# y = (sa_mask * score_mask).astype(float)
y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * sa_mask).astype(float)
np.place(y, y == 0, np.nan)
ax.plot(time_axis, y, color=utils.ColorCycle[event_code])
ax2.plot(time_axis, score_mask, color="orange")
ax2.set_ylim(-4, 5)
elif multi_labels == "bcg" and event_codes is not None:
# 建立第二个y轴坐标
ax2.cla()
ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
color='blue', alpha=0.8, label='Low Amplitude Mask')
ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
color='orange', alpha=0.8, label='Movement Mask')
ax2.plot(time_axis, event_mask["BCG_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3,
color='green', alpha=0.8, label='Amplitude Change Mask')
ax2.set_ylim(-4, 4)
ax.set_ylabel("Amplitude")
ax.legend(loc=1)
def plt_stage_on_ax(ax, stage_data, segment_start, segment_end):
stage_signal = stage_data["data"]
stage_fs = stage_data["fs"]
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * stage_fs)
ax.plot(time_axis, stage_signal[segment_start * stage_fs:segment_end * stage_fs], color='black', label=stage_data["name"])
ax.set_ylim(0, 6)
ax.set_yticks([1, 2, 3, 4, 5])
ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"])
ax.set_ylabel("Stage")
ax.legend(loc=1)
def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end):
spo2_signal = spo2_data["data"]
spo2_fs = spo2_data["fs"]
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * spo2_fs)
ax.plot(time_axis, spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs], color='black', label=spo2_data["name"])
if spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() < 85:
ax.set_ylim((spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() - 5, 100))
else:
ax.set_ylim((85, 100))
ax.set_ylabel("SpO2 (%)")
ax.legend(loc=1)
def score_mask2alpha(score_mask):
alpha_mask = np.zeros_like(score_mask, dtype=float)
alpha_mask[score_mask <= 0] = 0
alpha_mask[score_mask == 1] = 0.9
alpha_mask[score_mask == 2] = 0.6
alpha_mask[score_mask == 3] = 0.1
return alpha_mask
def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None, verbose=True,
multi_p=None, multi_task_id=None):
if save_path is not None:
save_path.mkdir(parents=True, exist_ok=True)
for mask in event_mask.keys():
if mask.startswith("Resp_") or mask.startswith("BCG_"):
event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0)
event_mask["SA_Score"] = utils.none_to_nan_mask(event_mask["SA_Score"], 0)
# event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"])
# event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0)
fig, axes = create_psg_bcg_figure()
for i, (segment_start, segment_end) in enumerate(segment_list):
for ax in axes:
ax.cla()
plt_spo2_on_ax(axes[psg_bcg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end)
plt_stage_on_ax(axes[psg_bcg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4],
ax2=axes[psg_bcg_chn_name2ax["resp_twinx"]])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4],
ax2=axes[psg_bcg_chn_name2ax["bcg_twinx"]])
if save_path is not None:
fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png")
# print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
if multi_p is not None:
multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"}
plt.close(fig)
plt.close('all')
gc.collect()
def draw_psg_label(psg_data, psg_label, segment_list, save_path=None, verbose=True,
multi_p=None, multi_task_id=None):
if save_path is not None:
save_path.mkdir(parents=True, exist_ok=True)
if multi_p is None:
# 遍历psg_data中所有数据的长度
for i in range(len(psg_data.keys())):
chn_name = list(psg_data.keys())[i]
print(f"{chn_name} data length: {len(psg_data[chn_name]['data'])}, fs: {psg_data[chn_name]['fs']}")
# psg_label的长度
print(f"psg_label length: {len(psg_label)}")
fig, axes = create_psg_figure()
for i, (segment_start, segment_end) in enumerate(segment_list):
for ax in axes:
ax.cla()
plt_spo2_on_ax(axes[psg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end)
plt_stage_on_ax(axes[psg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort"]], psg_data["Effort"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_chn_name2ax["RRI"]], psg_data["RRI"], segment_start, segment_end)
if save_path is not None:
fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png")
# print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
if multi_p is not None:
multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"}
plt.close(fig)
plt.close('all')
gc.collect()

371
draw_tools/draw_statics.py Normal file
View File

@ -0,0 +1,371 @@
from matplotlib.axes import Axes
from matplotlib.gridspec import GridSpec
from matplotlib.colors import PowerNorm
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import numpy as np
def draw_ax_confusion_matrix(ax:Axes, confusion_matrix, segment_count_matrix, confusion_matrix_percent,
valid_signal_length, total_duration, time_labels, amp_labels, signal_type=''):
# 创建用于热图注释的文本矩阵
text_matrix = np.empty((len(amp_labels), len(time_labels)), dtype=object)
percent_matrix = np.zeros((len(amp_labels), len(time_labels)))
# 填充文本矩阵和百分比矩阵
for i in range(len(amp_labels)):
for j in range(len(time_labels)):
val = confusion_matrix.iloc[i, j]
segment_count = segment_count_matrix[i, j]
percent = confusion_matrix_percent.iloc[i, j]
text_matrix[i, j] = f"[{int(segment_count)}]{int(val)}\n({percent:.2f}%)"
percent_matrix[i, j] = percent
# 绘制热图,调整颜色条位置
sns_heatmap = sns.heatmap(percent_matrix, annot=text_matrix, fmt='',
xticklabels=time_labels, yticklabels=amp_labels,
cmap='YlGnBu', ax=ax, vmin=0, vmax=100,
norm=PowerNorm(gamma=0.5, vmin=0, vmax=100),
cbar_kws={'label': '百分比 (%)', 'shrink': 0.6, 'pad': 0.15},
# annot_kws={'fontsize': 12}
)
# 添加行统计(右侧)
row_sums = confusion_matrix['总计']
row_percents = confusion_matrix_percent['总计']
ax.text(len(time_labels) + 1, -0.5, "各幅值时长\n(有效区间百分比%)", ha='center', va='center')
for i, (val, perc) in enumerate(zip(row_sums, row_percents)):
ax.text(len(time_labels) + 0.5, i + 0.5, f"{int(val)}\n({perc:.2f}%)",
ha='center', va='center')
# 添加列统计(底部)
col_sums = segment_count_matrix.sum(axis=0)
col_percents = confusion_matrix.sum(axis=0) / total_duration * 100
ax.text(-1, len(amp_labels) + 0.5, "[各时长片段数]\n(有效区间百分比%)", ha='center', va='center', rotation=0)
for j, (val, perc) in enumerate(zip(col_sums, col_percents)):
ax.text(j + 0.5, len(amp_labels) + 0.5, f"[{int(val)}]\n({perc:.2f}%)",
ha='center', va='center', rotation=0)
# 将x轴坐标移到顶部
ax.xaxis.set_label_position('top')
ax.xaxis.tick_top()
# 设置标题和标签
ax.set_title('幅值-时长统计矩阵', pad=40)
ax.set_xlabel('持续时间区间 (秒)', labelpad=10)
ax.set_ylabel('幅值区间')
# 设置坐标轴标签水平显示
ax.set_xticklabels(time_labels, rotation=0)
ax.set_yticklabels(amp_labels, rotation=0)
# 调整颜色条位置
cbar = sns_heatmap.collections[0].colorbar
cbar.ax.yaxis.set_label_position('right')
# 添加图例说明
ax.text(-2, -1, "热图内:\n[片段数]时长\n(有效区间百分比%)",
ha='left', va='top', bbox=dict(facecolor='none', edgecolor='black', alpha=0.5))
# 总计
# 总片段数
total_segments = segment_count_matrix.sum()
# 有效总市场占比
total_percent = valid_signal_length / total_duration * 100
ax.text(len(time_labels) + 0.5, len(amp_labels) + 0.5, f"[{int(total_segments)}]{valid_signal_length}\n({total_percent:.2f}%)",
ha='center', va='center')
def draw_ax_amp(ax, signal_name, original_times, origin_signal, no_movement_signal, mav_values,
movement_position_list, low_amp_position_list, signal_second_length, aml_list=None):
# 绘制信号线
ax.plot(original_times, origin_signal, 'k-', linewidth=1, alpha=0.7)
# 添加红色和蓝色的axvspan区域
for start, end in movement_position_list:
if start < len(original_times) and end < len(original_times):
ax.axvspan(start, end, color='red', alpha=0.3)
for start, end in low_amp_position_list:
if start < len(original_times) and end < len(original_times):
ax.axvspan(start, end, color='blue', alpha=0.2)
# 如果存在AML列表绘制水平线
if aml_list is not None:
color_map = ['red', 'orange', 'green']
for i, aml in enumerate(aml_list):
ax.hlines(aml, 0, signal_second_length, color=color_map[min(i, 2)], linestyle='dashed', linewidth=2, alpha=0.5, label=f'{aml} aml')
ax.plot(np.linspace(0, len(mav_values), len(mav_values)), mav_values, color='blue', linewidth=2, alpha=0.4, label='2sMAV')
# 设置Y轴范围
ax.set_ylim((-2000, 2000))
# 创建表示不同颜色区域的图例
red_patch = mpatches.Patch(color='red', alpha=0.2, label='Movement Area')
blue_patch = mpatches.Patch(color='blue', alpha=0.2, label='Low Amplitude Area')
# 添加新的图例项,并保留原来的图例项
handles, labels = ax.get_legend_handles_labels() # 获取原有图例
ax.legend(handles=[red_patch, blue_patch] + handles,
labels=['Movement Area', 'Low Amplitude Area'] + labels,
loc='upper right',
bbox_to_anchor=(1, 1.4),
framealpha=0.5)
# 设置标题和标签
ax.set_title(f'{signal_name} Signal')
ax.set_ylabel('Amplitude')
ax.set_xlabel('Time (s)')
# 启用网格
ax.grid(True, linestyle='--', alpha=0.7)
def draw_signal_metrics(bcg_origin_signal, resp_origin_signal, bcg_no_movement_signal, resp_no_movement_signal,
bcg_sampling_rate, resp_sampling_rate, bcg_movement_position_list, bcg_low_amp_position_list,
resp_movement_position_list, resp_low_amp_position_list,
bcg_mav_values, resp_mav_values, bcg_statistic_info, resp_statistic_info,
signal_info, show=False, save_path=None):
# 创建图像
fig = plt.figure(figsize=(18, 10))
gs = GridSpec(2, 2, height_ratios=[2, 2], width_ratios=[4, 2], hspace=0.5)
signal_second_length = len(bcg_origin_signal) // bcg_sampling_rate
bcg_origin_times = np.linspace(0, signal_second_length, len(bcg_origin_signal))
resp_origin_times = np.linspace(0, signal_second_length, len(resp_origin_signal))
# 子图 1原始信号
ax1 = fig.add_subplot(gs[0])
draw_ax_amp(ax=ax1, signal_name='BCG', original_times=bcg_origin_times, origin_signal=bcg_origin_signal,
no_movement_signal=bcg_no_movement_signal, mav_values=bcg_mav_values,
movement_position_list=bcg_movement_position_list, low_amp_position_list=bcg_low_amp_position_list,
signal_second_length=signal_second_length, aml_list=[200, 500, 1000])
ax2 = fig.add_subplot(gs[1])
param_names = ['confusion_matrix', 'segment_count_matrix', 'confusion_matrix_percent',
'valid_signal_length', 'total_duration', 'time_labels', 'amp_labels']
params = dict(zip(param_names, bcg_statistic_info))
params['ax'] = ax2
draw_ax_confusion_matrix(**params)
ax3 = fig.add_subplot(gs[2], sharex=ax1)
draw_ax_amp(ax=ax3, signal_name='RSEP', original_times=resp_origin_times, origin_signal=resp_origin_signal,
no_movement_signal=resp_no_movement_signal, mav_values=resp_mav_values,
movement_position_list=resp_movement_position_list, low_amp_position_list=resp_low_amp_position_list,
signal_second_length=signal_second_length, aml_list=[100, 300, 500])
ax4 = fig.add_subplot(gs[3])
params = dict(zip(param_names, resp_statistic_info))
params['ax'] = ax4
draw_ax_confusion_matrix(**params)
# 全局标题
fig.suptitle(f'{signal_info} Signal Metrics', fontsize=16, x=0.35, y=0.95)
if save_path is not None:
# 保存图像
plt.savefig(save_path, dpi=300)
if show:
plt.show()
plt.close()
def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, resp_fs, bcg_fs,
signal_disable_mask, resp_low_amp_mask, resp_movement_mask, resp_change_mask,
resp_sa_mask, bcg_low_amp_mask, bcg_movement_mask, bcg_change_mask, show=False, save_path=None
):
# 第一行绘制去工频噪声的原始信号,右侧为不可用区间标记,左侧为信号幅值纵坐标
# 第二行绘制呼吸分量右侧低幅值、高幅值、幅值变换标记、SA标签左侧为呼吸幅值纵坐标
# 第三行绘制心冲击分量,右侧为低幅值、高幅值、幅值变换标记、,左侧为心冲击幅值纵坐标
# mask为None则生成全Nan掩码
def _none_to_nan_mask(mask, ref):
if mask is None:
return np.full_like(ref, np.nan)
else:
# 将mask中的0替换为nan其他的保持
mask = np.where(mask == 0, np.nan, mask)
return mask
signal_disable_mask = _none_to_nan_mask(signal_disable_mask, signal_data)
resp_low_amp_mask = _none_to_nan_mask(resp_low_amp_mask, resp_data)
resp_movement_mask = _none_to_nan_mask(resp_movement_mask, resp_data)
resp_change_mask = _none_to_nan_mask(resp_change_mask, resp_data)
resp_sa_mask = _none_to_nan_mask(resp_sa_mask, resp_data)
bcg_low_amp_mask = _none_to_nan_mask(bcg_low_amp_mask, bcg_data)
bcg_movement_mask = _none_to_nan_mask(bcg_movement_mask, bcg_data)
bcg_change_mask = _none_to_nan_mask(bcg_change_mask, bcg_data)
fig = plt.figure(figsize=(18, 10))
ax0 = fig.add_subplot(3, 1, 1)
ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue')
ax0.set_title(f'Sample {samp_id} - Raw Signal Data')
ax0.set_ylabel('Amplitude')
# ax0.set_xticklabels([])
ax0_twin = ax0.twinx()
ax0_twin.plot(np.linspace(0, len(signal_disable_mask), len(signal_disable_mask)), signal_disable_mask,
color='red', alpha=0.5)
ax0_twin.autoscale(enable=False, axis='y', tight=True)
ax0_twin.set_ylim((-2, 2))
ax0_twin.set_ylabel('Disable Mask')
ax0_twin.set_yticks([0, 1])
ax0_twin.set_yticklabels(['Enabled', 'Disabled'])
ax0_twin.grid(False)
ax0_twin.legend(['Disable Mask'], loc='upper right')
ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='gray', alpha=0.5)
resp_data_no_movement = resp_data.copy()
resp_data_no_movement[resp_movement_mask.repeat(int(resp_fs)) == 1] = np.nan
ax1.plot(np.linspace(0, len(resp_data_no_movement) // resp_fs, len(resp_data_no_movement)), resp_data_no_movement, color='orange')
ax1.set_ylabel('Amplitude')
# ax1.set_xticklabels([])
ax1_twin = ax1.twinx()
ax1_twin.plot(np.linspace(0, len(resp_low_amp_mask), len(resp_low_amp_mask)), resp_low_amp_mask*-1,
color='blue', alpha=0.5, label='Low Amplitude Mask')
ax1_twin.plot(np.linspace(0, len(resp_movement_mask), len(resp_movement_mask)), resp_movement_mask*-2,
color='red', alpha=0.5, label='Movement Mask')
ax1_twin.plot(np.linspace(0, len(resp_change_mask), len(resp_change_mask)), resp_change_mask*-3,
color='green', alpha=0.5, label='Amplitude Change Mask')
ax1_twin.plot(np.linspace(0, len(resp_sa_mask), len(resp_sa_mask)), resp_sa_mask,
color='purple', alpha=0.5, label='SA Mask')
ax1_twin.autoscale(enable=False, axis='y', tight=True)
ax1_twin.set_ylim((-4, 5))
# ax1_twin.set_ylabel('Resp Masks')
# ax1_twin.set_yticks([0, 1])
# ax1_twin.set_yticklabels(['No', 'Yes'])
ax1_twin.grid(False)
ax1_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask', 'SA Mask'], loc='upper right')
ax1.set_title(f'Sample {samp_id} - Respiration Component')
ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
ax2.plot(np.linspace(0, len(bcg_data) // bcg_fs, len(bcg_data)), bcg_data, color='green')
ax2.set_ylabel('Amplitude')
ax2.set_xlabel('Time (s)')
ax2_twin = ax2.twinx()
ax2_twin.plot(np.linspace(0, len(bcg_low_amp_mask), len(bcg_low_amp_mask)), bcg_low_amp_mask*-1,
color='blue', alpha=0.5, label='Low Amplitude Mask')
ax2_twin.plot(np.linspace(0, len(bcg_movement_mask), len(bcg_movement_mask)), bcg_movement_mask*-2,
color='red', alpha=0.5, label='Movement Mask')
ax2_twin.plot(np.linspace(0, len(bcg_change_mask), len(bcg_change_mask)), bcg_change_mask*-3,
color='green', alpha=0.5, label='Amplitude Change Mask')
ax2_twin.autoscale(enable=False, axis='y', tight=True)
ax2_twin.set_ylim((-4, 2))
ax2_twin.set_ylabel('BCG Masks')
# ax2_twin.set_yticks([0, 1])
# ax2_twin.set_yticklabels(['No', 'Yes'])
ax2_twin.grid(False)
ax2_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask'], loc='upper right')
# ax2.set_title(f'Sample {samp_id} - BCG Component')
ax0_twin._lim_lock = False
ax1_twin._lim_lock = False
ax2_twin._lim_lock = False
def on_lims_change(event_ax):
if getattr(event_ax, '_lim_lock', False):
return
try:
event_ax._lim_lock = True
if event_ax == ax0_twin:
# 重新锁定 ax1 的 Y 轴范围
ax0_twin.set_ylim(-2, 2)
elif event_ax == ax1_twin:
ax1_twin.set_ylim(-4, 5)
elif event_ax == ax2_twin:
ax2_twin.set_ylim(-4, 2)
finally:
event_ax._lim_lock = False
ax0_twin.callbacks.connect('ylim_changed', on_lims_change)
ax1_twin.callbacks.connect('ylim_changed', on_lims_change)
ax2_twin.callbacks.connect('ylim_changed', on_lims_change)
plt.tight_layout()
if save_path is not None:
plt.savefig(save_path, dpi=300)
if show:
plt.show()
def draw_psg_signal(samp_id, tho_signal, abd_signal, flowp_signal, flowt_signal, spo2_signal, effort_signal, rri_signal, event_mask, fs,
show=False, save_path=None):
sa_mask = event_mask.repeat(fs)
fig, axs = plt.subplots(7, 1, figsize=(18, 12), sharex=True)
time_axis = np.linspace(0, len(tho_signal) / fs, len(tho_signal))
axs[0].plot(time_axis, tho_signal, label='THO', color='black')
axs[0].set_title(f'Sample {samp_id} - PSG Signal Data')
axs[0].set_ylabel('THO Amplitude')
axs[0].legend(loc='upper right')
ax0_twin = axs[0].twinx()
ax0_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
ax0_twin.autoscale(enable=False, axis='y', tight=True)
ax0_twin.set_ylim((-4, 5))
axs[1].plot(time_axis, abd_signal, label='ABD', color='black')
axs[1].set_ylabel('ABD Amplitude')
axs[1].legend(loc='upper right')
ax1_twin = axs[1].twinx()
ax1_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
ax1_twin.autoscale(enable=False, axis='y', tight=True)
ax1_twin.set_ylim((-4, 5))
axs[2].plot(time_axis, effort_signal, label='EFFO', color='black')
axs[2].set_ylabel('EFFO Amplitude')
axs[2].legend(loc='upper right')
ax2_twin = axs[2].twinx()
ax2_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
ax2_twin.autoscale(enable=False, axis='y', tight=True)
ax2_twin.set_ylim((-4, 5))
axs[3].plot(time_axis, flowp_signal, label='FLOWP', color='black')
axs[3].set_ylabel('FLOWP Amplitude')
axs[3].legend(loc='upper right')
ax3_twin = axs[3].twinx()
ax3_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
ax3_twin.autoscale(enable=False, axis='y', tight=True)
ax3_twin.set_ylim((-4, 5))
axs[4].plot(time_axis, flowt_signal, label='FLOWT', color='black')
axs[4].set_ylabel('FLOWT Amplitude')
axs[4].legend(loc='upper right')
ax4_twin = axs[4].twinx()
ax4_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
ax4_twin.autoscale(enable=False, axis='y', tight=True)
ax4_twin.set_ylim((-4, 5))
axs[5].plot(time_axis, rri_signal, label='RRI', color='black')
axs[5].set_ylabel('RRI Amplitude')
axs[5].legend(loc='upper right')
axs[6].plot(time_axis, spo2_signal, label='SPO2', color='black')
axs[6].set_ylabel('SPO2 Amplitude')
axs[6].set_xlabel('Time (s)')
axs[6].legend(loc='upper right')
if save_path is not None:
plt.savefig(save_path, dpi=300)
if show:
plt.show()

View File

@ -0,0 +1,183 @@
"""
本脚本完成对呼研所数据的处理包含以下功能
1. 数据读取与预处理
从传入路径中进行数据和标签的读取并进行初步的预处理
预处理包括为数据进行滤波去噪等操作
2. 数据清洗与异常值处理
3. 输出清晰后的统计信息
4. 数据保存
将处理后的数据保存到指定路径便于后续使用
主要是保存切分后的数据位置和标签
5. 可视化
提供数据处理前后的可视化对比帮助理解数据变化
绘制多条可用性趋势图展示数据的可用区间体动区间低幅值区间等
# 低幅值区间规则标定与剔除
# 高幅值连续体动规则标定与剔除
# 手动标定不可用区间提剔除
"""
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import shutil
import draw_tools
import utils
import numpy as np
import signal_method
import os
os.environ['DISPLAY'] = "localhost:10.0"
def process_one_signal(samp_id, show=False):
pass
tho_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Tho_Sync_*.txt"))
abd_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Abd_Sync_*.txt"))
flowp_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow P_Sync_*.txt"))
flowt_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow T_Sync_*.txt"))
spo2_signal_path = list((org_signal_root_path / f"{samp_id}").glob("SpO2_Sync_*.txt"))
stage_signal_path = list((org_signal_root_path / f"{samp_id}").glob("5_class_Sync_*.txt"))
if not tho_signal_path:
raise FileNotFoundError(f"Effort Tho_Sync file not found for sample ID: {samp_id}")
tho_signal_path = tho_signal_path[0]
print(f"Processing Effort Tho_Sync signal file: {tho_signal_path}")
if not abd_signal_path:
raise FileNotFoundError(f"Effort Abd_Sync file not found for sample ID: {samp_id}")
abd_signal_path = abd_signal_path[0]
print(f"Processing Effort Abd_Sync signal file: {abd_signal_path}")
if not flowp_signal_path:
raise FileNotFoundError(f"Flow P_Sync file not found for sample ID: {samp_id}")
flowp_signal_path = flowp_signal_path[0]
print(f"Processing Flow P_Sync signal file: {flowp_signal_path}")
if not flowt_signal_path:
raise FileNotFoundError(f"Flow T_Sync file not found for sample ID: {samp_id}")
flowt_signal_path = flowt_signal_path[0]
print(f"Processing Flow T_Sync signal file: {flowt_signal_path}")
if not spo2_signal_path:
raise FileNotFoundError(f"SpO2_Sync file not found for sample ID: {samp_id}")
spo2_signal_path = spo2_signal_path[0]
print(f"Processing SpO2_Sync signal file: {spo2_signal_path}")
if not stage_signal_path:
raise FileNotFoundError(f"5_class_Sync file not found for sample ID: {samp_id}")
stage_signal_path = stage_signal_path[0]
print(f"Processing 5_class_Sync signal file: {stage_signal_path}")
label_path = (label_root_path / f"{samp_id}").glob("SA Label_Sync.csv")
if not label_path:
raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}")
label_path = list(label_path)[0]
print(f"Processing Label_corrected file: {label_path}")
#
# # 保存处理后的数据和标签
save_samp_path = save_path / f"{samp_id}"
save_samp_path.mkdir(parents=True, exist_ok=True)
# # # 读取信号数据
tho_data_raw, tho_length, tho_fs, tho_second = utils.read_signal_txt(tho_signal_path, dtype=float, verbose=True)
# abd_data_raw, abd_length, abd_fs, abd_second = utils.read_signal_txt(abd_signal_path, dtype=float, verbose=True)
# flowp_data_raw, flowp_length, flowp_fs, flowp_second = utils.read_signal_txt(flowp_signal_path, dtype=float, verbose=True)
# flowt_data_raw, flowt_length, flowt_fs, flowt_second = utils.read_signal_txt(flowt_signal_path, dtype=float, verbose=True)
# spo2_data_raw, spo2_length, spo2_fs, spo2_second = utils.read_signal_txt(spo2_signal_path, dtype=float, verbose=True)
stage_data_raw, stage_length, stage_fs, stage_second = utils.read_signal_txt(stage_signal_path, dtype=str, verbose=True)
#
# # 预处理与滤波
# tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=tho_data_raw, effort_fs=tho_fs)
# abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=abd_data_raw, effort_fs=abd_fs)
# flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=flowp_data_raw, effort_fs=flowp_fs)
# flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=flowt_data_raw, effort_fs=flowt_fs)
# 降采样
# old_tho_fs = tho_fs
# tho_fs = conf["effort"]["downsample_fs"]
# tho_data_filt = utils.downsample_signal_fast(original_signal=tho_data_filt, original_fs=old_tho_fs, target_fs=tho_fs)
# old_abd_fs = abd_fs
# abd_fs = conf["effort"]["downsample_fs"]
# abd_data_filt = utils.downsample_signal_fast(original_signal=abd_data_filt, original_fs=old_abd_fs, target_fs=abd_fs)
# old_flowp_fs = flowp_fs
# flowp_fs = conf["effort"]["downsample_fs"]
# flowp_data_filt = utils.downsample_signal_fast(original_signal=flowp_data_filt, original_fs=old_flowp_fs, target_fs=flowp_fs)
# old_flowt_fs = flowt_fs
# flowt_fs = conf["effort"]["downsample_fs"]
# flowt_data_filt = utils.downsample_signal_fast(original_signal=flowt_data_filt, original_fs=old_flowt_fs, target_fs=flowt_fs)
# spo2不降采样
# spo2_data_filt = spo2_data_raw
# spo2_fs = spo2_fs
label_data = utils.read_raw_psg_label(path=label_path)
event_mask, score_mask = utils.generate_event_mask(signal_second=tho_second, event_df=label_data, use_correct=False, with_score=False)
# event_mask > 0 的部分为1其他为0
score_mask = np.where(event_mask > 0, 1, 0)
# 根据睡眠分期生成不可用区间
wake_mask = utils.get_wake_mask(stage_data_raw)
# 剔除短于60秒的觉醒区间
wake_mask = utils.remove_short_durations(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), min_duration_sec=60)
# 合并短于120秒的觉醒区间
wake_mask = utils.merge_short_gaps(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), max_gap_sec=60)
disable_label = wake_mask
disable_label = disable_label[:tho_second]
# 复制事件文件 到保存路径
sa_label_save_name = f"{samp_id}_" + label_path.name
shutil.copyfile(label_path, save_samp_path / sa_label_save_name)
#
# 新建一个dataframe分别是秒数、SA标签
save_dict = {
"Second": np.arange(tho_second),
"SA_Label": event_mask,
"SA_Score": score_mask,
"Disable_Label": disable_label,
"Resp_LowAmp_Label": np.zeros_like(event_mask),
"Resp_Movement_Label": np.zeros_like(event_mask),
"Resp_AmpChange_Label": np.zeros_like(event_mask),
"BCG_LowAmp_Label": np.zeros_like(event_mask),
"BCG_Movement_Label": np.zeros_like(event_mask),
"BCG_AmpChange_Label": np.zeros_like(event_mask)
}
mask_label_save_name = f"{samp_id}_Processed_Labels.csv"
utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict)
if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml"
# disable_df_path = project_root_path / "排除区间.xlsx"
#
conf = utils.load_dataset_conf(yaml_path)
root_path = Path(conf["root_path"])
save_path = Path(conf["mask_save_path"])
select_ids = conf["select_ids"]
#
print(f"select_ids: {select_ids}")
print(f"root_path: {root_path}")
print(f"save_path: {save_path}")
#
org_signal_root_path = root_path / "PSG_Aligned"
label_root_path = root_path / "PSG_Aligned"
#
# all_samp_disable_df = utils.read_disable_excel(disable_df_path)
#
# process_one_signal(select_ids[0], show=True)
# #
for samp_id in select_ids:
print(f"Processing sample ID: {samp_id}")
process_one_signal(samp_id, show=False)
print(f"Finished processing sample ID: {samp_id}\n\n")
pass

View File

@ -0,0 +1,246 @@
"""
本脚本完成对呼研所数据的处理包含以下功能
1. 数据读取与预处理
从传入路径中进行数据和标签的读取并进行初步的预处理
预处理包括为数据进行滤波去噪等操作
2. 数据清洗与异常值处理
3. 输出清晰后的统计信息
4. 数据保存
将处理后的数据保存到指定路径便于后续使用
主要是保存切分后的数据位置和标签
5. 可视化
提供数据处理前后的可视化对比帮助理解数据变化
绘制多条可用性趋势图展示数据的可用区间体动区间低幅值区间等
# 低幅值区间规则标定与剔除
# 高幅值连续体动规则标定与剔除
# 手动标定不可用区间提剔除
"""
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import shutil
import draw_tools
import utils
import numpy as np
import signal_method
import os
os.environ['DISPLAY'] = "localhost:10.0"
def process_one_signal(samp_id, show=False):
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
if not signal_path:
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
signal_path = signal_path[0]
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
label_path = (label_root_path / f"{samp_id}").glob("SA Label_corrected.csv")
if not label_path:
raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}")
label_path = list(label_path)[0]
print(f"Processing Label_corrected file: {label_path}")
# 保存处理后的数据和标签
save_samp_path = save_path / f"{samp_id}"
save_samp_path.mkdir(parents=True, exist_ok=True)
signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=True)
signal_data, resp_data, resp_fs, bcg_data, bcg_fs = signal_method.signal_filter_split(conf=conf, signal_data_raw=signal_data_raw, signal_fs=signal_fs)
# 降采样
old_resp_fs = resp_fs
resp_fs = conf["resp"]["downsample_fs_2"]
resp_data = utils.downsample_signal_fast(original_signal=resp_data, original_fs=old_resp_fs, target_fs=resp_fs)
bcg_fs = conf["bcg"]["downsample_fs"]
bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs)
label_data = utils.read_label_csv(path=label_path)
event_mask, score_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data)
manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[
all_samp_disable_df["id"] == samp_id])
print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}")
# 分析Resp的低幅值区间
resp_low_amp_conf = conf.get("resp_low_amp", None)
if resp_low_amp_conf is not None:
resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal(
signal_data=resp_data,
sampling_rate=resp_fs,
**resp_low_amp_conf
)
print(
f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}, count_low_amp_positions: {len(resp_low_amp_position_list)}")
else:
resp_low_amp_mask, resp_low_amp_position_list = None, None
print("resp_low_amp_mask is None")
# 分析Resp的高幅值伪迹区间
resp_movement_conf = conf.get("resp_movement", None)
if resp_movement_conf is not None:
raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement(
signal_data=resp_data,
sampling_rate=resp_fs,
**resp_movement_conf
)
print(
f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}")
else:
resp_movement_mask, resp_movement_position_list = None, None
print("resp_movement_mask is None")
resp_movement_revise_conf = conf.get("resp_movement_revise", None)
if resp_movement_mask is not None and resp_movement_revise_conf is not None:
resp_movement_mask, resp_movement_position_list = signal_method.movement_revise(
signal_data=resp_data,
movement_mask=resp_movement_mask,
movement_list=resp_movement_position_list,
sampling_rate=resp_fs,
**resp_movement_revise_conf,
verbose=False
)
print(
f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}")
else:
print("resp_movement_mask revise is skipped")
# 分析Resp的幅值突变区间
resp_amp_change_conf = conf.get("resp_amp_change", None)
if resp_amp_change_conf is not None and resp_movement_mask is not None:
resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v3(
signal_data=resp_data,
movement_mask=resp_movement_mask,
movement_list=resp_movement_position_list,
sampling_rate=resp_fs,
**resp_amp_change_conf,
verbose=False)
print(
f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}")
else:
resp_amp_change_mask = None
print("amp_change_mask is None")
# 分析Bcg的低幅值区间
bcg_low_amp_conf = conf.get("bcg_low_amp", None)
if bcg_low_amp_conf is not None:
bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal(
signal_data=bcg_data,
sampling_rate=bcg_fs,
**bcg_low_amp_conf
)
print(
f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}, count_low_amp_positions: {len(bcg_low_amp_position_list)}")
else:
bcg_low_amp_mask, bcg_low_amp_position_list = None, None
print("bcg_low_amp_mask is None")
# 分析Bcg的高幅值伪迹区间
bcg_movement_conf = conf.get("bcg_movement", None)
if bcg_movement_conf is not None:
raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement(
signal_data=bcg_data,
sampling_rate=bcg_fs,
**bcg_movement_conf
)
print(
f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}, count_movement_positions: {len(bcg_movement_position_list)}")
else:
bcg_movement_mask = None
print("bcg_movement_mask is None")
# 分析Bcg的幅值突变区间
if bcg_movement_mask is not None:
bcg_amp_change_mask, bcg_amp_change_list = signal_method.position_based_sleep_recognition_v2(
signal_data=bcg_data,
movement_mask=bcg_movement_mask,
sampling_rate=bcg_fs)
print(
f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}, count_amp_change_positions: {len(bcg_amp_change_list)}")
else:
bcg_amp_change_mask = None
print("bcg_amp_change_mask is None")
# 如果signal_data采样率过进行降采样
if signal_fs == 1000:
signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100)
signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs,
target_fs=100)
signal_fs = 100
draw_tools.draw_signal_with_mask(samp_id=samp_id,
signal_data=signal_data,
signal_fs=signal_fs,
resp_data=resp_data,
resp_fs=resp_fs,
bcg_data=bcg_data,
bcg_fs=bcg_fs,
signal_disable_mask=manual_disable_mask,
resp_low_amp_mask=resp_low_amp_mask,
resp_movement_mask=resp_movement_mask,
resp_change_mask=resp_amp_change_mask,
resp_sa_mask=event_mask,
bcg_low_amp_mask=bcg_low_amp_mask,
bcg_movement_mask=bcg_movement_mask,
bcg_change_mask=bcg_amp_change_mask,
show=show,
save_path=save_samp_path / f"{samp_id}_Signal_Plots.png")
# 复制事件文件 到保存路径
sa_label_save_name = f"{samp_id}_" + label_path.name
shutil.copyfile(label_path, save_samp_path / sa_label_save_name)
# 新建一个dataframe分别是秒数、SA标签SA质量标签禁用标签Resp低幅值标签Resp体动标签Resp幅值突变标签Bcg低幅值标签Bcg体动标签Bcg幅值突变标签
save_dict = {
"Second": np.arange(signal_second),
"SA_Label": event_mask,
"SA_Score": score_mask,
"Disable_Label": manual_disable_mask,
"Resp_LowAmp_Label": resp_low_amp_mask if resp_low_amp_mask is not None else np.zeros(signal_second, dtype=int),
"Resp_Movement_Label": resp_movement_mask if resp_movement_mask is not None else np.zeros(signal_second,
dtype=int),
"Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second,
dtype=int),
"BCG_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int),
"BCG_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second,
dtype=int),
"BCG_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second,
dtype=int)
}
mask_label_save_name = f"{samp_id}_Processed_Labels.csv"
utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict)
if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/HYS_config.yaml"
disable_df_path = project_root_path / "排除区间.xlsx"
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
save_path = Path(conf["mask_save_path"])
print(f"select_ids: {select_ids}")
print(f"root_path: {root_path}")
print(f"save_path: {save_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned"
label_root_path = root_path / "Label"
all_samp_disable_df = utils.read_disable_excel(disable_df_path)
# process_one_signal(select_ids[6], show=True)
#
for samp_id in select_ids:
print(f"Processing sample ID: {samp_id}")
process_one_signal(samp_id, show=False)
print(f"Finished processing sample ID: {samp_id}\n\n")

View File

@ -0,0 +1,42 @@
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import shutil
import draw_tools
import utils
import numpy as np
import signal_method
import os
import mne
from tqdm import tqdm
import xml.etree.ElementTree as ET
import re
# 获取分期和事件标签,以及不可用区间
def process_one_signal(samp_id, show=False):
if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/SHHS1_config.yaml"
conf = utils.load_dataset_conf(yaml_path)
root_path = Path(conf["root_path"])
save_path = Path(conf["mask_save_path"])
print(f"root_path: {root_path}")
print(f"save_path: {save_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned"
label_root_path = root_path / "Label"

View File

View File

@ -0,0 +1,6 @@
from .rule_base_event import detect_low_amplitude_signal, detect_movement
from .rule_base_event import position_based_sleep_recognition_v2, position_based_sleep_recognition_v3
from .rule_base_event import movement_revise
from .time_metrics import calc_mav_by_slide_windows
from .signal_process import signal_filter_split, rpeak2hr, psg_effort_filter, rpeak2rri_interpolation
from .normalize_method import normalize_resp_signal_by_segment

View File

@ -0,0 +1,52 @@
import utils
import pandas as pd
import numpy as np
from scipy import signal
def normalize_resp_signal_by_segment(resp_signal: np.ndarray, resp_fs, movement_mask, enable_list):
# 根据呼吸信号的幅值改变区间对每段进行Z-Score标准化
normalized_resp_signal = np.zeros_like(resp_signal)
# 全部填成nan
normalized_resp_signal[:] = np.nan
resp_signal_no_movement = resp_signal.copy()
resp_signal_no_movement[np.array(movement_mask == 1).repeat(resp_fs)] = np.nan
for i in range(len(enable_list)):
enable_start = enable_list[i][0] * resp_fs
enable_end = enable_list[i][1] * resp_fs
segment = resp_signal_no_movement[enable_start:enable_end]
# print(f"Normalizing segment {i+1}/{len(enable_list)}: start={enable_start}, end={enable_end}, length={len(segment)}")
segment_mean = np.nanmean(segment)
segment_std = np.nanstd(segment)
if segment_std == 0:
raise ValueError(f"segment_std is zero! segment_start: {enable_start}, segment_end: {enable_end}")
# 同下一个enable区间的体动一起进行标准化
if i <= len(enable_list) - 2:
enable_end = enable_list[i + 1][0] * resp_fs
raw_segment = resp_signal[enable_start:enable_end]
normalized_resp_signal[enable_start:enable_end] = (raw_segment - segment_mean) / segment_std
#如果enable区间不从0开始则将前面的部分也进行标准化
if enable_list[0][0] > 0:
new_enable_start = 0
enable_start = enable_list[0][0] * resp_fs
enable_end = enable_list[0][1] * resp_fs
segment = resp_signal_no_movement[enable_start:enable_end]
segment_mean = np.nanmean(segment)
segment_std = np.nanstd(segment)
if segment_std == 0:
raise ValueError(f"segment_std is zero! segment_start: {enable_start}, segment_end: {enable_end}")
raw_segment = resp_signal[new_enable_start:enable_start]
normalized_resp_signal[new_enable_start:enable_start] = (raw_segment - segment_mean) / segment_std
return normalized_resp_signal

View File

@ -0,0 +1,688 @@
import utils
from utils.operation_tools import timing_decorator
import numpy as np
from utils import merge_short_gaps, remove_short_durations, event_mask_2_list, collect_values
from signal_method.time_metrics import calc_mav_by_slide_windows
@timing_decorator()
def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=None,
std_median_multiplier=4.5, compare_intervals_sec=[30, 60],
interval_multiplier=2.5,
merge_gap_sec=10, min_duration_sec=5,
low_amplitude_periods=None):
"""
检测信号中的体动状态结合两种方法标准差比较和前后窗口幅值对比
使用反射填充处理信号边界
参数
- signal_data: numpy array输入的信号数据
- sampling_rate: int信号的采样率Hz
- window_size_sec: float分析窗口的时长默认值为 2
- stride_sec: float窗口滑动步长默认值为None等于window_size_sec无重叠
- std_median_multiplier: float标准差中位数的乘数阈值默认值为 4.5
- compare_intervals_sec: list用于比较的时间间隔列表默认为 [30, 60]
- interval_multiplier: float间隔中位数的上限乘数默认值为 2.5
- merge_gap_sec: float要合并的体动状态之间的最大间隔默认值为 10
- min_duration_sec: float要保留的体动状态的最小持续时间默认值为 5
- low_amplitude_periods: numpy array低幅值期间的掩码1表示低幅值期间默认为None
返回
- movement_mask: numpy array体动状态的掩码1表示体动0表示睡眠
"""
# 计算窗口大小(样本数)
window_samples = int(window_size_sec * sampling_rate)
# 如果未指定步长,设置为窗口大小(无重叠)
if stride_sec is None:
stride_sec = window_size_sec
# 计算步长(样本数)
stride_samples = int(stride_sec * sampling_rate)
# 确保步长至少为1
stride_samples = max(1, stride_samples)
# 计算需要的最大填充大小(基于比较间隔)
max_interval_samples = int(max(compare_intervals_sec) * sampling_rate)
# 应用反射填充以正确处理边界
# 填充大小为最大比较间隔的一半,以确保边界有足够的上下文
pad_size = max_interval_samples
padded_signal = np.pad(signal_data, pad_size, mode='reflect')
# 计算填充后的窗口数量
num_windows = max(1, (len(padded_signal) - window_samples) // stride_samples + 1)
# 初始化窗口标准差数组
window_std = np.zeros(num_windows)
# 计算每个窗口的标准差
# 分窗计算标准差
for i in range(num_windows):
start_idx = i * stride_samples
end_idx = min(start_idx + window_samples, len(padded_signal))
# 处理窗口,包括可能不完整的最后一个窗口
window_data = padded_signal[start_idx:end_idx]
if len(window_data) > 0:
window_std[i] = np.std(window_data, ddof=1)
else:
window_std[i] = 0
# 计算原始信号对应的窗口索引范围
# 填充后原始信号从pad_size开始
orig_start_window = pad_size // stride_samples
if stride_sec == 1:
orig_end_window = orig_start_window + (len(signal_data) // stride_samples)
else:
orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + 1
# 只保留原始信号对应的窗口标准差
original_window_std = window_std[orig_start_window:orig_end_window]
num_original_windows = len(original_window_std)
# 创建时间点数组(秒)
time_points = np.arange(num_original_windows) * stride_sec
# # 如果提供了低幅值期间的掩码,则在计算全局中位数时排除这些期间
# if low_amplitude_periods is not None and len(low_amplitude_periods) == num_original_windows:
# valid_std = original_window_std[low_amplitude_periods == 0]
# if len(valid_std) == 0: # 如果所有窗口都在低幅值期间
# valid_std = original_window_std # 使用全部窗口
# else:
# valid_std = original_window_std
valid_std = original_window_std ##20250418新修改
# ---------------------- 方法一基于STD的体动判定 ----------------------#
# 计算所有有效窗口标准差的中位数
median_std = np.median(valid_std)
# 当窗口标准差大于中位数的倍数,判定为体动状态
std_movement = np.where((original_window_std > (median_std * std_median_multiplier)), 1, 0)
# ------------------ 方法二:基于前后信号幅值变化的体动判定 ------------------#
amplitude_movement = np.zeros(num_original_windows, dtype=int)
# 定义基于时间粒度的比较间隔索引
compare_intervals_idx = [int(interval // stride_sec) for interval in compare_intervals_sec]
# 逐窗口判断
for win_idx in range(num_original_windows):
# 全局索引(在填充后的窗口数组中)
global_win_idx = win_idx + orig_start_window
# 对每个比较间隔进行检查
for interval_idx in compare_intervals_idx:
# 确定比较范围的结束索引(在填充后的窗口数组中)
end_idx = min(global_win_idx + interval_idx, len(window_std))
# 提取相应时间范围内的标准差值
if global_win_idx < end_idx:
interval_std = window_std[global_win_idx:end_idx]
# 计算该间隔的中位数
interval_median = np.median(interval_std)
# 计算上下阈值
upper_threshold = interval_median * interval_multiplier
# 检查当前窗口是否超出阈值范围,如果超出则直接标记为体动
if window_std[global_win_idx] > upper_threshold:
amplitude_movement[win_idx] = 1
break # 一旦确定为体动就不需要继续检查其他间隔
# 将两种方法的结果合并:只要其中一种方法判定为体动,最终结果就是体动
movement_mask = np.logical_or(std_movement, amplitude_movement).astype(int)
raw_movement_mask = movement_mask
# 如果需要合并间隔小的体动状态
if merge_gap_sec > 0:
movement_mask = merge_short_gaps(movement_mask, time_points, merge_gap_sec)
# 如果需要移除短时体动状态
if min_duration_sec > 0:
movement_mask = remove_short_durations(movement_mask, time_points, min_duration_sec)
# raw_movement_mask movement_mask恢复对应秒数而不是点数
raw_movement_mask = raw_movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate]
movement_mask = movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate]
# 比较剔除的体动如果被剔除的体动所在区域有高于3std的幅值则不剔除
removed_movement_mask = (raw_movement_mask - movement_mask) > 0
removed_movement_start = np.where(np.diff(np.concatenate([[0], removed_movement_mask])) == 1)[0]
removed_movement_end = np.where(np.diff(np.concatenate([removed_movement_mask, [0]])) == -1)[0]
for start, end in zip(removed_movement_start, removed_movement_end):
# print(start ,end)
# 计算剔除的体动区域的幅值
if np.nanmax(signal_data[start * sampling_rate:(end + 1) * sampling_rate]) > median_std * std_median_multiplier:
movement_mask[start:end + 1] = 1
# raw体动起止位置 [[start, end], [start, end], ...]
raw_movement_position_list = event_mask_2_list(raw_movement_mask)
# merge体动起止位置 [[start, end], [start, end], ...]
movement_position_list = event_mask_2_list(movement_mask)
return raw_movement_mask, movement_mask, raw_movement_position_list, movement_position_list
def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up_interval_multiplier: float,
down_interval_multiplier: float, compare_intervals_sec, merge_gap_sec, min_duration_sec, verbose=False):
"""
基于标准差对已有体动掩码进行修正 用于大尺度的体动检测后的位置精细修正
参数
- signal_data: numpy array输入的信号数据
- sampling_rate: int信号的采样率Hz
- movement_mask: numpy array已有的体动掩码1表示体动0表示睡眠
返回
- revised_movement_mask: numpy array修正后的体动掩码
"""
window_size = sampling_rate
stride_size = sampling_rate
time_points = np.arange(len(signal_data))
compare_size = int(compare_intervals_sec // (stride_size / sampling_rate))
_, mav = calc_mav_by_slide_windows(signal_data, movement_mask=None, low_amp_mask=None, sampling_rate=sampling_rate,
window_second=4, step_second=1,
inner_window_second=4)
# 往左右两边取compare_size个点的mav取平均值
for start, end in movement_list:
left_points = start - 20
right_points = end + 20
left_values = collect_values(arr=mav, index=left_points, step=-1, limit=compare_size, mask=movement_mask)
right_values = collect_values(arr=mav, index=right_points, step=1, limit=compare_size, mask=movement_mask)
left_value_metrics = np.median(left_values) if len(left_values) > 0 else 0
right_value_metrics = np.median(right_values) if len(right_values) > 0 else 0
# if left_value_metrics == 0:
# value_metrics = right_value_metrics
# elif right_value_metrics == 0:
# value_metrics = left_value_metrics
# else:
# value_metrics = np.mean([left_value_metrics, right_value_metrics])
if left_value_metrics == 0:
left_value_metrics = right_value_metrics
elif right_value_metrics == 0:
right_value_metrics = left_value_metrics
if verbose:
print(f"Revising movement from index {start} to {end}, left_metric: {left_value_metrics:.2f}, right_metric: {right_value_metrics:.2f}")
for i in range(left_points, right_points):
if i < 0 or i >= len(mav):
continue
if i < start:
value_metrics = left_value_metrics
elif i > end:
value_metrics = right_value_metrics
else:
value_metrics = (left_value_metrics + right_value_metrics) / 2
if mav[i] > (value_metrics * up_interval_multiplier) and movement_mask[i] == 0:
movement_mask[i] = 1
if verbose:
print(f"Normal revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_metrics * up_interval_multiplier:.2f}")
elif mav[i] < (value_metrics * down_interval_multiplier) and movement_mask[i] == 1:
movement_mask[i] = 0
if verbose:
print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_metrics * down_interval_multiplier:.2f}")
else:
if verbose:
print(f"No revision at index {i}, mav: {mav[i]:.2f}, up_threshold: {value_metrics * up_interval_multiplier:.2f}, down_threshold: {value_metrics * down_interval_multiplier:.2f}")
#
# 逐秒遍历mav判断是否需要修正
# print(f"Revising movement from index {start} to {end}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}")
# for i in range(left_points, right_points):
# if i < 0 or i >= len(mav):
# continue
# # print(f"Index {i}, mav: {mav[i]:.2f}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}")
# if mav[i] > (value_metrics * up_interval_multiplier):
# movement_mask[i] = 1
# # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * up_interval_multiplier:.2f}")
# elif mav[i] < (value_metrics * down_interval_multiplier):
# movement_mask[i] = 0
# # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * down_interval_multiplier:.2f}")
# # else:
# # print(f"No revision at index {i}, mav: {mav[i]:.2f}, up_threshold: {value_mean * up_interval_multiplier:.2f}, down_threshold: {value_mean * down_interval_multiplier:.2f}")
# #
# 如果需要合并间隔小的体动状态
if merge_gap_sec > 0:
movement_mask = merge_short_gaps(movement_mask, time_points, merge_gap_sec)
# 如果需要移除短时体动状态
if min_duration_sec > 0:
movement_mask = remove_short_durations(movement_mask, time_points, min_duration_sec)
movement_list = event_mask_2_list(movement_mask)
return movement_mask, movement_list
@timing_decorator()
def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, stride_sec=None,
amplitude_threshold=50, merge_gap_sec=10, min_duration_sec=5):
"""
检测信号中的低幅值状态通过计算RMS值判断信号强度是否低于设定阈值
参数
- signal_data: numpy array输入的信号数据
- sampling_rate: int信号的采样率Hz
- window_size_sec: float分析窗口的时长默认值为 1
- stride_sec: float窗口滑动步长默认值为None等于window_size_sec无重叠
- amplitude_threshold: floatRMS阈值低于此值表示低幅值状态默认值为 50
- merge_gap_sec: float要合并的状态之间的最大间隔默认值为 10
- min_duration_sec: float要保留的状态的最小持续时间默认值为 5
返回
- low_amplitude_mask: numpy array低幅值状态的掩码1表示低幅值0表示正常幅值
"""
# 计算窗口大小(样本数)
window_samples = int(window_size_sec * sampling_rate)
# 如果未指定步长,设置为窗口大小(无重叠)
if stride_sec is None:
stride_sec = window_size_sec
# 计算步长(样本数)
stride_samples = int(stride_sec * sampling_rate)
# 确保步长至少为1
stride_samples = max(sampling_rate, stride_samples)
# 处理信号边界,使用反射填充
pad_size = window_samples // 2
padded_signal = np.pad(signal_data, pad_size, mode='reflect')
# 计算填充后的窗口数量
num_windows = max(1, (len(padded_signal) - window_samples) // stride_samples + 1)
# 初始化RMS值数组
rms_values = np.zeros(num_windows)
# 计算每个窗口的RMS值
for i in range(num_windows):
start_idx = i * stride_samples
end_idx = min(start_idx + window_samples, len(signal_data))
# 处理窗口,包括可能不完整的最后一个窗口
window_data = signal_data[start_idx:end_idx]
if len(window_data) > 0:
rms_values[i] = np.sqrt(np.mean(window_data ** 2))
else:
rms_values[i] = 0
# 生成初始低幅值掩码RMS低于阈值的窗口标记为1低幅值其他为0
low_amplitude_mask = np.where(rms_values <= amplitude_threshold, 1, 0)
# 计算原始信号对应的窗口索引范围
orig_start_window = pad_size // stride_samples
if stride_sec == 1:
orig_end_window = orig_start_window + (len(signal_data) // stride_samples)
else:
orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + 1
# 只保留原始信号对应的窗口低幅值掩码
low_amplitude_mask = low_amplitude_mask[orig_start_window:orig_end_window]
# print("low_amplitude_mask_length: ", len(low_amplitude_mask))
num_original_windows = len(low_amplitude_mask)
# 转换为时间轴上的状态序列
# 计算每个窗口对应的时间点(秒)
time_points = np.arange(num_original_windows) * stride_sec
# 如果需要合并间隔小的状态
if merge_gap_sec > 0:
low_amplitude_mask = merge_short_gaps(low_amplitude_mask, time_points, merge_gap_sec)
# 如果需要移除短时状态
if min_duration_sec > 0:
low_amplitude_mask = remove_short_durations(low_amplitude_mask, time_points, min_duration_sec)
low_amplitude_mask = low_amplitude_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate]
# 低幅值状态起止位置 [[start, end], [start, end], ...]
low_amplitude_position_list = event_mask_2_list(low_amplitude_mask)
return low_amplitude_mask, low_amplitude_position_list
def get_typical_segment_for_continues_signal(signal_data, sampling_rate=100, window_size=30, step_size=1):
"""
获取十个片段
:param signal_data: 信号数据
:param sampling_rate: 采样率
:param window_size: 窗口大小
:param step_size: 步长
:return: 典型片段列表
"""
pass
# 基于体动位置和幅值的睡姿识别
# 主要是依靠体动mask将整夜分割成多个有效片段然后每个片段计算幅值指标判断两个片段的幅值指标是否存在显著差异如果存在显著差异则认为存在睡姿变化
# 考虑到每个片段长度为10s所以每个片段的幅值指标计算时间长度为10s然后计算每个片段的幅值指标
# 仅对比相邻片段的幅值指标如果存在显著差异则认为存在睡姿变化即每个体动相邻的30秒内存在睡姿变化如果片段不足30秒则按实际长度对比
@timing_decorator()
def position_based_sleep_recognition_v1(signal_data, movement_mask, sampling_rate=100, window_size_sec=30,
interval_to_movement=10):
mav_calc_window_sec = 2 # 计算mav的窗口大小单位秒
# 获取有效片段起止位置
valid_mask = 1 - movement_mask
valid_starts = np.where(np.diff(np.concatenate([[0], valid_mask])) == 1)[0]
valid_ends = np.where(np.diff(np.concatenate([valid_mask, [0]])) == -1)[0]
movement_start = np.where(np.diff(np.concatenate([[0], movement_mask])) == 1)[0]
movement_end = np.where(np.diff(np.concatenate([movement_mask, [0]])) == -1)[0]
segment_left_average_amplitude = []
segment_right_average_amplitude = []
segment_left_average_energy = []
segment_right_average_energy = []
# window_samples = int(window_size_sec * sampling_rate)
# pad_size = window_samples // 2
# padded_signal = np.pad(signal_data, pad_size, mode='reflect')
for start, end in zip(valid_starts, valid_ends):
start *= sampling_rate
end *= sampling_rate
# 避免过短的片段
if end - start <= sampling_rate: # 小于1秒的片段不考虑
continue
# 获取当前片段数据
elif end - start < (window_size_sec * interval_to_movement + 1) * sampling_rate:
left_start = start
left_end = min(end, left_start + window_size_sec * sampling_rate)
right_start = max(start, end - window_size_sec * sampling_rate)
right_end = end
else:
left_start = start + interval_to_movement * sampling_rate
left_end = left_start + window_size_sec * sampling_rate
right_start = end - interval_to_movement * sampling_rate - window_size_sec * sampling_rate
right_end = end
# 新的end - start确保为200的整数倍
if (left_end - left_start) % (mav_calc_window_sec * sampling_rate) != 0:
left_end = left_start + ((left_end - left_start) // (mav_calc_window_sec * sampling_rate)) * (
mav_calc_window_sec * sampling_rate)
if (right_end - right_start) % (mav_calc_window_sec * sampling_rate) != 0:
right_end = right_start + ((right_end - right_start) // (mav_calc_window_sec * sampling_rate)) * (
mav_calc_window_sec * sampling_rate)
# 计算每个片段的幅值指标
left_mav = np.mean(np.max(signal_data[left_start:left_end].reshape(-1, mav_calc_window_sec * sampling_rate),
axis=0)) - np.mean(
np.min(signal_data[left_start:left_end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0))
right_mav = np.mean(
np.max(signal_data[right_start:right_end].reshape(-1, mav_calc_window_sec * sampling_rate),
axis=0)) - np.mean(
np.min(signal_data[right_start:right_end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0))
segment_left_average_amplitude.append(left_mav)
segment_right_average_amplitude.append(right_mav)
left_energy = np.sum(np.abs(signal_data[left_start:left_end] ** 2))
right_energy = np.sum(np.abs(signal_data[right_start:right_end] ** 2))
segment_left_average_energy.append(left_energy)
segment_right_average_energy.append(right_energy)
position_changes = []
position_change_times = []
# 判断是否存在显著变化 (可根据实际情况调整阈值)
threshold_amplitude = 0.1 # 幅值变化阈值
threshold_energy = 0.1 # 能量变化阈值
for i in range(1, len(segment_left_average_amplitude)):
# 计算幅值指标的变化率
left_amplitude_change = abs(segment_left_average_amplitude[i] - segment_left_average_amplitude[i - 1]) / max(
segment_left_average_amplitude[i - 1], 1e-6)
right_amplitude_change = abs(segment_right_average_amplitude[i] - segment_right_average_amplitude[i - 1]) / max(
segment_right_average_amplitude[i - 1], 1e-6)
# 计算能量指标的变化率
left_energy_change = abs(segment_left_average_energy[i] - segment_left_average_energy[i - 1]) / max(
segment_left_average_energy[i - 1], 1e-6)
right_energy_change = abs(segment_right_average_energy[i] - segment_right_average_energy[i - 1]) / max(
segment_right_average_energy[i - 1], 1e-6)
# 如果左右通道中的任一通道同时满足幅值和能量的变化阈值,则认为存在姿势变化
left_significant_change = (left_amplitude_change > threshold_amplitude) and (
left_energy_change > threshold_energy)
right_significant_change = (right_amplitude_change > threshold_amplitude) and (
right_energy_change > threshold_energy)
if left_significant_change or right_significant_change:
# 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示
position_changes.append(1)
position_change_times.append((movement_start[i - 1], movement_end[i - 1]))
else:
position_changes.append(0) # 0表示不存在姿势变化
return position_changes, position_change_times
def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rate=100):
"""
:param signal_data:
:param movement_mask: mask的采样率为1Hz
:param sampling_rate:
:param window_size_sec:
:return:
"""
mav_calc_window_sec = 2 # 计算mav的窗口大小单位秒
# 获取有效片段起止位置
valid_mask = 1 - movement_mask
valid_starts = np.where(np.diff(np.concatenate([[0], valid_mask])) == 1)[0]
valid_ends = np.where(np.diff(np.concatenate([valid_mask, [0]])) == -1)[0]
movement_start = np.where(np.diff(np.concatenate([[0], movement_mask])) == 1)[0]
movement_end = np.where(np.diff(np.concatenate([movement_mask, [0]])) == -1)[0]
segment_average_amplitude = []
segment_average_energy = []
signal_data_no_movement = signal_data.copy()
for start, end in zip(movement_start, movement_end):
signal_data_no_movement[start * sampling_rate:end * sampling_rate] = np.nan
# from matplotlib import pyplot as plt
# plt.plot(signal_data, alpha=0.3, color='gray')
# plt.plot(signal_data_no_movement, color='blue', linewidth=1)
# plt.show()
for start, end in zip(valid_starts, valid_ends):
start *= sampling_rate
end *= sampling_rate
# 避免过短的片段
if end - start <= sampling_rate: # 小于1秒的片段不考虑
continue
# 新的end - start确保为200的整数倍
if (end - start) % (mav_calc_window_sec * sampling_rate) != 0:
end = start + ((end - start) // (mav_calc_window_sec * sampling_rate)) * (
mav_calc_window_sec * sampling_rate)
# 计算每个片段的幅值指标
mav = np.nanmean(
np.nanmax(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate),
axis=0)) - np.nanmean(
np.nanmin(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0))
segment_average_amplitude.append(mav)
energy = np.nansum(np.abs(signal_data_no_movement[start:end] ** 2))
segment_average_energy.append(energy)
position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int)
position_change_times = []
# 判断是否存在显著变化 (可根据实际情况调整阈值)
threshold_amplitude = 0.1 # 幅值变化阈值
threshold_energy = 0.1 # 能量变化阈值
for i in range(1, len(segment_average_amplitude)):
# 计算幅值指标的变化率
amplitude_change = abs(segment_average_amplitude[i] - segment_average_amplitude[i - 1]) / max(
segment_average_amplitude[i - 1], 1e-6)
# 计算能量指标的变化率
energy_change = abs(segment_average_energy[i] - segment_average_energy[i - 1]) / max(
segment_average_energy[i - 1], 1e-6)
significant_change = (amplitude_change > threshold_amplitude) and (energy_change > threshold_energy)
if significant_change:
# 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示
position_changes[movement_start[i - 1]:movement_end[i - 1]] = 1
position_change_times.append((movement_start[i - 1], movement_end[i - 1]))
return position_changes, position_change_times
def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_list, sampling_rate, mav_calc_window_sec,
threshold_amplitude, threshold_energy, verbose=False):
"""
:param threshold_energy:
:param threshold_amplitude:
:param mav_calc_window_sec:
:param movement_list:
:param signal_data:
:param movement_mask: mask的采样率为1Hz
:param sampling_rate:
:param window_size_sec:
:return:
"""
# 获取有效片段起止位置
valid_list = utils.event_mask_2_list(movement_mask, event_true=False)
segment_average_amplitude = []
segment_average_energy = []
signal_data_no_movement = signal_data.copy()
for start, end in movement_list:
signal_data_no_movement[start * sampling_rate:end * sampling_rate] = np.nan
# from matplotlib import pyplot as plt
# plt.plot(signal_data, alpha=0.3, color='gray')
# plt.plot(signal_data_no_movement, color='blue', linewidth=1)
# plt.show()
if len(valid_list) < 2:
return [], []
def clac_mav(data_segment):
# 确定data_segment长度为mav_calc_window_sec的整数倍
if len(data_segment) % (mav_calc_window_sec * sampling_rate) != 0:
data_segment = data_segment[:-(len(data_segment) % (mav_calc_window_sec * sampling_rate))]
mav = np.nanstd(
np.nanmax(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate),
axis=0) -
np.nanmin(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=0))
return mav
def clac_energy(data_segment):
energy = np.nansum(np.abs(data_segment ** 2)) // (len(data_segment) // sampling_rate)
return energy
def calc_mav_by_quantiles(data_segment):
# 先计算所有的mav值
if len(data_segment) % (mav_calc_window_sec * sampling_rate) != 0:
data_segment = data_segment[:-(len(data_segment) % (mav_calc_window_sec * sampling_rate))]
mav_values = np.nanmax(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=1) - np.nanmin(
data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=1)
# 计算分位数
q20 = np.nanpercentile(mav_values, 20)
q80 = np.nanpercentile(mav_values, 80)
mav_values = mav_values[(mav_values >= q20) & (mav_values <= q80)]
mav = np.nanmean(mav_values)
return mav
position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int)
position_change_list = []
pre_valid_start = valid_list[0][0] * sampling_rate
pre_valid_end = valid_list[0][1] * sampling_rate
if verbose:
print(f"Total movement segments to analyze: {len(movement_list)}")
print(f"Total valid segments available: {len(valid_list)}")
for i in range(len(movement_list)):
if verbose:
print(f"Analyzing movement segment {i + 1}/{len(movement_list)}")
if i + 1 >= len(valid_list):
if verbose:
print("No more valid segments to compare. Ending analysis.")
break
next_valid_start = valid_list[i + 1][0] * sampling_rate
next_valid_end = valid_list[i + 1][1] * sampling_rate
movement_start = movement_list[i][0]
movement_end = movement_list[i][1]
# 避免过短的片段
if movement_end - movement_start <= 1: # 小于1秒的片段不考虑
if verbose:
print(
f"Skipping movement segment {i + 1} due to insufficient length. movement start: {movement_start}, movement end: {movement_end}")
pre_valid_start = pre_valid_start
pre_valid_end = next_valid_end
continue
# 计算前后片段的幅值和能量
# left_mav = clac_mav(signal_data_no_movement[pre_valid_start:pre_valid_end])
# right_mav = clac_mav(signal_data_no_movement[next_valid_start:next_valid_end])
# left_energy = clac_energy(signal_data_no_movement[pre_valid_start:pre_valid_end])
# right_energy = clac_energy(signal_data_no_movement[next_valid_start:next_valid_end])
left_mav = calc_mav_by_quantiles(signal_data_no_movement[pre_valid_start:pre_valid_end])
right_mav = calc_mav_by_quantiles(signal_data_no_movement[next_valid_start:next_valid_end])
# 计算幅值指标的变化率
amplitude_change = abs(right_mav - left_mav) / max(left_mav, 1e-6)
# # 计算能量指标的变化率
# energy_change = abs(right_energy - left_energy) / max(left_energy, 1e-6)
# significant_change = (amplitude_change > threshold_amplitude) and (energy_change > threshold_energy)
significant_change = (amplitude_change > threshold_amplitude)
if significant_change:
# print(
# f"Significant position change detected between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}")
if verbose:
print(f"Significant position change detected between segments {movement_start}s and {movement_end}:s left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right:{next_valid_start}to{next_valid_end} right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}")
# 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示
position_changes[movement_start:movement_end] = 1
position_change_list.append(movement_list[i])
# 更新前后片段
pre_valid_start = next_valid_start
pre_valid_end = next_valid_end
else:
# print(
# f"No significant position change between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}")
if verbose:
print(f"No significant position change between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}")
# 仅更新前片段
pre_valid_start = pre_valid_start
pre_valid_end = next_valid_end
return position_changes, position_change_list

View File

@ -0,0 +1,62 @@
import xml.etree.ElementTree as ET
ANNOTATION_MAP = {
"Wake|0": 0,
"Stage 1 sleep|1": 1,
"Stage 2 sleep|2": 2,
"Stage 3 sleep|3": 3,
"Stage 4 sleep|4": 4,
"REM sleep|5": 5,
"Unscored|9": 9,
"Movement|6": 6
}
SA_EVENTS = ['Central apnea', 'Hypopnea', 'Obstructive apnea']
def parse_sleep_annotations(annotation_path):
"""解析睡眠分期注释"""
try:
tree = ET.parse(annotation_path)
root = tree.getroot()
events = []
for scored_event in root.findall('.//ScoredEvent'):
event_type = scored_event.find('EventType').text
if event_type != "Stages|Stages":
continue
description = scored_event.find('EventConcept').text
start = float(scored_event.find('Start').text)
duration = float(scored_event.find('Duration').text)
if description not in ANNOTATION_MAP:
continue
events.append({
'onset': start,
'duration': duration,
'description': description,
'stage': ANNOTATION_MAP[description]
})
return events
except Exception as e:
return None
def extract_osa_events(annotation_path):
"""提取睡眠呼吸暂停事件"""
try:
tree = ET.parse(annotation_path)
root = tree.getroot()
events = []
for scored_event in root.findall('.//ScoredEvent'):
event_concept = scored_event.find('EventConcept').text
event_type = event_concept.split('|')[0].strip()
if event_type in SA_EVENTS:
start = float(scored_event.find('Start').text)
duration = float(scored_event.find('Duration').text)
if duration >= 10:
events.append({
'start': start,
'duration': duration,
'end': start + duration,
'type': event_type
})
return events
except Exception as e:
return []

View File

@ -0,0 +1,107 @@
import numpy as np
from scipy.interpolate import interp1d
import utils
def signal_filter_split(conf, signal_data_raw, signal_fs, verbose=True):
# 滤波
# 50Hz陷波滤波器
# signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs)
if verbose:
print("Applying 50Hz notch filter...")
signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs)
resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs)
resp_fs = conf["resp"]["downsample_fs_1"]
resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs)
resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20)
resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"],
low_cut=conf["resp_filter"]["low_cut"],
high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"],
sample_rate=resp_fs)
if verbose:
print("Begin plotting signal data...")
# fig = plt.figure(figsize=(12, 8))
# # 绘制三个图raw_data、resp_data_1、resp_data_2
# ax0 = fig.add_subplot(3, 1, 1)
# ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue')
# ax0.set_title('Raw Signal Data')
# ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
# ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange')
# ax1.set_title('Resp Data after Average Filtering')
# ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
# ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green')
# ax2.set_title('Resp Data after Butterworth Filtering')
# plt.tight_layout()
# plt.show()
bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"],
low_cut=conf["bcg_filter"]["low_cut"],
high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"],
sample_rate=signal_fs)
return signal_data, resp_data_2, resp_fs, bcg_data, signal_fs
def psg_effort_filter(conf, effort_data_raw, effort_fs):
# 滤波
effort_data_1 = utils.bessel(data=effort_data_raw, _type=conf["effort_filter"]["filter_type"],
low_cut=conf["effort_filter"]["low_cut"],
high_cut=conf["effort_filter"]["high_cut"], order=conf["effort_filter"]["order"],
sample_rate=effort_fs)
# 移动平均
effort_data_2 = utils.average_filter(raw_data=effort_data_1, sample_rate=effort_fs, window_size_sec=conf["average_filter"]["window_size_sec"])
return effort_data_raw, effort_data_2, effort_fs
def rpeak2hr(rpeak_indices, signal_length, ecg_fs):
hr_signal = np.zeros(signal_length)
for i in range(1, len(rpeak_indices)):
rri = rpeak_indices[i] - rpeak_indices[i - 1]
if rri == 0:
continue
hr = 60 * ecg_fs / rri # 心率单位bpm
if hr > 120:
hr = 120
elif hr < 30:
hr = 30
hr_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = hr
# 填充最后一个R峰之后的心率值
if len(rpeak_indices) > 1:
hr_signal[rpeak_indices[-1]:] = hr_signal[rpeak_indices[-2]]
return hr_signal
def rpeak2rri_repeat(rpeak_indices, signal_length, ecg_fs):
rri_signal = np.zeros(signal_length)
for i in range(1, len(rpeak_indices)):
rri = rpeak_indices[i] - rpeak_indices[i - 1]
rri_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = rri
# 填充最后一个R峰之后的RRI值
if len(rpeak_indices) > 1:
rri_signal[rpeak_indices[-1]:] = rri_signal[rpeak_indices[-2]]
# 遍历异常值
for i in range(1, len(rpeak_indices)):
rri = rpeak_indices[i] - rpeak_indices[i - 1]
if rri < 0.3 * ecg_fs or rri > 2 * ecg_fs:
rri_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = 0
return rri_signal
def rpeak2rri_interpolation(rpeak_indices, ecg_fs, rri_fs):
r_peak_time = np.asarray(rpeak_indices) / ecg_fs
rri = np.diff(r_peak_time)
t_rri = r_peak_time[1:]
mask = (rri > 0.3) & (rri < 2.0)
rri_clean = rri[mask]
t_rri_clean = t_rri[mask]
t_uniform = np.arange(t_rri_clean[0], t_rri_clean[-1], 1/rri_fs)
f = interp1d(t_rri_clean, rri_clean, kind='linear', fill_value="extrapolate")
rri_uniform = f(t_uniform)
return rri_uniform, rri_fs

View File

@ -0,0 +1,45 @@
from utils.operation_tools import calculate_by_slide_windows
from utils.operation_tools import timing_decorator
import numpy as np
@timing_decorator()
def calc_mav_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window_second=10, step_second=1, inner_window_second=2):
if movement_mask is not None:
assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}"
# assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}"
# print(f"movement_mask_length: {len(movement_mask)}, signal_data_length: {len(signal_data)}")
processed_mask = movement_mask.copy()
else:
processed_mask = None
def mav_func(x):
return np.mean(np.nanmax(x.reshape(-1, inner_window_second*sampling_rate), axis=1) - np.nanmin(x.reshape(-1, inner_window_second*sampling_rate), axis=1)) / 2
mav_nan, mav = calculate_by_slide_windows(mav_func, signal_data, processed_mask, sampling_rate=sampling_rate,
window_second=window_second, step_second=step_second)
return mav_nan, mav
@timing_decorator()
def calc_wavefactor_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100):
assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}"
assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}"
processed_mask = movement_mask.copy()
processed_mask = processed_mask.repeat(sampling_rate)
wavefactor_nan, wavefactor = calculate_by_slide_windows(lambda x: np.sqrt(np.mean(x ** 2)) / np.mean(np.abs(x)),
signal_data, processed_mask, sampling_rate=sampling_rate, window_second=2, step_second=1)
return wavefactor_nan, wavefactor
@timing_decorator()
def calc_peakfactor_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100):
assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}"
assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}"
processed_mask = movement_mask.copy()
processed_mask = processed_mask.repeat(sampling_rate)
peakfactor_nan, peakfactor = calculate_by_slide_windows(lambda x: np.max(np.abs(x)) / np.sqrt(np.mean(x ** 2)),
signal_data, processed_mask, sampling_rate=sampling_rate, window_second=2, step_second=1)
return peakfactor_nan, peakfactor

354
utils/HYS_FileReader.py Normal file
View File

@ -0,0 +1,354 @@
from pathlib import Path
from typing import Union
import utils
from .event_map import N2Chn
import numpy as np
import pandas as pd
from .operation_tools import event_mask_2_list
# 尝试导入 Polars
try:
import polars as pl
HAS_POLARS = True
except ImportError:
HAS_POLARS = False
def read_signal_txt(path: Union[str, Path], dtype, verbose=True, is_peak=False):
"""
Read a txt file and return the first column as a numpy array.
Args:
:param is_peak:
:param path:
:param verbose:
:param dtype:
Returns:
np.ndarray: The first column of the txt file as a numpy array.
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
if HAS_POLARS:
df = pl.read_csv(path, has_header=False, infer_schema_length=0)
signal_data_raw = df[:, 0].to_numpy().astype(dtype)
else:
df = pd.read_csv(path, header=None, dtype=dtype)
signal_data_raw = df.iloc[:, 0].to_numpy()
signal_original_length = len(signal_data_raw)
signal_fs = int(path.stem.split("_")[-1])
if is_peak:
signal_second = None
signal_length = None
else:
signal_second = signal_original_length // signal_fs
# 根据采样率进行截断
signal_data_raw = signal_data_raw[:signal_second * signal_fs]
signal_length = len(signal_data_raw)
if verbose:
print(f"Signal file read from {path}")
print(f"signal_fs: {signal_fs}")
print(f"signal_original_length: {signal_original_length}")
print(f"signal_after_cut_off_length: {signal_length}")
print(f"signal_second: {signal_second}")
return signal_data_raw, signal_length, signal_fs, signal_second
def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame:
"""
Read a CSV file and return it as a pandas DataFrame.
Args:
path (str | Path): Path to the CSV file.
verbose (bool):
Returns:
pd.DataFrame: The content of the CSV file as a pandas DataFrame.
:param path:
:param verbose:
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
# 直接用pandas读取 包含中文 故指定编码
df = pd.read_csv(path, encoding="gbk")
if verbose:
print(f"Label file read from {path}, number of rows: {len(df)}")
# 统计打标情况
# isLabeled=1 表示已打标
# Event type 有值的为PSG导出的事件
# Event type 为nan的为手动打标的事件
# score=1 显著事件, score=2 为受干扰事件 score=3 为非显著应删除事件
# 确认后的事件在correct_EventsType
# 输出事件信息 按照总计事件、低通气、中枢性、阻塞性、混合型按行输出 格式为 总计/来自PSG/手动/删除/未标注
# Columns:
# Index Event type Stage Time Epoch Date Duration HR bef. HR extr. HR delta O2 bef. O2 min. O2 delta Body Position Validation Start End score remark correct_Start correct_End correct_EventsType isLabeled
# Event type:
# Hypopnea
# Central apnea
# Obstructive apnea
# Mixed apnea
num_total = np.sum((df["isLabeled"] == 1) & (df["score"] != 3))
num_psg_events = np.sum(df["Event type"].notna())
num_manual_events = np.sum(df["Event type"].isna())
num_deleted = np.sum(df["score"] == 3)
# 统计事件
num_unlabeled = np.sum(df["isLabeled"] == -1)
num_psg_hyp = np.sum(df["Event type"] == "Hypopnea")
num_psg_csa = np.sum(df["Event type"] == "Central apnea")
num_psg_osa = np.sum(df["Event type"] == "Obstructive apnea")
num_psg_msa = np.sum(df["Event type"] == "Mixed apnea")
num_hyp = np.sum((df["correct_EventsType"] == "Hypopnea") & (df["score"] != 3))
num_csa = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] != 3))
num_osa = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] != 3))
num_msa = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] != 3))
num_manual_hyp = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Hypopnea"))
num_manual_csa = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Central apnea"))
num_manual_osa = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Obstructive apnea"))
num_manual_msa = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Mixed apnea"))
num_deleted_hyp = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Hypopnea"))
num_deleted_csa = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Central apnea"))
num_deleted_osa = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Obstructive apnea"))
num_deleted_msa = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Mixed apnea"))
num_unlabeled_hyp = np.sum((df["isLabeled"] == -1) & (df["Event type"] == "Hypopnea"))
num_unlabeled_csa = np.sum((df["isLabeled"] == -1) & (df["Event type"] == "Central apnea"))
num_unlabeled_osa = np.sum((df["isLabeled"] == -1) & (df["Event type"] == "Obstructive apnea"))
num_unlabeled_msa = np.sum((df["isLabeled"] == -1) & (df["Event type"] == "Mixed apnea"))
num_hyp_1_score = np.sum((df["correct_EventsType"] == "Hypopnea") & (df["score"] == 1))
num_csa_1_score = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] == 1))
num_osa_1_score = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] == 1))
num_msa_1_score = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] == 1))
num_hyp_2_score = np.sum((df["correct_EventsType"] == "Hypopnea") & (df["score"] == 2))
num_csa_2_score = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] == 2))
num_osa_2_score = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] == 2))
num_msa_2_score = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] == 2))
num_hyp_3_score = np.sum((df["correct_EventsType"] == "Hypopnea") & (df["score"] == 3))
num_csa_3_score = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] == 3))
num_osa_3_score = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] == 3))
num_msa_3_score = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] == 3))
num_1_score = np.sum(df["score"] == 1)
num_2_score = np.sum(df["score"] == 2)
num_3_score = np.sum(df["score"] == 3)
if verbose:
print("Event Statistics:")
# 格式化输出 总计/来自PSG/手动/删除/未标注 指定宽度
print(f"Type {'Total':^8s} / {'From PSG':^8s} / {'Manual':^8s} / {'Deleted':^8s} / {'Unlabeled':^8s}")
print(
f"Hyp: {num_hyp:^8d} / {num_psg_hyp:^8d} / {num_manual_hyp:^8d} / {num_deleted_hyp:^8d} / {num_unlabeled_hyp:^8d}")
print(
f"CSA: {num_csa:^8d} / {num_psg_csa:^8d} / {num_manual_csa:^8d} / {num_deleted_csa:^8d} / {num_unlabeled_csa:^8d}")
print(
f"OSA: {num_osa:^8d} / {num_psg_osa:^8d} / {num_manual_osa:^8d} / {num_deleted_osa:^8d} / {num_unlabeled_osa:^8d}")
print(
f"MSA: {num_msa:^8d} / {num_psg_msa:^8d} / {num_manual_msa:^8d} / {num_deleted_msa:^8d} / {num_unlabeled_msa:^8d}")
print(
f"All: {num_total:^8d} / {num_psg_events:^8d} / {num_manual_events:^8d} / {num_deleted:^8d} / {num_unlabeled:^8d}")
print("Score Statistics (only for non-deleted events and manual created events):")
print(f"Type {'Total':^8s} / {'Score 1':^8s} / {'Score 2':^8s} / {'Score 3':^8s}")
print(f"Hyp: {num_hyp:^8d} / {num_hyp_1_score:^8d} / {num_hyp_2_score:^8d} / {num_hyp_3_score:^8d}")
print(f"CSA: {num_csa:^8d} / {num_csa_1_score:^8d} / {num_csa_2_score:^8d} / {num_csa_3_score:^8d}")
print(f"OSA: {num_osa:^8d} / {num_osa_1_score:^8d} / {num_osa_2_score:^8d} / {num_osa_3_score:^8d}")
print(f"MSA: {num_msa:^8d} / {num_msa_1_score:^8d} / {num_msa_2_score:^8d} / {num_msa_3_score:^8d}")
print(f"All: {num_total:^8d} / {num_1_score:^8d} / {num_2_score:^8d} / {num_3_score:^8d}")
df["Start"] = df["Start"].astype(int)
df["End"] = df["End"].astype(int)
return df
def read_raw_psg_label(path: Union[str, Path], verbose=True) -> pd.DataFrame:
"""
Read a CSV file and return it as a pandas DataFrame.
Args:
path (str | Path): Path to the CSV file.
verbose (bool):
Returns:
pd.DataFrame: The content of the CSV file as a pandas DataFrame.
:param path:
:param verbose:
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
# 直接用pandas读取 包含中文 故指定编码
df = pd.read_csv(path, encoding="gbk")
if verbose:
print(f"Label file read from {path}, number of rows: {len(df)}")
num_psg_events = np.sum(df["Event type"].notna())
# 统计事件
num_psg_hyp = np.sum(df["Event type"] == "Hypopnea")
num_psg_csa = np.sum(df["Event type"] == "Central apnea")
num_psg_osa = np.sum(df["Event type"] == "Obstructive apnea")
num_psg_msa = np.sum(df["Event type"] == "Mixed apnea")
if verbose:
print("Event Statistics:")
# 格式化输出 总计/来自PSG/手动/删除/未标注 指定宽度
print(f"Type {'Total':^8s}")
print(
f"Hyp: {num_psg_hyp:^8d} ")
print(
f"CSA: {num_psg_csa:^8d} ")
print(
f"OSA: {num_psg_osa:^8d} ")
print(
f"MSA: {num_psg_msa:^8d} ")
print(
f"All: {num_psg_events:^8d}")
df["Start"] = df["Start"].astype(int)
df["End"] = df["End"].astype(int)
return df
def read_disable_excel(path: Union[str, Path]) -> pd.DataFrame:
"""
Read an Excel file and return it as a pandas DataFrame.
Args:
path (str | Path): Path to the Excel file.
Returns:
pd.DataFrame: The content of the Excel file as a pandas DataFrame.
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
# 直接用pandas读取
df = pd.read_excel(path)
df["id"] = df["id"].astype(int)
df["start"] = df["start"].astype(int)
df["end"] = df["end"].astype(int)
return df
def read_mask_execl(path: Union[str, Path]):
"""
Read an Excel file and return the mask as a numpy array.
Args:
path (str | Path): Path to the Excel file.
Returns:
np.ndarray: The mask as a numpy array.
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
df = pd.read_csv(path)
event_mask = df.to_dict(orient="list")
for key in event_mask:
event_mask[key] = np.array(event_mask[key])
event_list = {"RespAmpChangeSegment": event_mask_2_list(1 - event_mask["Resp_AmpChange_Label"]),
"BCGAmpChangeSegment": event_mask_2_list(1 - event_mask["BCG_AmpChange_Label"]),
"EnableSegment": event_mask_2_list(1 - event_mask["Disable_Label"]),
"DisableSegment": event_mask_2_list(event_mask["Disable_Label"])}
return event_mask, event_list
def read_psg_mask_excel(path: Union[str, Path]):
df = pd.read_csv(path)
event_mask = df.to_dict(orient="list")
for key in event_mask:
event_mask[key] = np.array(event_mask[key])
return event_mask
def read_psg_channel(path_str: Union[str, Path], channel_number: list[int], verbose=True):
"""
读取PSG文件中特定通道的数据
参数:
path_str (Union[str, Path]): 存放PSG文件的文件夹路径
channel_name (str): 需要读取的通道名称
返回:
np.ndarray: 指定通道的数据数组
"""
path = Path(path_str)
if not path.exists():
raise FileNotFoundError(f"PSG Dir not found: {path}")
if not path.is_dir():
raise NotADirectoryError(f"PSG Dir not found: {path}")
channel_data = {}
# 遍历检查通道对应的文件是否存在
for ch_id in channel_number:
ch_name = N2Chn[ch_id]
ch_path = list(path.glob(f"{ch_name}*.txt"))
if not any(ch_path):
raise FileNotFoundError(f"PSG Channel file not found: {ch_path}")
if len(ch_path) > 1:
print(f"Warning!!! PSG Channel file more than one: {ch_path}")
if ch_id == 8:
# sleep stage 特例 读取为整数
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=str, verbose=verbose)
# 转换为整数数组
for stage_str, stage_number in utils.Stage2N.items():
np.place(ch_signal, ch_signal == stage_str, stage_number)
ch_signal = ch_signal.astype(int)
elif ch_id == 1:
# Rpeak 特例 读取为整数
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=int, verbose=verbose, is_peak=True)
else:
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=float, verbose=verbose)
channel_data[ch_name] = {
"name": ch_name,
"path": ch_path[0],
"data": ch_signal,
"length": ch_length,
"fs": ch_fs,
"second": ch_second
}
return channel_data
def read_psg_label(path: Union[str, Path], verbose=True):
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
# 直接用pandas读取 包含中文 故指定编码
df = pd.read_csv(path, encoding="gbk")
if verbose:
print(f"Label file read from {path}, number of rows: {len(df)}")
# 丢掉Event type为空的行
df = df.dropna(subset=["Event type"], how='all').reset_index(drop=True)
return df

12
utils/__init__.py Normal file
View File

@ -0,0 +1,12 @@
from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel, read_psg_label, read_raw_psg_label, read_psg_mask_excel
from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list
from .operation_tools import merge_short_gaps, remove_short_durations
from .operation_tools import collect_values
from .operation_tools import save_process_label
from .operation_tools import none_to_nan_mask
from .operation_tools import get_wake_mask
from .operation_tools import fill_spo2_anomaly
from .split_method import resp_split
from .HYS_FileReader import read_mask_execl, read_psg_channel
from .event_map import E2N, N2Chn, Stage2N, ColorCycle
from .filter_func import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel, adjust_sample_rate

42
utils/event_map.py Normal file
View File

@ -0,0 +1,42 @@
# apnea event type to number mapping
E2N = {
"Hypopnea": 1,
"Central apnea": 2,
"Obstructive apnea": 3,
"Mixed apnea": 4
}
N2Chn = {
1: "Rpeak",
2: "ECG_Sync",
3: "Effort Tho",
4: "Effort Abd",
5: "Flow P",
6: "Flow T",
7: "SpO2",
8: "5_class"
}
Stage2N = {
"W": 5,
"N1": 3,
"N2": 2,
"N3": 1,
"R": 4,
}
# 设定事件和其对应颜色
# event_code color event
# 0 黑色 背景
# 1 粉色 低通气
# 2 蓝色 中枢性
# 3 红色 阻塞型
# 4 灰色 混合型
# 5 绿色 血氧饱和度下降
# 6 橙色 大体动
# 7 橙色 小体动
# 8 橙色 深呼吸
# 9 橙色 脉冲体动
# 10 橙色 无效片段
ColorCycle = ["black", "pink", "blue", "red", "silver", "green", "orange", "orange", "orange", "orange",
"orange"]

155
utils/filter_func.py Normal file
View File

@ -0,0 +1,155 @@
from utils.operation_tools import timing_decorator
import numpy as np
from scipy import signal, ndimage
@timing_decorator()
def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10, sample_rate=1000):
if _type == "lowpass": # 低通滤波处理
sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos')
return signal.sosfiltfilt(sos, np.array(data))
elif _type == "bandpass": # 带通滤波处理
low = low_cut / (sample_rate * 0.5)
high = high_cut / (sample_rate * 0.5)
sos = signal.butter(order, [low, high], btype='bandpass', output='sos')
return signal.sosfiltfilt(sos, np.array(data))
elif _type == "highpass": # 高通滤波处理
sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos')
return signal.sosfiltfilt(sos, np.array(data))
else: # 警告,滤波器类型必须有
raise ValueError("Please choose a type of fliter")
@timing_decorator()
def bessel(data, _type, low_cut=0.0, high_cut=0.0, order=4, sample_rate=1000):
if _type == "lowpass": # 低通滤波处理
b, a = signal.bessel(order, low_cut / (sample_rate * 0.5), btype='lowpass', analog=False, norm='mag')
return signal.filtfilt(b, a, np.array(data))
elif _type == "bandpass": # 带通滤波处理
low = low_cut / (sample_rate * 0.5)
high = high_cut / (sample_rate * 0.5)
b, a = signal.bessel(order, [low, high], btype='bandpass', analog=False, norm='mag')
return signal.filtfilt(b, a, np.array(data))
elif _type == "highpass": # 高通滤波处理
b, a = signal.bessel(order, high_cut / (sample_rate * 0.5), btype='highpass', analog=False, norm='mag')
return signal.filtfilt(b, a, np.array(data))
else: # 警告,滤波器类型必须有
raise ValueError("Please choose a type of fliter")
@timing_decorator()
def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000):
"""
高效整数倍降采样长信号分段处理以优化内存和速度
参数:
original_signal : array-like, 原始信号数组
original_fs : float, 原始采样率 (Hz)
target_fs : float, 目标采样率 (Hz)
chunk_size : int, 每段处理的样本数默认100000
返回:
downsampled_signal : array-like, 降采样后的信号
"""
# 输入验证
if not isinstance(original_signal, np.ndarray):
original_signal = np.array(original_signal)
if original_fs <= target_fs:
raise ValueError("目标采样率必须小于原始采样率")
if target_fs <= 0 or original_fs <= 0:
raise ValueError("采样率必须为正数")
# 计算降采样因子(必须为整数)
downsample_factor = original_fs / target_fs
if not downsample_factor.is_integer():
raise ValueError("降采样因子必须为整数倍")
downsample_factor = int(downsample_factor)
# 计算总输出长度
total_length = len(original_signal)
output_length = total_length // downsample_factor
# 初始化输出数组
downsampled_signal = np.zeros(output_length)
# 分段处理
for start in range(0, total_length, chunk_size):
end = min(start + chunk_size, total_length)
chunk = original_signal[start:end]
# 使用decimate进行整数倍降采样
chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True)
# 计算输出位置
out_start = start // downsample_factor
out_end = out_start + len(chunk_downsampled)
if out_end > output_length:
chunk_downsampled = chunk_downsampled[:output_length - out_start]
downsampled_signal[out_start:out_end] = chunk_downsampled
return downsampled_signal
def upsample_signal(original_signal, original_fs, target_fs):
"""
信号升采样
参数:
original_signal : array-like, 原始信号数组
original_fs : float, 原始采样率 (Hz)
target_fs : float, 目标采样率 (Hz)
返回:
upsampled_signal : array-like, 升采样后的信号
"""
if not isinstance(original_signal, np.ndarray):
original_signal = np.array(original_signal)
if target_fs <= original_fs:
raise ValueError("目标采样率必须大于原始采样率")
if target_fs <= 0 or original_fs <= 0:
raise ValueError("采样率必须为正数")
upsample_factor = target_fs / original_fs
num_output_samples = int(len(original_signal) * upsample_factor)
upsampled_signal = signal.resample(original_signal, num_output_samples)
return upsampled_signal
def adjust_sample_rate(signal_data, original_fs, target_fs):
"""
根据信号的原始采样率和目标采样率自动选择升采样或降采样
参数:
signal_data : array-like, 原始信号数组
original_fs : float, 原始采样率 (Hz)
target_fs : float, 目标采样率 (Hz)
返回:
adjusted_signal : array-like, 调整采样率后的信号
"""
if original_fs == target_fs:
return signal_data
elif original_fs > target_fs:
return downsample_signal_fast(signal_data, original_fs, target_fs)
else:
return upsample_signal(signal_data, original_fs, target_fs)
@timing_decorator()
def average_filter(raw_data, sample_rate, window_size_sec=20):
kernel = np.ones(window_size_sec * sample_rate) / (window_size_sec * sample_rate)
filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect')
convolve_filter_signal = raw_data - filtered
return convolve_filter_signal
# 陷波滤波器
@timing_decorator()
def notch_filter(data, notch_freq=50.0, quality_factor=30.0, sample_rate=1000):
nyquist = 0.5 * sample_rate
norm_notch_freq = notch_freq / nyquist
b, a = signal.iirnotch(norm_notch_freq, quality_factor)
filtered_data = signal.filtfilt(b, a, data)
return filtered_data

380
utils/operation_tools.py Normal file
View File

@ -0,0 +1,380 @@
import time
from pathlib import Path
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import yaml
from numpy.ma.core import append
from scipy.interpolate import PchipInterpolator
from utils.event_map import E2N
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
from scipy import ndimage, signal
from functools import wraps
# 全局配置
class Config:
time_verbose = False
def timing_decorator():
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
elapsed_time = end_time - start_time
if Config.time_verbose: # 运行时检查全局配置
print(f"函数 '{func.__name__}' 执行耗时: {elapsed_time:.4f}")
return result
return wrapper
return decorator
@timing_decorator()
def read_auto(file_path):
# print('suffix: ', file_path.suffix)
if file_path.suffix == '.txt':
# 使用pandas read csv读取txt
return pd.read_csv(file_path, header=None).to_numpy().reshape(-1)
elif file_path.suffix == '.npy':
return np.load(file_path.__str__())
elif file_path.suffix == '.base64':
with open(file_path) as f:
files = f.readlines()
data = np.array(files, dtype=int)
return data
else:
raise ValueError('这个文件类型不支持,需要自己写读取程序')
def merge_short_gaps(state_sequence, time_points, max_gap_sec):
"""
合并状态序列中间隔小于指定时长的段
参数
- state_sequence: numpy array状态序列0/1
- time_points: numpy array每个状态点对应的时间点
- max_gap_sec: float要合并的最大间隔
返回
- merged_sequence: numpy array合并后的状态序列
"""
if len(state_sequence) <= 1:
return state_sequence
merged_sequence = state_sequence.copy()
# 找出状态转换点
transitions = np.diff(np.concatenate([[0], merged_sequence, [0]]))
# 找出状态1的起始和结束位置
state_starts = np.where(transitions == 1)[0]
state_ends = np.where(transitions == -1)[0] - 1
# 检查每对连续的状态1
for i in range(len(state_starts) - 1):
if state_ends[i] < len(time_points) and state_starts[i + 1] < len(time_points):
# 计算间隔时长
gap_duration = time_points[state_starts[i + 1]] - time_points[state_ends[i]]
# 如果间隔小于阈值,则合并
if gap_duration <= max_gap_sec:
merged_sequence[state_ends[i]:state_starts[i + 1]] = 1
return merged_sequence
def remove_short_durations(state_sequence, time_points, min_duration_sec):
"""
移除状态序列中持续时间短于指定阈值的段
参数
- state_sequence: numpy array状态序列0/1
- time_points: numpy array每个状态点对应的时间点
- min_duration_sec: float要保留的最小持续时间
返回
- filtered_sequence: numpy array过滤后的状态序列
"""
if len(state_sequence) <= 1:
return state_sequence
filtered_sequence = state_sequence.copy()
# 找出状态转换点
transitions = np.diff(np.concatenate([[0], filtered_sequence, [0]]))
# 找出状态1的起始和结束位置
state_starts = np.where(transitions == 1)[0]
state_ends = np.where(transitions == -1)[0] - 1
# 检查每个状态1的持续时间
for i in range(len(state_starts)):
if state_starts[i] < len(time_points) and state_ends[i] < len(time_points):
# 计算持续时间
duration = time_points[state_ends[i]] - time_points[state_starts[i]]
if state_ends[i] == len(time_points) - 1:
# 如果是最后一个窗口,加上一个窗口的长度
duration += time_points[1] - time_points[0] if len(time_points) > 1 else 0
# 如果持续时间短于阈值,则移除
if duration < min_duration_sec:
filtered_sequence[state_starts[i]:state_ends[i] + 1] = 0
return filtered_sequence
@timing_decorator()
def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, window_second=20, step_second=None):
# 处理标志位长度与 signal_data 对齐
# if calc_mask is None:
# calc_mask = np.zeros(len(signal_data), dtype=bool)
if step_second is None:
step_second = window_second
step_length = step_second * sampling_rate
window_length = window_second * sampling_rate
origin_seconds = len(signal_data) // sampling_rate
total_samples = len(signal_data)
# reflect padding
left_pad_size = int(window_length // 2)
right_pad_size = window_length - left_pad_size
data = np.pad(signal_data, (left_pad_size, right_pad_size), mode='reflect')
num_segments = int(np.ceil(len(signal_data) / step_length))
values_nan = np.full(num_segments, np.nan)
# print(f"num_segments: {num_segments}, step_length: {step_length}, window_length: {window_length}")
for i in range(num_segments):
# 包含体动则仅计算不含体动部分
start = int(i * step_length)
end = start + window_length
segment = data[start:end]
values_nan[i] = func(segment)
values_nan = values_nan.repeat(step_second)[:origin_seconds]
if calc_mask is not None:
for i in range(len(values_nan)):
if calc_mask[i]:
values_nan[i] = np.nan
values = values_nan.copy()
# 插值处理体动区域的 NaN 值
def interpolate_nans(x, t):
valid_mask = ~np.isnan(x)
return np.interp(t, t[valid_mask], x[valid_mask])
values = interpolate_nans(values, np.arange(len(values)))
else:
values = values_nan.copy()
return values_nan, values
def load_dataset_conf(yaml_path):
with open(yaml_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# select_ids = config.get('select_id', [])
# root_path = config.get('root_path', None)
# data_path = Path(root_path)
return config
def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray:
disable_mask = np.zeros(signal_second, dtype=int)
for _, row in disable_df.iterrows():
start = row["start"]
end = row["end"]
disable_mask[start:end] = 1
return disable_mask
def generate_event_mask(signal_second: int, event_df, use_correct=True, with_score=True):
event_mask = np.zeros(signal_second, dtype=int)
if with_score:
score_mask = np.zeros(signal_second, dtype=int)
else:
score_mask = None
if use_correct:
start_name = "correct_Start"
end_name = "correct_End"
event_type_name = "correct_EventsType"
else:
start_name = "Start"
end_name = "End"
event_type_name = "Event type"
# 剔除start = -1 的行
event_df = event_df[event_df[start_name] >= 0]
for _, row in event_df.iterrows():
start = row[start_name]
end = row[end_name] + 1
event_mask[start:end] = E2N[row[event_type_name]]
if with_score:
score_mask[start:end] = row["score"]
return event_mask, score_mask
def event_mask_2_list(mask, event_true=True):
if event_true:
event_2_normal = -1
normal_2_event = 1
_append = 0
else:
event_2_normal = 1
normal_2_event = -1
_append = 1
mask_start = np.where(np.diff(mask, prepend=_append, append=_append) == normal_2_event)[0]
mask_end = np.where(np.diff(mask, prepend=_append, append=_append) == event_2_normal)[0]
event_list =[[start, end] for start, end in zip(mask_start, mask_end)]
return event_list
def collect_values(arr: np.ndarray, index: int, step: int, limit: int, mask=None) -> list:
"""收集非 NaN 值,直到达到指定数量或边界"""
values = []
count = 0
mask = mask if mask is not None else arr
while count < limit and 0 <= index < len(mask):
if not np.isnan(mask[index]):
values.append(arr[index])
count += 1
index += step
return values
def save_process_label(save_path: Path, save_dict: dict):
save_df = pd.DataFrame(save_dict)
save_df.to_csv(save_path, index=False)
def none_to_nan_mask(mask, ref):
"""将None转换为与ref形状相同的nan掩码"""
if mask is None:
return np.full_like(ref, np.nan)
else:
# 将mask中的0替换为nan其他的保持
mask = np.where(mask == 0, np.nan, mask)
return mask
def get_wake_mask(sleep_stage_mask):
# 将N1, N2, N3, REM视为睡眠 0其他为清醒 1
# 输入是字符 分别有 'W', 'N1', 'N2', 'N3', 'R' 等
wake_mask = np.where(np.isin(sleep_stage_mask, ['N1', 'N2', 'N3', 'REM', 'R']), 0, 1)
return wake_mask
def detect_spo2_anomaly(spo2, fs, diff_thresh=7):
anomaly = np.zeros(len(spo2), dtype=bool)
# 生理范围
anomaly |= (spo2 < 50) | (spo2 > 100)
# 突变
diff = np.abs(np.diff(spo2, prepend=spo2[0]))
anomaly |= diff > diff_thresh
# NaN
anomaly |= np.isnan(spo2)
return anomaly
def merge_close_anomalies(anomaly, fs, min_gap_duration):
min_gap = int(min_gap_duration * fs)
merged = anomaly.copy()
i = 0
n = len(anomaly)
while i < n:
if not anomaly[i]:
i += 1
continue
# 当前异常段
start = i
while i < n and anomaly[i]:
i += 1
end = i
# 向后看 gap
j = end
while j < n and not anomaly[j]:
j += 1
if j < n and (j - end) < min_gap:
merged[end:j] = True
return merged
def fill_spo2_anomaly(
spo2_data,
spo2_fs,
max_fill_duration,
min_gap_duration,
):
spo2 = spo2_data.astype(float).copy()
n = len(spo2)
anomaly = detect_spo2_anomaly(spo2, spo2_fs)
anomaly = merge_close_anomalies(anomaly, spo2_fs, min_gap_duration)
max_len = int(max_fill_duration * spo2_fs)
valid_mask = ~anomaly
i = 0
while i < n:
if not anomaly[i]:
i += 1
continue
start = i
while i < n and anomaly[i]:
i += 1
end = i
seg_len = end - start
# 超长异常段
if seg_len > max_len:
spo2[start:end] = np.nan
valid_mask[start:end] = False
continue
has_left = start > 0 and valid_mask[start - 1]
has_right = end < n and valid_mask[end]
# 开头异常:单侧填充
if not has_left and has_right:
spo2[start:end] = spo2[end]
continue
# 结尾异常:单侧填充
if has_left and not has_right:
spo2[start:end] = spo2[start - 1]
continue
# 两侧都有 → PCHIP
if has_left and has_right:
x = np.array([start - 1, end])
y = np.array([spo2[start - 1], spo2[end]])
interp = PchipInterpolator(x, y)
spo2[start:end] = interp(np.arange(start, end))
continue
# 两侧都没有(极端情况)
spo2[start:end] = np.nan
valid_mask[start:end] = False
return spo2, valid_mask

108
utils/signal_process.py Normal file
View File

@ -0,0 +1,108 @@
from utils.operation_tools import timing_decorator
import numpy as np
from scipy import signal, ndimage
@timing_decorator()
def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10, sample_rate=1000):
if _type == "lowpass": # 低通滤波处理
sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos')
return signal.sosfiltfilt(sos, np.array(data))
elif _type == "bandpass": # 带通滤波处理
low = low_cut / (sample_rate * 0.5)
high = high_cut / (sample_rate * 0.5)
sos = signal.butter(order, [low, high], btype='bandpass', output='sos')
return signal.sosfiltfilt(sos, np.array(data))
elif _type == "highpass": # 高通滤波处理
sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos')
return signal.sosfiltfilt(sos, np.array(data))
else: # 警告,滤波器类型必须有
raise ValueError("Please choose a type of fliter")
def bessel(data, _type, low_cut=0.0, high_cut=0.0, order=4, sample_rate=1000):
if _type == "lowpass": # 低通滤波处理
b, a = signal.bessel(order, low_cut / (sample_rate * 0.5), btype='lowpass', analog=False, norm='mag')
return signal.filtfilt(b, a, np.array(data))
elif _type == "bandpass": # 带通滤波处理
low = low_cut / (sample_rate * 0.5)
high = high_cut / (sample_rate * 0.5)
b, a = signal.bessel(order, [low, high], btype='bandpass', analog=False, norm='mag')
return signal.filtfilt(b, a, np.array(data))
elif _type == "highpass": # 高通滤波处理
b, a = signal.bessel(order, high_cut / (sample_rate * 0.5), btype='highpass', analog=False, norm='mag')
return signal.filtfilt(b, a, np.array(data))
else: # 警告,滤波器类型必须有
raise ValueError("Please choose a type of fliter")
@timing_decorator()
def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000):
"""
高效整数倍降采样长信号分段处理以优化内存和速度
参数:
original_signal : array-like, 原始信号数组
original_fs : float, 原始采样率 (Hz)
target_fs : float, 目标采样率 (Hz)
chunk_size : int, 每段处理的样本数默认100000
返回:
downsampled_signal : array-like, 降采样后的信号
"""
# 输入验证
if not isinstance(original_signal, np.ndarray):
original_signal = np.array(original_signal)
if original_fs <= target_fs:
raise ValueError("目标采样率必须小于原始采样率")
if target_fs <= 0 or original_fs <= 0:
raise ValueError("采样率必须为正数")
# 计算降采样因子(必须为整数)
downsample_factor = original_fs / target_fs
if not downsample_factor.is_integer():
raise ValueError("降采样因子必须为整数倍")
downsample_factor = int(downsample_factor)
# 计算总输出长度
total_length = len(original_signal)
output_length = total_length // downsample_factor
# 初始化输出数组
downsampled_signal = np.zeros(output_length)
# 分段处理
for start in range(0, total_length, chunk_size):
end = min(start + chunk_size, total_length)
chunk = original_signal[start:end]
# 使用decimate进行整数倍降采样
chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True)
# 计算输出位置
out_start = start // downsample_factor
out_end = out_start + len(chunk_downsampled)
if out_end > output_length:
chunk_downsampled = chunk_downsampled[:output_length - out_start]
downsampled_signal[out_start:out_end] = chunk_downsampled
return downsampled_signal
@timing_decorator()
def average_filter(raw_data, sample_rate, window_size_sec=20):
kernel = np.ones(window_size_sec * sample_rate) / (window_size_sec * sample_rate)
filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect')
convolve_filter_signal = raw_data - filtered
return convolve_filter_signal
# 陷波滤波器
@timing_decorator()
def notch_filter(data, notch_freq=50.0, quality_factor=30.0, sample_rate=1000):
nyquist = 0.5 * sample_rate
norm_notch_freq = notch_freq / nyquist
b, a = signal.iirnotch(norm_notch_freq, quality_factor)
filtered_data = signal.filtfilt(b, a, data)
return filtered_data

56
utils/split_method.py Normal file
View File

@ -0,0 +1,56 @@
def check_split(event_mask, current_start, window_sec, verbose=False):
# 检查当前窗口是否包含在禁用区间或低幅值区间内
resp_movement_mask = event_mask["Resp_Movement_Label"][current_start : current_start + window_sec]
resp_low_amp_mask = event_mask["Resp_LowAmp_Label"][current_start : current_start + window_sec]
# 体动与低幅值进行与计算
low_move_mask = resp_movement_mask | resp_low_amp_mask
if low_move_mask.sum() > 2/3 * window_sec:
if verbose:
print(f"{current_start}-{current_start + window_sec} rejected due to movement/low amplitude mask more than 2/3")
return False
return True
def resp_split(dataset_config, event_mask, event_list, verbose=False):
# 提取体动区间和呼吸低幅值区间
enable_list = event_list["EnableSegment"]
disable_list = event_list["DisableSegment"]
# 读取数据集配置
window_sec = dataset_config["window_sec"]
stride_sec = dataset_config["stride_sec"]
segment_list = []
disable_segment_list = []
# 遍历每个enable区间, 如果最后一个窗口不足stride的1/2则舍弃否则以enable_end为结尾截取一个窗口
for enable_start, enable_end in enable_list:
current_start = enable_start
while current_start + window_sec <= enable_end:
if check_split(event_mask, current_start, window_sec, verbose):
segment_list.append((current_start, current_start + window_sec))
else:
disable_segment_list.append((current_start, current_start + window_sec))
current_start += stride_sec
# 检查最后一个窗口是否需要添加
if (enable_end - current_start >= stride_sec / 2) and (enable_end - current_start >= window_sec):
if check_split(event_mask, enable_end - window_sec, window_sec, verbose):
segment_list.append((enable_end - window_sec, enable_end))
else:
disable_segment_list.append((enable_end - window_sec, enable_end))
# 遍历每个disable区间, 如果最后一个窗口不足stride的1/2则舍弃否则以disable_end为结尾截取一个窗口
for disable_start, disable_end in disable_list:
current_start = disable_start
while current_start + window_sec <= disable_end:
disable_segment_list.append((current_start, current_start + window_sec))
current_start += stride_sec
# 检查最后一个窗口是否需要添加
if (disable_end - current_start >= stride_sec / 2) and (disable_end - current_start >= window_sec):
disable_segment_list.append((disable_end - window_sec, disable_end))
return segment_list, disable_segment_list

105
utils/statistics_metrics.py Normal file
View File

@ -0,0 +1,105 @@
from utils.operation_tools import timing_decorator
import numpy as np
import pandas as pd
@timing_decorator()
def statistic_amplitude_metrics(data, aml_interval=None, time_interval=None):
"""
计算不同幅值区间占比和时间最后汇总成混淆矩阵
参数:
data: 采样率为1秒的一维序列其中体动所在的区域用np.nan填充
aml_interval: 幅值区间的分界点列表默认为[200, 500, 1000, 2000]
time_interval: 时间区间的分界点列表单位为秒默认为[60, 300, 1800, 3600]
返回:
confusion_matrix: 幅值-时长统计矩阵
summary: 汇总统计信息
"""
if aml_interval is None:
aml_interval = [200, 500, 1000, 2000]
if time_interval is None:
time_interval = [60, 300, 1800, 3600]
# 检查输入
if not isinstance(data, np.ndarray):
data = np.array(data)
# 整个记录的时长包括nan
total_duration = len(data)
# 创建幅值标签和时间标签
amp_labels = [f"0-{aml_interval[0]}"]
for i in range(len(aml_interval) - 1):
amp_labels.append(f"{aml_interval[i]}-{aml_interval[i + 1]}")
amp_labels.append(f"{aml_interval[-1]}+")
time_labels = [f"0-{time_interval[0]}"]
for i in range(len(time_interval) - 1):
time_labels.append(f"{time_interval[i]}-{time_interval[i + 1]}")
time_labels.append(f"{time_interval[-1]}+")
# 初始化结果矩阵(时长)和片段数矩阵
result_matrix = np.zeros((len(amp_labels), len(time_labels))) # 时长矩阵
segment_count_matrix = np.zeros((len(amp_labels), len(time_labels))) # 片段数矩阵
# 有效信号总量非NaN的数据点数量
valid_signal_length = np.sum(~np.isnan(data))
# 添加信号开始和结束的边界条件
signal_padded = np.concatenate(([np.nan], data, [np.nan]))
diff = np.diff(np.isnan(signal_padded).astype(int))
# 连续片段的起始位置(从 nan 变为非 nan
segment_starts = np.where(diff == -1)[0]
# 连续片段的结束位置(从非 nan 变为 nan
segment_ends = np.where(diff == 1)[0]
# 计算每个片段的时长和平均幅值,并填充结果矩阵
for start, end in zip(segment_starts, segment_ends):
segment = data[start:end]
duration = end - start # 时长(单位:秒)
mean_amplitude = np.nanmean(segment) # 片段平均幅值
# 确定幅值区间
if mean_amplitude <= aml_interval[0]:
amp_idx = 0
elif mean_amplitude > aml_interval[-1]:
amp_idx = len(aml_interval)
else:
amp_idx = np.searchsorted(aml_interval, mean_amplitude)
# 确定时长区间
if duration <= time_interval[0]:
time_idx = 0
elif duration > time_interval[-1]:
time_idx = len(time_interval)
else:
time_idx = np.searchsorted(time_interval, duration)
# 在对应位置累加该片段的时长和片段数
result_matrix[amp_idx, time_idx] += duration
segment_count_matrix[amp_idx, time_idx] += 1 # 片段数加1
# 创建DataFrame以便于展示和后续处理
confusion_matrix = pd.DataFrame(result_matrix, index=amp_labels, columns=time_labels)
# 计算行和列的总和
confusion_matrix['总计'] = confusion_matrix.sum(axis=1)
row_totals = confusion_matrix['总计'].copy()
# 计算百分比(相对于有效记录时长)
confusion_matrix_percent = confusion_matrix.div(total_duration) * 100
# 汇总统计
summary = {
'total_duration': total_duration,
'total_valid_signal': valid_signal_length,
'amplitude_distribution': row_totals.to_dict(),
'amplitude_percent': row_totals.div(total_duration) * 100,
'time_distribution': confusion_matrix.sum(axis=0).to_dict(),
'time_percent': confusion_matrix.sum(axis=0).div(total_duration) * 100
}
return summary, (confusion_matrix, segment_count_matrix, confusion_matrix_percent, valid_signal_length,
total_duration, time_labels, amp_labels)

BIN
排除区间.xlsx Normal file

Binary file not shown.