From b3acc4e886be4363937766e431d803441dad7954 Mon Sep 17 00:00:00 2001 From: marques Date: Mon, 29 Sep 2025 14:13:13 +0800 Subject: [PATCH 01/28] Initial commit --- .gitignore | 255 +++++++++++++++++++++++++++++++++++++++++++++++++++++ README.md | 2 + 2 files changed, 257 insertions(+) create mode 100644 .gitignore create mode 100644 README.md diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2429834 --- /dev/null +++ b/.gitignore @@ -0,0 +1,255 @@ +# ---> JetBrains +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# AWS User-specific +.idea/**/aws.xml + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# SonarLint plugin +.idea/sonarlint/ + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +# ---> JupyterNotebooks +# gitignore template for Jupyter Notebooks +# website: http://jupyter.org/ + +.ipynb_checkpoints +*/.ipynb_checkpoints/* + +# IPython +profile_default/ +ipython_config.py + +# Remove previous ipynb_checkpoints +# git rm -r .ipynb_checkpoints/ + +# ---> Python +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + diff --git a/README.md b/README.md new file mode 100644 index 0000000..a83c153 --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +# DataPrepare + From 35e5b6a202301c573338d0d8430a5903719924a1 Mon Sep 17 00:00:00 2001 From: marques <20172333133@m.scnu.edu.cn> Date: Mon, 29 Sep 2025 14:15:17 +0800 Subject: [PATCH 02/28] created --- README.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 From 805f1dc7f832f98b0bf71bd60bd84a7406addc2a Mon Sep 17 00:00:00 2001 From: marques Date: Sat, 11 Oct 2025 15:32:00 +0800 Subject: [PATCH 03/28] Add data processing and visualization modules for signal analysis --- HYS_process.py | 21 +++ SHHS_process.py | 0 draw_tools/__init__.py | 0 draw_tools/draw_statics.py | 175 +++++++++++++++++++++ signal_method/__init__.py | 0 signal_method/rule_base_event.py | 207 +++++++++++++++++++++++++ signal_method/time_metrics.py | 41 +++++ utils/HYS_FileReader.py | 54 +++++++ utils/__init__.py | 0 utils/operation_tools.py | 256 +++++++++++++++++++++++++++++++ utils/statistics_metrics.py | 105 +++++++++++++ 11 files changed, 859 insertions(+) create mode 100644 HYS_process.py create mode 100644 SHHS_process.py create mode 100644 draw_tools/__init__.py create mode 100644 draw_tools/draw_statics.py create mode 100644 signal_method/__init__.py create mode 100644 signal_method/rule_base_event.py create mode 100644 signal_method/time_metrics.py create mode 100644 utils/HYS_FileReader.py create mode 100644 utils/__init__.py create mode 100644 utils/operation_tools.py create mode 100644 utils/statistics_metrics.py diff --git a/HYS_process.py b/HYS_process.py new file mode 100644 index 0000000..1f07192 --- /dev/null +++ b/HYS_process.py @@ -0,0 +1,21 @@ +""" +本脚本完成对呼研所数据的处理,包含以下功能: +1. 数据读取与预处理 + 从传入路径中,进行数据和标签的读取,并进行初步的预处理 + 预处理包括为数据进行滤波、去噪等操作 +2. 数据清洗与异常值处理 +3. 输出清晰后的统计信息 +4. 数据保存 + 将处理后的数据保存到指定路径,便于后续使用 + 主要是保存切分后的数据位置和标签 +5. 可视化 + 提供数据处理前后的可视化对比,帮助理解数据变化 + 绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等 + + + + +# 低幅值区间规则标定与剔除 +# 高幅值连续体动规则标定与剔除 +# 手动标定不可用区间提剔除 +""" \ No newline at end of file diff --git a/SHHS_process.py b/SHHS_process.py new file mode 100644 index 0000000..e69de29 diff --git a/draw_tools/__init__.py b/draw_tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/draw_tools/draw_statics.py b/draw_tools/draw_statics.py new file mode 100644 index 0000000..88f790e --- /dev/null +++ b/draw_tools/draw_statics.py @@ -0,0 +1,175 @@ +from matplotlib.axes import Axes +from matplotlib.gridspec import GridSpec +from matplotlib.colors import PowerNorm +import seaborn as sns +import numpy as np + + +def draw_ax_confusion_matrix(ax:Axes, confusion_matrix, segment_count_matrix, confusion_matrix_percent, + valid_signal_length, total_duration, time_labels, amp_labels, signal_type=''): + # 创建用于热图注释的文本矩阵 + text_matrix = np.empty((len(amp_labels), len(time_labels)), dtype=object) + percent_matrix = np.zeros((len(amp_labels), len(time_labels))) + + # 填充文本矩阵和百分比矩阵 + for i in range(len(amp_labels)): + for j in range(len(time_labels)): + val = confusion_matrix.iloc[i, j] + segment_count = segment_count_matrix[i, j] + percent = confusion_matrix_percent.iloc[i, j] + text_matrix[i, j] = f"[{int(segment_count)}]{int(val)}\n({percent:.2f}%)" + percent_matrix[i, j] = percent + + # 绘制热图,调整颜色条位置 + sns_heatmap = sns.heatmap(percent_matrix, annot=text_matrix, fmt='', + xticklabels=time_labels, yticklabels=amp_labels, + cmap='YlGnBu', ax=ax, vmin=0, vmax=100, + norm=PowerNorm(gamma=0.5, vmin=0, vmax=100), + cbar_kws={'label': '百分比 (%)', 'shrink': 0.6, 'pad': 0.15}, + # annot_kws={'fontsize': 12} + ) + + # 添加行统计(右侧) + row_sums = confusion_matrix['总计'] + row_percents = confusion_matrix_percent['总计'] + ax.text(len(time_labels) + 1, -0.5, "各幅值时长\n(有效区间百分比%)", ha='center', va='center') + for i, (val, perc) in enumerate(zip(row_sums, row_percents)): + ax.text(len(time_labels) + 0.5, i + 0.5, f"{int(val)}\n({perc:.2f}%)", + ha='center', va='center') + + # 添加列统计(底部) + col_sums = segment_count_matrix.sum(axis=0) + col_percents = confusion_matrix.sum(axis=0) / total_duration * 100 + ax.text(-1, len(amp_labels) + 0.5, "[各时长片段数]\n(有效区间百分比%)", ha='center', va='center', rotation=0) + for j, (val, perc) in enumerate(zip(col_sums, col_percents)): + ax.text(j + 0.5, len(amp_labels) + 0.5, f"[{int(val)}]\n({perc:.2f}%)", + ha='center', va='center', rotation=0) + + # 将x轴坐标移到顶部 + ax.xaxis.set_label_position('top') + ax.xaxis.tick_top() + + # 设置标题和标签 + ax.set_title('幅值-时长统计矩阵', pad=40) + ax.set_xlabel('持续时间区间 (秒)', labelpad=10) + ax.set_ylabel('幅值区间') + + # 设置坐标轴标签水平显示 + ax.set_xticklabels(time_labels, rotation=0) + ax.set_yticklabels(amp_labels, rotation=0) + + # 调整颜色条位置 + cbar = sns_heatmap.collections[0].colorbar + cbar.ax.yaxis.set_label_position('right') + + # 添加图例说明 + ax.text(-2, -1, "热图内:\n[片段数]时长\n(有效区间百分比%)", + ha='left', va='top', bbox=dict(facecolor='none', edgecolor='black', alpha=0.5)) + + # 总计 + # 总片段数 + total_segments = segment_count_matrix.sum() + # 有效总市场占比 + total_percent = valid_signal_length / total_duration * 100 + ax.text(len(time_labels) + 0.5, len(amp_labels) + 0.5, f"[{int(total_segments)}]{valid_signal_length}\n({total_percent:.2f}%)", + ha='center', va='center') + +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches + +def draw_ax_amp(ax, signal_name, original_times, origin_signal, no_movement_signal, mav_values, + movement_position_list, low_amp_position_list, signal_second_length, aml_list=None): + # 绘制信号线 + ax.plot(original_times, origin_signal, 'k-', linewidth=1, alpha=0.7) + + # 添加红色和蓝色的axvspan区域 + for start, end in movement_position_list: + if start < len(original_times) and end < len(original_times): + ax.axvspan(start, end, color='red', alpha=0.3) + + for start, end in low_amp_position_list: + if start < len(original_times) and end < len(original_times): + ax.axvspan(start, end, color='blue', alpha=0.2) + + # 如果存在AML列表,绘制水平线 + if aml_list is not None: + color_map = ['red', 'orange', 'green'] + for i, aml in enumerate(aml_list): + ax.hlines(aml, 0, signal_second_length, color=color_map[min(i, 2)], linestyle='dashed', linewidth=2, alpha=0.5, label=f'{aml} aml') + + + ax.plot(np.linspace(0, len(mav_values), len(mav_values)), mav_values, color='blue', linewidth=2, alpha=0.4, label='2sMAV') + + # 设置Y轴范围 + ax.set_ylim((-2000, 2000)) + + # 创建表示不同颜色区域的图例 + red_patch = mpatches.Patch(color='red', alpha=0.2, label='Movement Area') + blue_patch = mpatches.Patch(color='blue', alpha=0.2, label='Low Amplitude Area') + + # 添加新的图例项,并保留原来的图例项 + handles, labels = ax.get_legend_handles_labels() # 获取原有图例 + ax.legend(handles=[red_patch, blue_patch] + handles, + labels=['Movement Area', 'Low Amplitude Area'] + labels, + loc='upper right', + bbox_to_anchor=(1, 1.4), + framealpha=0.5) + + # 设置标题和标签 + ax.set_title(f'{signal_name} Signal') + ax.set_ylabel('Amplitude') + ax.set_xlabel('Time (s)') + + # 启用网格 + ax.grid(True, linestyle='--', alpha=0.7) + + +def draw_signal_metrics(bcg_origin_signal, resp_origin_signal, bcg_no_movement_signal, resp_no_movement_signal, + bcg_sampling_rate, resp_sampling_rate, bcg_movement_position_list, bcg_low_amp_position_list, + resp_movement_position_list, resp_low_amp_position_list, + bcg_mav_values, resp_mav_values, bcg_statistic_info, resp_statistic_info, + signal_info, show=False, save_path=None): + + # 创建图像 + fig = plt.figure(figsize=(18, 10)) + + gs = GridSpec(2, 2, height_ratios=[2, 2], width_ratios=[4, 2], hspace=0.5) + + signal_second_length = len(bcg_origin_signal) // bcg_sampling_rate + bcg_origin_times = np.linspace(0, signal_second_length, len(bcg_origin_signal)) + resp_origin_times = np.linspace(0, signal_second_length, len(resp_origin_signal)) + # 子图 1:原始信号 + ax1 = fig.add_subplot(gs[0]) + draw_ax_amp(ax=ax1, signal_name='BCG', original_times=bcg_origin_times, origin_signal=bcg_origin_signal, + no_movement_signal=bcg_no_movement_signal, mav_values=bcg_mav_values, + movement_position_list=bcg_movement_position_list, low_amp_position_list=bcg_low_amp_position_list, + signal_second_length=signal_second_length, aml_list=[200, 500, 1000]) + + ax2 = fig.add_subplot(gs[1]) + param_names = ['confusion_matrix', 'segment_count_matrix', 'confusion_matrix_percent', + 'valid_signal_length', 'total_duration', 'time_labels', 'amp_labels'] + params = dict(zip(param_names, bcg_statistic_info)) + params['ax'] = ax2 + draw_ax_confusion_matrix(**params) + + ax3 = fig.add_subplot(gs[2], sharex=ax1) + draw_ax_amp(ax=ax3, signal_name='RSEP', original_times=resp_origin_times, origin_signal=resp_origin_signal, + no_movement_signal=resp_no_movement_signal, mav_values=resp_mav_values, + movement_position_list=resp_movement_position_list, low_amp_position_list=resp_low_amp_position_list, + signal_second_length=signal_second_length, aml_list=[100, 300, 500]) + ax4 = fig.add_subplot(gs[3]) + params = dict(zip(param_names, resp_statistic_info)) + params['ax'] = ax4 + draw_ax_confusion_matrix(**params) + + # 全局标题 + fig.suptitle(f'{signal_info} Signal Metrics', fontsize=16, x=0.35, y=0.95) + + if save_path is not None: + # 保存图像 + plt.savefig(save_path, dpi=300) + + if show: + plt.show() + + plt.close() \ No newline at end of file diff --git a/signal_method/__init__.py b/signal_method/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py new file mode 100644 index 0000000..74e5929 --- /dev/null +++ b/signal_method/rule_base_event.py @@ -0,0 +1,207 @@ +from utils.operation_tools import timing_decorator +import numpy as np +from utils.operation_tools import merge_short_gaps, remove_short_durations + +@timing_decorator() +def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, stride_sec=None, + amplitude_threshold=50, merge_gap_sec=10, min_duration_sec=5): + """ + 检测信号中的低幅值状态,通过计算RMS值判断信号强度是否低于设定阈值。 + + 参数: + - signal_data: numpy array,输入的信号数据 + - sampling_rate: int,信号的采样率(Hz) + - window_size_sec: float,分析窗口的时长(秒),默认值为 1 秒 + - stride_sec: float,窗口滑动步长(秒),默认值为None(等于window_size_sec,无重叠) + - amplitude_threshold: float,RMS阈值,低于此值表示低幅值状态,默认值为 50 + - merge_gap_sec: float,要合并的状态之间的最大间隔(秒),默认值为 10 秒 + - min_duration_sec: float,要保留的状态的最小持续时间(秒),默认值为 5 秒 + + 返回: + - low_amplitude_mask: numpy array,低幅值状态的掩码(1表示低幅值,0表示正常幅值) + """ + # 计算窗口大小(样本数) + window_samples = int(window_size_sec * sampling_rate) + + # 如果未指定步长,设置为窗口大小(无重叠) + if stride_sec is None: + stride_sec = window_size_sec + + # 计算步长(样本数) + stride_samples = int(stride_sec * sampling_rate) + + # 确保步长至少为1 + stride_samples = max(1, stride_samples) + + # 处理信号边界,使用反射填充 + pad_size = window_samples // 2 + padded_signal = np.pad(signal_data, pad_size, mode='reflect') + + # 计算填充后的窗口数量 + num_windows = max(1, (len(padded_signal) - window_samples) // stride_samples + 1) + + # 初始化RMS值数组 + rms_values = np.zeros(num_windows) + + # 计算每个窗口的RMS值 + for i in range(num_windows): + start_idx = i * stride_samples + end_idx = min(start_idx + window_samples, len(signal_data)) + + # 处理窗口,包括可能不完整的最后一个窗口 + window_data = signal_data[start_idx:end_idx] + if len(window_data) > 0: + rms_values[i] = np.sqrt(np.mean(window_data ** 2)) + else: + rms_values[i] = 0 + + # 生成初始低幅值掩码:RMS低于阈值的窗口标记为1(低幅值),其他为0 + low_amplitude_mask = np.where(rms_values <= amplitude_threshold, 1, 0) + + # 计算原始信号对应的窗口索引范围 + orig_start_window = pad_size // stride_samples + if stride_sec == 1: + orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + else: + orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + 1 + + # 只保留原始信号对应的窗口低幅值掩码 + low_amplitude_mask = low_amplitude_mask[orig_start_window:orig_end_window] + # print("low_amplitude_mask_length: ", len(low_amplitude_mask)) + num_original_windows = len(low_amplitude_mask) + + # 转换为时间轴上的状态序列 + # 计算每个窗口对应的时间点(秒) + time_points = np.arange(num_original_windows) * stride_sec + + # 如果需要合并间隔小的状态 + if merge_gap_sec > 0: + low_amplitude_mask = merge_short_gaps(low_amplitude_mask, time_points, merge_gap_sec) + + # 如果需要移除短时状态 + if min_duration_sec > 0: + low_amplitude_mask = remove_short_durations(low_amplitude_mask, time_points, min_duration_sec) + + low_amplitude_mask = low_amplitude_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate] + + # 低幅值状态起止位置 [[start, end], [start, end], ...] + low_amplitude_start = np.where(np.diff(np.concatenate([[0], low_amplitude_mask])) == 1)[0] + low_amplitude_end = np.where(np.diff(np.concatenate([low_amplitude_mask, [0]])) == -1)[0] + low_amplitude_position_list = [[start, end] for start, end in zip(low_amplitude_start, low_amplitude_end)] + + return low_amplitude_mask, low_amplitude_position_list + + +def get_typical_segment_for_continues_signal(signal_data, sampling_rate=100, window_size=30, step_size=1): + """ + 获取十个片段 + :param signal_data: 信号数据 + :param sampling_rate: 采样率 + :param window_size: 窗口大小(秒) + :param step_size: 步长(秒) + :return: 典型片段列表 + """ + pass + + +# 基于体动位置和幅值的睡姿识别 +# 主要是依靠体动mask,将整夜分割成多个有效片段,然后每个片段计算幅值指标,判断两个片段的幅值指标是否存在显著差异,如果存在显著差异,则认为存在睡姿变化 +# 考虑到每个片段长度为10s,所以每个片段的幅值指标计算时间长度为10s,然后计算每个片段的幅值指标 +# 仅对比相邻片段的幅值指标,如果存在显著差异,则认为存在睡姿变化,即每个体动相邻的30秒内存在睡姿变化,如果片段不足30秒,则按实际长度对比 + +@timing_decorator() +def position_based_sleep_recognition(signal_data, movement_mask, sampling_rate=100, window_size_sec=30, + interval_to_movement=10): + # 获取有效片段起止位置 + valid_mask = 1 - movement_mask + valid_starts = np.where(np.diff(np.concatenate([[0], valid_mask])) == 1)[0] + valid_ends = np.where(np.diff(np.concatenate([valid_mask, [0]])) == -1)[0] + + movement_start = np.where(np.diff(np.concatenate([[0], movement_mask])) == 1)[0] + movement_end = np.where(np.diff(np.concatenate([movement_mask, [0]])) == -1)[0] + + segment_left_average_amplitude = [] + segment_right_average_amplitude = [] + segment_left_average_energy = [] + segment_right_average_energy = [] + + # window_samples = int(window_size_sec * sampling_rate) + # pad_size = window_samples // 2 + # padded_signal = np.pad(signal_data, pad_size, mode='reflect') + + for start, end in zip(valid_starts, valid_ends): + start *= sampling_rate + end *= sampling_rate + # 避免过短的片段 + if end - start <= sampling_rate: # 小于1秒的片段不考虑 + continue + # 获取当前片段数据 + + + elif end - start < (window_size_sec * interval_to_movement + 1) * sampling_rate: + left_start = start + left_end = min(end, left_start + window_size_sec * sampling_rate) + right_start = max(start, end - window_size_sec * sampling_rate) + right_end = end + else: + left_start = start + interval_to_movement * sampling_rate + left_end = left_start + window_size_sec * sampling_rate + right_start = end - interval_to_movement * sampling_rate - window_size_sec * sampling_rate + right_end = end + + # 新的end - start确保为200的整数倍 + if (left_end - left_start) % (2 * sampling_rate) != 0: + left_end = left_start + ((left_end - left_start) // (2 * sampling_rate)) * (2 * sampling_rate) + if (right_end - right_start) % (2 * sampling_rate) != 0: + right_end = right_start + ((right_end - right_start) // (2 * sampling_rate)) * (2 * sampling_rate) + + # 计算每个片段的幅值指标 + left_mav = np.mean(np.max(signal_data[left_start:left_end].reshape(-1, 2 * sampling_rate), axis=0)) - np.mean( + np.min(signal_data[left_start:left_end].reshape(-1, 2 * sampling_rate), axis=0)) + right_mav = np.mean( + np.max(signal_data[right_start:right_end].reshape(-1, 2 * sampling_rate), axis=0)) - np.mean( + np.min(signal_data[right_start:right_end].reshape(-1, 2 * sampling_rate), axis=0)) + segment_left_average_amplitude.append(left_mav) + segment_right_average_amplitude.append(right_mav) + + left_energy = np.sum(np.abs(signal_data[left_start:left_end] ** 2)) + right_energy = np.sum(np.abs(signal_data[right_start:right_end] ** 2)) + segment_left_average_energy.append(left_energy) + segment_right_average_energy.append(right_energy) + + position_changes = [] + position_change_times = [] + for i in range(1, len(segment_left_average_amplitude)): + # 计算幅值指标的变化率 + left_amplitude_change = abs(segment_left_average_amplitude[i] - segment_left_average_amplitude[i - 1]) / max( + segment_left_average_amplitude[i - 1], 1e-6) + right_amplitude_change = abs(segment_right_average_amplitude[i] - segment_right_average_amplitude[i - 1]) / max( + segment_right_average_amplitude[i - 1], 1e-6) + + # 计算能量指标的变化率 + left_energy_change = abs(segment_left_average_energy[i] - segment_left_average_energy[i - 1]) / max( + segment_left_average_energy[i - 1], 1e-6) + right_energy_change = abs(segment_right_average_energy[i] - segment_right_average_energy[i - 1]) / max( + segment_right_average_energy[i - 1], 1e-6) + + # 判断是否存在显著变化 (可根据实际情况调整阈值) + threshold_amplitude = 0.1 # 幅值变化阈值 + threshold_energy = 0.1 # 能量变化阈值 + + # 如果左右通道中的任一通道同时满足幅值和能量的变化阈值,则认为存在姿势变化 + left_significant_change = (left_amplitude_change > threshold_amplitude) and ( + left_energy_change > threshold_energy) + right_significant_change = (right_amplitude_change > threshold_amplitude) and ( + right_energy_change > threshold_energy) + + if left_significant_change or right_significant_change: + # 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示 + position_changes.append(1) + position_change_times.append((movement_start[i - 1], movement_end[i - 1])) + else: + position_changes.append(0) # 0表示不存在姿势变化 + + # print(i,movement_start[i], movement_end[i], round(left_amplitude_change, 2), round(right_amplitude_change, 2), round(left_energy_change, 2), round(right_energy_change, 2)) + + return position_changes, position_change_times + diff --git a/signal_method/time_metrics.py b/signal_method/time_metrics.py new file mode 100644 index 0000000..a6a2a94 --- /dev/null +++ b/signal_method/time_metrics.py @@ -0,0 +1,41 @@ +from utils.operation_tools import calculate_by_slide_windows +from utils.operation_tools import timing_decorator +import numpy as np + + +@timing_decorator() +def calc_mav(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window_second=10, step_second=1, inner_window_second=2): + assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" + assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" + # print(f"movement_mask_length: {len(movement_mask)}, signal_data_length: {len(signal_data)}") + processed_mask = movement_mask.copy() + def mav_func(x): + return np.mean(np.nanmax(x.reshape(-1, inner_window_second*sampling_rate), axis=1) - np.nanmin(x.reshape(-1, inner_window_second*sampling_rate), axis=1)) / 2 + mav_nan, mav = calculate_by_slide_windows(mav_func, signal_data, processed_mask, sampling_rate=sampling_rate, + window_second=window_second, step_second=step_second) + + return mav_nan, mav + +@timing_decorator() +def calc_wavefactor(signal_data, movement_mask, low_amp_mask, sampling_rate=100): + assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" + assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" + + processed_mask = movement_mask.copy() + processed_mask = processed_mask.repeat(sampling_rate) + wavefactor_nan, wavefactor = calculate_by_slide_windows(lambda x: np.sqrt(np.mean(x ** 2)) / np.mean(np.abs(x)), + signal_data, processed_mask, sampling_rate=sampling_rate, window_second=2, step_second=1) + + return wavefactor_nan, wavefactor + +@timing_decorator() +def calc_peakfactor(signal_data, movement_mask, low_amp_mask, sampling_rate=100): + assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" + assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" + + processed_mask = movement_mask.copy() + processed_mask = processed_mask.repeat(sampling_rate) + peakfactor_nan, peakfactor = calculate_by_slide_windows(lambda x: np.max(np.abs(x)) / np.sqrt(np.mean(x ** 2)), + signal_data, processed_mask, sampling_rate=sampling_rate, window_second=2, step_second=1) + + return peakfactor_nan, peakfactor \ No newline at end of file diff --git a/utils/HYS_FileReader.py b/utils/HYS_FileReader.py new file mode 100644 index 0000000..82e4584 --- /dev/null +++ b/utils/HYS_FileReader.py @@ -0,0 +1,54 @@ +from pathlib import Path +from typing import Union + +import numpy as np +import pandas as pd + +# 尝试导入 Polars +try: + import polars as pl + HAS_POLARS = True +except ImportError: + HAS_POLARS = False + + +def read_signal_txt(path: Union[str, Path]) -> np.ndarray: + """ + Read a txt file and return the first column as a numpy array. + + Args: + path (str | Path): Path to the txt file. + + Returns: + np.ndarray: The first column of the txt file as a numpy array. + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + if HAS_POLARS: + df = pl.read_csv(path, has_header=False, infer_schema_length=0) + return df[:, 0].to_numpy() + else: + df = pd.read_csv(path, header=None, dtype=float) + return df.iloc[:, 0].to_numpy() + + +def read_laebl_csv(path: Union[str, Path]) -> pd.DataFrame: + """ + Read a CSV file and return it as a pandas DataFrame. + + Args: + path (str | Path): Path to the CSV file. + Returns: + pd.DataFrame: The content of the CSV file as a pandas DataFrame. + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + # 直接用pandas读取 包含中文 故指定编码 + df = pd.read_csv(path, encoding="gbk") + df["Start"] = df["Start"].astype(int) + df["End"] = df["End"].astype(int) + return df \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/operation_tools.py b/utils/operation_tools.py new file mode 100644 index 0000000..bcc5062 --- /dev/null +++ b/utils/operation_tools.py @@ -0,0 +1,256 @@ +import time + +from pathlib import Path +import numpy as np +import pandas as pd +from matplotlib import pyplot as plt + + +plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 +plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 + +from scipy import ndimage, signal +from functools import wraps + +# 全局配置 +class Config: + time_verbose = False + +def timing_decorator(): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + elapsed_time = end_time - start_time + if Config.time_verbose: # 运行时检查全局配置 + print(f"函数 '{func.__name__}' 执行耗时: {elapsed_time:.4f} 秒") + return result + return wrapper + return decorator + +@timing_decorator() +def read_auto(file_path): + # print('suffix: ', file_path.suffix) + if file_path.suffix == '.txt': + # 使用pandas read csv读取txt + return pd.read_csv(file_path, header=None).to_numpy().reshape(-1) + elif file_path.suffix == '.npy': + return np.load(file_path.__str__()) + elif file_path.suffix == '.base64': + with open(file_path) as f: + files = f.readlines() + + data = np.array(files, dtype=int) + return data + else: + raise ValueError('这个文件类型不支持,需要自己写读取程序') + +@timing_decorator() +def Butterworth(data, type, low_cut=0.0, high_cut=0.0, order=10,sample_rate=1000): + + if type == "lowpass": # 低通滤波处理 + sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + elif type == "bandpass": # 带通滤波处理 + low = low_cut / (sample_rate * 0.5) + high = high_cut / (sample_rate * 0.5) + sos = signal.butter(order, [low, high], btype='bandpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + elif type == "highpass": # 高通滤波处理 + sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + else: # 警告,滤波器类型必须有 + raise ValueError("Please choose a type of fliter") + +@timing_decorator() +def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000): + """ + 高效整数倍降采样长信号(适合8小时以上),分段处理以优化内存和速度。 + + 参数: + original_signal : array-like, 原始信号数组 + original_fs : float, 原始采样率 (Hz) + target_fs : float, 目标采样率 (Hz) + chunk_size : int, 每段处理的样本数,默认100000 + + 返回: + downsampled_signal : array-like, 降采样后的信号 + """ + # 输入验证 + if not isinstance(original_signal, np.ndarray): + original_signal = np.array(original_signal) + if original_fs <= target_fs: + raise ValueError("目标采样率必须小于原始采样率") + if target_fs <= 0 or original_fs <= 0: + raise ValueError("采样率必须为正数") + + # 计算降采样因子(必须为整数) + downsample_factor = original_fs / target_fs + if not downsample_factor.is_integer(): + raise ValueError("降采样因子必须为整数倍") + downsample_factor = int(downsample_factor) + + # 计算总输出长度 + total_length = len(original_signal) + output_length = total_length // downsample_factor + + # 初始化输出数组 + downsampled_signal = np.zeros(output_length) + + # 分段处理 + for start in range(0, total_length, chunk_size): + end = min(start + chunk_size, total_length) + chunk = original_signal[start:end] + + # 使用decimate进行整数倍降采样 + chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True) + + # 计算输出位置 + out_start = start // downsample_factor + out_end = out_start + len(chunk_downsampled) + if out_end > output_length: + chunk_downsampled = chunk_downsampled[:output_length - out_start] + + downsampled_signal[out_start:out_end] = chunk_downsampled + + return downsampled_signal + +@timing_decorator() +def average_filter(raw_data, sample_rate, window_size=20): + kernel = np.ones(window_size * sample_rate) / (window_size * sample_rate) + filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect') + convolve_filter_signal = raw_data - filtered + return convolve_filter_signal + + + + +def merge_short_gaps(state_sequence, time_points, max_gap_sec): + """ + 合并状态序列中间隔小于指定时长的段 + + 参数: + - state_sequence: numpy array,状态序列(0/1) + - time_points: numpy array,每个状态点对应的时间点 + - max_gap_sec: float,要合并的最大间隔(秒) + + 返回: + - merged_sequence: numpy array,合并后的状态序列 + """ + if len(state_sequence) <= 1: + return state_sequence + + merged_sequence = state_sequence.copy() + + # 找出状态转换点 + transitions = np.diff(np.concatenate([[0], merged_sequence, [0]])) + # 找出状态1的起始和结束位置 + state_starts = np.where(transitions == 1)[0] + state_ends = np.where(transitions == -1)[0] - 1 + + # 检查每对连续的状态1 + for i in range(len(state_starts) - 1): + if state_ends[i] < len(time_points) and state_starts[i + 1] < len(time_points): + # 计算间隔时长 + gap_duration = time_points[state_starts[i + 1]] - time_points[state_ends[i]] + # 如果间隔小于阈值,则合并 + if gap_duration <= max_gap_sec: + merged_sequence[state_ends[i]:state_starts[i + 1]] = 1 + + return merged_sequence + + +def remove_short_durations(state_sequence, time_points, min_duration_sec): + """ + 移除状态序列中持续时间短于指定阈值的段 + + 参数: + - state_sequence: numpy array,状态序列(0/1) + - time_points: numpy array,每个状态点对应的时间点 + - min_duration_sec: float,要保留的最小持续时间(秒) + + 返回: + - filtered_sequence: numpy array,过滤后的状态序列 + """ + if len(state_sequence) <= 1: + return state_sequence + + filtered_sequence = state_sequence.copy() + + # 找出状态转换点 + transitions = np.diff(np.concatenate([[0], filtered_sequence, [0]])) + # 找出状态1的起始和结束位置 + state_starts = np.where(transitions == 1)[0] + state_ends = np.where(transitions == -1)[0] - 1 + + # 检查每个状态1的持续时间 + for i in range(len(state_starts)): + if state_starts[i] < len(time_points) and state_ends[i] < len(time_points): + # 计算持续时间 + duration = time_points[state_ends[i]] - time_points[state_starts[i]] + if state_ends[i] == len(time_points) - 1: + # 如果是最后一个窗口,加上一个窗口的长度 + duration += time_points[1] - time_points[0] if len(time_points) > 1 else 0 + + # 如果持续时间短于阈值,则移除 + if duration < min_duration_sec: + filtered_sequence[state_starts[i]:state_ends[i] + 1] = 0 + + return filtered_sequence + + +@timing_decorator() +def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, window_second=20, step_second=None): + # 处理标志位长度与 signal_data 对齐 + if calc_mask is None: + calc_mask = np.zeros(len(signal_data), dtype=bool) + + if step_second is None: + step_second = window_second + + step_length = step_second * sampling_rate + window_length = window_second * sampling_rate + + origin_seconds = len(signal_data) // sampling_rate + total_samples = len(signal_data) + + + # reflect padding + left_pad_size = int(window_length // 2) + right_pad_size = window_length - left_pad_size + data = np.pad(signal_data, (left_pad_size, right_pad_size), mode='reflect') + + num_segments = int(np.ceil(len(signal_data) / step_length)) + values_nan = np.full(num_segments, np.nan) + + # print(f"num_segments: {num_segments}, step_length: {step_length}, window_length: {window_length}") + for i in range(num_segments): + # 包含体动则仅计算不含体动部分 + start = int(i * step_length) + end = start + window_length + segment = data[start:end] + values_nan[i] = func(segment) + + + values_nan = values_nan.repeat(step_second)[:origin_seconds] + + for i in range(len(values_nan)): + if calc_mask[i]: + values_nan[i] = np.nan + + values = values_nan.copy() + + # 插值处理体动区域的 NaN 值 + def interpolate_nans(x, t): + valid_mask = ~np.isnan(x) + return np.interp(t, t[valid_mask], x[valid_mask]) + + values = interpolate_nans(values, np.arange(len(values))) + + return values_nan, values + + + + diff --git a/utils/statistics_metrics.py b/utils/statistics_metrics.py new file mode 100644 index 0000000..5c06b19 --- /dev/null +++ b/utils/statistics_metrics.py @@ -0,0 +1,105 @@ +from utils.operation_tools import timing_decorator +import numpy as np +import pandas as pd + +@timing_decorator() +def statistic_amplitude_metrics(data, aml_interval=None, time_interval=None): + """ + 计算不同幅值区间占比和时间,最后汇总成混淆矩阵 + + 参数: + data: 采样率为1秒的一维序列,其中体动所在的区域用np.nan填充 + aml_interval: 幅值区间的分界点列表,默认为[200, 500, 1000, 2000] + time_interval: 时间区间的分界点列表,单位为秒,默认为[60, 300, 1800, 3600] + + 返回: + confusion_matrix: 幅值-时长统计矩阵 + summary: 汇总统计信息 + """ + if aml_interval is None: + aml_interval = [200, 500, 1000, 2000] + + if time_interval is None: + time_interval = [60, 300, 1800, 3600] + # 检查输入 + if not isinstance(data, np.ndarray): + data = np.array(data) + + # 整个记录的时长(包括nan) + total_duration = len(data) + + # 创建幅值标签和时间标签 + amp_labels = [f"0-{aml_interval[0]}"] + for i in range(len(aml_interval) - 1): + amp_labels.append(f"{aml_interval[i]}-{aml_interval[i + 1]}") + amp_labels.append(f"{aml_interval[-1]}+") + + time_labels = [f"0-{time_interval[0]}"] + for i in range(len(time_interval) - 1): + time_labels.append(f"{time_interval[i]}-{time_interval[i + 1]}") + time_labels.append(f"{time_interval[-1]}+") + + # 初始化结果矩阵(时长)和片段数矩阵 + result_matrix = np.zeros((len(amp_labels), len(time_labels))) # 时长矩阵 + segment_count_matrix = np.zeros((len(amp_labels), len(time_labels))) # 片段数矩阵 + + # 有效信号总量(非NaN的数据点数量) + valid_signal_length = np.sum(~np.isnan(data)) + + # 添加信号开始和结束的边界条件 + signal_padded = np.concatenate(([np.nan], data, [np.nan])) + diff = np.diff(np.isnan(signal_padded).astype(int)) + + # 连续片段的起始位置(从 nan 变为非 nan) + segment_starts = np.where(diff == -1)[0] + # 连续片段的结束位置(从非 nan 变为 nan) + segment_ends = np.where(diff == 1)[0] + + # 计算每个片段的时长和平均幅值,并填充结果矩阵 + for start, end in zip(segment_starts, segment_ends): + segment = data[start:end] + duration = end - start # 时长(单位:秒) + mean_amplitude = np.nanmean(segment) # 片段平均幅值 + + # 确定幅值区间 + if mean_amplitude <= aml_interval[0]: + amp_idx = 0 + elif mean_amplitude > aml_interval[-1]: + amp_idx = len(aml_interval) + else: + amp_idx = np.searchsorted(aml_interval, mean_amplitude) + + # 确定时长区间 + if duration <= time_interval[0]: + time_idx = 0 + elif duration > time_interval[-1]: + time_idx = len(time_interval) + else: + time_idx = np.searchsorted(time_interval, duration) + + # 在对应位置累加该片段的时长和片段数 + result_matrix[amp_idx, time_idx] += duration + segment_count_matrix[amp_idx, time_idx] += 1 # 片段数加1 + + # 创建DataFrame以便于展示和后续处理 + confusion_matrix = pd.DataFrame(result_matrix, index=amp_labels, columns=time_labels) + + # 计算行和列的总和 + confusion_matrix['总计'] = confusion_matrix.sum(axis=1) + row_totals = confusion_matrix['总计'].copy() + + # 计算百分比(相对于有效记录时长) + confusion_matrix_percent = confusion_matrix.div(total_duration) * 100 + + # 汇总统计 + summary = { + 'total_duration': total_duration, + 'total_valid_signal': valid_signal_length, + 'amplitude_distribution': row_totals.to_dict(), + 'amplitude_percent': row_totals.div(total_duration) * 100, + 'time_distribution': confusion_matrix.sum(axis=0).to_dict(), + 'time_percent': confusion_matrix.sum(axis=0).div(total_duration) * 100 + } + + return summary, (confusion_matrix, segment_count_matrix, confusion_matrix_percent, valid_signal_length, + total_duration, time_labels, amp_labels) From d2ed6787d48cceac8a72440332b1dd212857d203 Mon Sep 17 00:00:00 2001 From: marques Date: Sun, 12 Oct 2025 18:42:29 +0800 Subject: [PATCH 04/28] Add utility functions for signal processing and configuration management --- HYS_process.py | 59 +++++++++++++++++++++- dataset_config/HYS_config.yaml | 13 +++++ utils/HYS_FileReader.py | 90 +++++++++++++++++++++++++++++++++- utils/__init__.py | 2 + utils/operation_tools.py | 31 +++++++++++- 5 files changed, 191 insertions(+), 4 deletions(-) create mode 100644 dataset_config/HYS_config.yaml diff --git a/HYS_process.py b/HYS_process.py index 1f07192..f3de6d8 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -12,10 +12,65 @@ 提供数据处理前后的可视化对比,帮助理解数据变化 绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等 - +todo: 使用mask 屏蔽无用区间 # 低幅值区间规则标定与剔除 # 高幅值连续体动规则标定与剔除 # 手动标定不可用区间提剔除 -""" \ No newline at end of file +""" + +from pathlib import Path +from typing import Union +import utils +import numpy as np + + + + +def process_one_signal(samp_id): + signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) + if not signal_path: + raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}") + signal_path = signal_path[0] + print(f"Processing OrgBCG_Sync signal file: {signal_path}") + + label_path = (label_root_path / f"{samp_id}").glob("SA Label_corrected.csv") + if not label_path: + raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}") + label_path = list(label_path)[0] + print(f"Processing Label_corrected file: {label_path}") + + + signal_data = utils.read_signal_txt(signal_path) + signal_length = len(signal_data) + print(f"signal_length: {signal_length}") + signal_fs = int(signal_path.stem.split("_")[-1]) + print(f"signal_fs: {signal_fs}") + signal_second = signal_length // signal_fs + print(f"signal_second: {signal_second}") + + + label_data = utils.read_label_csv(label_path) + + manual_disable_mask = utils.generate_disable_mask(signal_second, all_samp_disable_df[all_samp_disable_df["id"] == samp_id]) + print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}") + + + + + +if __name__ == '__main__': + yaml_path = Path("./dataset_config/HYS_config.yaml") + disable_df_path = Path("./排除区间.xlsx") + + select_ids, root_path = utils.load_dataset_info(yaml_path) + print(f"select_ids: {select_ids}") + print(f"root_path: {root_path}") + + org_signal_root_path = root_path / "OrgBCG_Aligned" + label_root_path = root_path / "Label" + + all_samp_disable_df = utils.read_disable_excel(disable_df_path) + + process_one_signal(select_ids[0]) diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml new file mode 100644 index 0000000..d30264f --- /dev/null +++ b/dataset_config/HYS_config.yaml @@ -0,0 +1,13 @@ +select_id: + - 1302 + - 286 + - 950 + - 220 + - 229 + - 541 + - 582 + - 670 + - 684 + - 960 + +root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS \ No newline at end of file diff --git a/utils/HYS_FileReader.py b/utils/HYS_FileReader.py index 82e4584..d7c477a 100644 --- a/utils/HYS_FileReader.py +++ b/utils/HYS_FileReader.py @@ -34,7 +34,7 @@ def read_signal_txt(path: Union[str, Path]) -> np.ndarray: return df.iloc[:, 0].to_numpy() -def read_laebl_csv(path: Union[str, Path]) -> pd.DataFrame: +def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame: """ Read a CSV file and return it as a pandas DataFrame. @@ -49,6 +49,94 @@ def read_laebl_csv(path: Union[str, Path]) -> pd.DataFrame: # 直接用pandas读取 包含中文 故指定编码 df = pd.read_csv(path, encoding="gbk") + if verbose: + print(f"Label file read from {path}, number of rows: {len(df)}") + + # 统计打标情况 + # isLabeled=1 表示已打标 + # Event type 有值的为PSG导出的事件 + # Event type 为nan的为手动打标的事件 + # score=1 显著事件, score=2 为受干扰事件 score=3 为非显著应删除事件 + # 确认后的事件在correct_EventsType + # 输出事件信息 按照总计事件、低通气、中枢性、阻塞性、混合型按行输出 格式为 总计/来自PSG/手动/删除/未标注 + # Columns: + # Index Event type Stage Time Epoch Date Duration HR bef. HR extr. HR delta O2 bef. O2 min. O2 delta Body Position Validation Start End score remark correct_Start correct_End correct_EventsType isLabeled + # Event type: + # Hypopnea + # Central apnea + # Obstructive apnea + # Mixed apnea + + num_labeled = np.sum(df["isLabeled"] == 1) + num_psg_events = np.sum(df["Event type"].notna()) + num_manual_events = num_labeled - num_psg_events + num_deleted = np.sum(df["score"] == 3) + + # 统计事件 + num_total = np.sum((df["isLabeled"] == 1) & (df["score"] != 3)) + num_unlabeled = num_total - num_labeled + + num_psg_hyp = np.sum(df["Event type"] == "Hypopnea") + num_psg_csa = np.sum(df["Event type"] == "Central apnea") + num_psg_osa = np.sum(df["Event type"] == "Obstructive apnea") + num_psg_msa = np.sum(df["Event type"] == "Mixed apnea") + + num_hyp = np.sum((df["correct_EventsType"] == "Hypopnea") & (df["score"] != 3)) + num_csa = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] != 3)) + num_osa = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] != 3)) + num_msa = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] != 3)) + + num_manual_hyp = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Hypopnea")) + num_manual_csa = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Central apnea")) + num_manual_osa = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Obstructive apnea")) + num_manual_msa = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Mixed apnea")) + + num_deleted_hyp = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Hypopnea")) + num_deleted_csa = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Central apnea")) + num_deleted_osa = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Obstructive apnea")) + num_deleted_msa = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Mixed apnea")) + + num_unlabeled_hyp = np.sum((df["isLabeled"] == 0) & (df["correct_EventsType"] == "Hypopnea")) + num_unlabeled_csa = np.sum((df["isLabeled"] == 0) & (df["correct_EventsType"] == "Central apnea")) + num_unlabeled_osa = np.sum((df["isLabeled"] == 0) & (df["correct_EventsType"] == "Obstructive apnea")) + num_unlabeled_msa = np.sum((df["isLabeled"] == 0) & (df["correct_EventsType"] == "Mixed apnea")) + + + + if verbose: + print("Event Statistics:") + # 格式化输出 总计/来自PSG/手动/删除/未标注 指定宽度 + print("Type Total / PSG / Manual / Deleted / Unlabeled") + print(f"Hypopnea: {num_hyp:4d} / {num_psg_hyp:4d} / {num_manual_hyp:4d} / {num_deleted_hyp:4d} / {num_unlabeled_hyp:4d}") + print(f"Central apnea: {num_csa:4d} / {num_psg_csa:4d} / {num_manual_csa:4d} / {num_deleted_csa:4d} / {num_unlabeled_csa:4d}") + print(f"Obstructive ap: {num_osa:4d} / {num_psg_osa:4d} / {num_manual_osa:4d} / {num_deleted_osa:4d} / {num_unlabeled_osa:4d}") + print(f"Mixed apnea: {num_msa:4d} / {num_psg_msa:4d} / {num_manual_msa:4d} / {num_deleted_msa:4d} / {num_unlabeled_msa:4d}") + print(f"Total events: {num_total:4d} / {num_psg_events:4d} / {num_manual_events:4d} / {num_deleted:4d} / {num_unlabeled:4d}") + + + + df["Start"] = df["Start"].astype(int) df["End"] = df["End"].astype(int) + return df + + +def read_disable_excel(path: Union[str, Path]) -> pd.DataFrame: + """ + Read an Excel file and return it as a pandas DataFrame. + + Args: + path (str | Path): Path to the Excel file. + Returns: + pd.DataFrame: The content of the Excel file as a pandas DataFrame. + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + # 直接用pandas读取 + df = pd.read_excel(path) + df["id"] = df["id"].astype(int) + df["start"] = df["start"].astype(int) + df["end"] = df["end"].astype(int) return df \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py index e69de29..d2c0727 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -0,0 +1,2 @@ +from utils.HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel +from utils.operation_tools import load_dataset_info, generate_disable_mask \ No newline at end of file diff --git a/utils/operation_tools.py b/utils/operation_tools.py index bcc5062..a4118b9 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -4,7 +4,7 @@ from pathlib import Path import numpy as np import pandas as pd from matplotlib import pyplot as plt - +import yaml plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 @@ -252,5 +252,34 @@ def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, return values_nan, values +def load_dataset_info(yaml_path): + with open(yaml_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + select_ids = config.get('select_id', []) + root_path = config.get('root_path', None) + data_path = Path(root_path) + return select_ids, data_path + + +def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray: + disable_mask = np.ones(signal_second, dtype=int) + + for _, row in disable_df.iterrows(): + start = row["start"] + end = row["end"] + disable_mask[start:end] = 0 + return disable_mask + + +def generate_event_mask(signal_second: int, event_df) -> np.ndarray: + event_mask = np.zeros(signal_second, dtype=int) + + for _, row in event_df.iterrows(): + start = row["start"] + end = row["end"] + event_mask[start:end] = 1 + return event_mask + From 180d872cd7ce6aacf609921d3118e7d67c16a7c6 Mon Sep 17 00:00:00 2001 From: marques Date: Sun, 12 Oct 2025 20:30:46 +0800 Subject: [PATCH 05/28] Refactor event statistics calculations and improve output formatting --- utils/HYS_FileReader.py | 69 +++++++++++++++++++++++++++++------------ 1 file changed, 49 insertions(+), 20 deletions(-) diff --git a/utils/HYS_FileReader.py b/utils/HYS_FileReader.py index d7c477a..5812da8 100644 --- a/utils/HYS_FileReader.py +++ b/utils/HYS_FileReader.py @@ -7,6 +7,7 @@ import pandas as pd # 尝试导入 Polars try: import polars as pl + HAS_POLARS = True except ImportError: HAS_POLARS = False @@ -67,14 +68,15 @@ def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame: # Obstructive apnea # Mixed apnea - num_labeled = np.sum(df["isLabeled"] == 1) + num_total = np.sum((df["isLabeled"] == 1) & (df["score"] != 3)) + num_psg_events = np.sum(df["Event type"].notna()) - num_manual_events = num_labeled - num_psg_events + num_manual_events = np.sum(df["Event type"].isna()) + num_deleted = np.sum(df["score"] == 3) # 统计事件 - num_total = np.sum((df["isLabeled"] == 1) & (df["score"] != 3)) - num_unlabeled = num_total - num_labeled + num_unlabeled = np.sum(df["isLabeled"] == -1) num_psg_hyp = np.sum(df["Event type"] == "Hypopnea") num_psg_csa = np.sum(df["Event type"] == "Central apnea") @@ -82,9 +84,9 @@ def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame: num_psg_msa = np.sum(df["Event type"] == "Mixed apnea") num_hyp = np.sum((df["correct_EventsType"] == "Hypopnea") & (df["score"] != 3)) - num_csa = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] != 3)) - num_osa = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] != 3)) - num_msa = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] != 3)) + num_csa = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] != 3)) + num_osa = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] != 3)) + num_msa = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] != 3)) num_manual_hyp = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Hypopnea")) num_manual_csa = np.sum((df["Event type"].isna()) & (df["correct_EventsType"] == "Central apnea")) @@ -96,25 +98,52 @@ def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame: num_deleted_osa = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Obstructive apnea")) num_deleted_msa = np.sum((df["score"] == 3) & (df["correct_EventsType"] == "Mixed apnea")) - num_unlabeled_hyp = np.sum((df["isLabeled"] == 0) & (df["correct_EventsType"] == "Hypopnea")) - num_unlabeled_csa = np.sum((df["isLabeled"] == 0) & (df["correct_EventsType"] == "Central apnea")) - num_unlabeled_osa = np.sum((df["isLabeled"] == 0) & (df["correct_EventsType"] == "Obstructive apnea")) - num_unlabeled_msa = np.sum((df["isLabeled"] == 0) & (df["correct_EventsType"] == "Mixed apnea")) + num_unlabeled_hyp = np.sum((df["isLabeled"] == -1) & (df["Event type"] == "Hypopnea")) + num_unlabeled_csa = np.sum((df["isLabeled"] == -1) & (df["Event type"] == "Central apnea")) + num_unlabeled_osa = np.sum((df["isLabeled"] == -1) & (df["Event type"] == "Obstructive apnea")) + num_unlabeled_msa = np.sum((df["isLabeled"] == -1) & (df["Event type"] == "Mixed apnea")) + num_hyp_1_score = np.sum((df["correct_EventsType"] == "Hypopnea") & (df["score"] == 1)) + num_csa_1_score = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] == 1)) + num_osa_1_score = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] == 1)) + num_msa_1_score = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] == 1)) + num_hyp_2_score = np.sum((df["correct_EventsType"] == "Hypopnea") & (df["score"] == 2)) + num_csa_2_score = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] == 2)) + num_osa_2_score = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] == 2)) + num_msa_2_score = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] == 2)) + + num_hyp_3_score = np.sum((df["correct_EventsType"] == "Hypopnea") & (df["score"] == 3)) + num_csa_3_score = np.sum((df["correct_EventsType"] == "Central apnea") & (df["score"] == 3)) + num_osa_3_score = np.sum((df["correct_EventsType"] == "Obstructive apnea") & (df["score"] == 3)) + num_msa_3_score = np.sum((df["correct_EventsType"] == "Mixed apnea") & (df["score"] == 3)) + + num_1_score = np.sum(df["score"] == 1) + num_2_score = np.sum(df["score"] == 2) + num_3_score = np.sum(df["score"] == 3) if verbose: print("Event Statistics:") # 格式化输出 总计/来自PSG/手动/删除/未标注 指定宽度 - print("Type Total / PSG / Manual / Deleted / Unlabeled") - print(f"Hypopnea: {num_hyp:4d} / {num_psg_hyp:4d} / {num_manual_hyp:4d} / {num_deleted_hyp:4d} / {num_unlabeled_hyp:4d}") - print(f"Central apnea: {num_csa:4d} / {num_psg_csa:4d} / {num_manual_csa:4d} / {num_deleted_csa:4d} / {num_unlabeled_csa:4d}") - print(f"Obstructive ap: {num_osa:4d} / {num_psg_osa:4d} / {num_manual_osa:4d} / {num_deleted_osa:4d} / {num_unlabeled_osa:4d}") - print(f"Mixed apnea: {num_msa:4d} / {num_psg_msa:4d} / {num_manual_msa:4d} / {num_deleted_msa:4d} / {num_unlabeled_msa:4d}") - print(f"Total events: {num_total:4d} / {num_psg_events:4d} / {num_manual_events:4d} / {num_deleted:4d} / {num_unlabeled:4d}") - - + print(f"Type {'Total':^8s} / {'From PSG':^8s} / {'Manual':^8s} / {'Deleted':^8s} / {'Unlabeled':^8s}") + print( + f"Hyp: {num_hyp:^8d} / {num_psg_hyp:^8d} / {num_manual_hyp:^8d} / {num_deleted_hyp:^8d} / {num_unlabeled_hyp:^8d}") + print( + f"CSA: {num_csa:^8d} / {num_psg_csa:^8d} / {num_manual_csa:^8d} / {num_deleted_csa:^8d} / {num_unlabeled_csa:^8d}") + print( + f"OSA: {num_osa:^8d} / {num_psg_osa:^8d} / {num_manual_osa:^8d} / {num_deleted_osa:^8d} / {num_unlabeled_osa:^8d}") + print( + f"MSA: {num_msa:^8d} / {num_psg_msa:^8d} / {num_manual_msa:^8d} / {num_deleted_msa:^8d} / {num_unlabeled_msa:^8d}") + print( + f"All: {num_total:^8d} / {num_psg_events:^8d} / {num_manual_events:^8d} / {num_deleted:^8d} / {num_unlabeled:^8d}") + print("Score Statistics (only for non-deleted events and manual created events):") + print(f"Type {'Total':^8s} / {'Score 1':^8s} / {'Score 2':^8s} / {'Score 3':^8s}") + print(f"Hyp: {num_hyp:^8d} / {num_hyp_1_score:^8d} / {num_hyp_2_score:^8d} / {num_hyp_3_score:^8d}") + print(f"CSA: {num_csa:^8d} / {num_csa_1_score:^8d} / {num_csa_2_score:^8d} / {num_csa_3_score:^8d}") + print(f"OSA: {num_osa:^8d} / {num_osa_1_score:^8d} / {num_osa_2_score:^8d} / {num_osa_3_score:^8d}") + print(f"MSA: {num_msa:^8d} / {num_msa_1_score:^8d} / {num_msa_2_score:^8d} / {num_msa_3_score:^8d}") + print(f"All: {num_total:^8d} / {num_1_score:^8d} / {num_2_score:^8d} / {num_3_score:^8d}") df["Start"] = df["Start"].astype(int) df["End"] = df["End"].astype(int) @@ -139,4 +168,4 @@ def read_disable_excel(path: Union[str, Path]) -> pd.DataFrame: df["id"] = df["id"].astype(int) df["start"] = df["start"].astype(int) df["end"] = df["end"].astype(int) - return df \ No newline at end of file + return df From 0900a9b489d7cf4638cf1bf9c0002fef344f3651 Mon Sep 17 00:00:00 2001 From: marques Date: Thu, 16 Oct 2025 19:21:05 +0800 Subject: [PATCH 06/28] Add event mask generation and sliding window segmentation for signal analysis --- HYS_process.py | 5 ++++ signal_method/rule_base_event.py | 2 -- utils/__init__.py | 3 +- utils/event_map.py | 7 +++++ utils/operation_tools.py | 49 ++++++++++++++++++++++++++++---- 5 files changed, 57 insertions(+), 9 deletions(-) create mode 100644 utils/event_map.py diff --git a/HYS_process.py b/HYS_process.py index f3de6d8..c0f4ddb 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -52,6 +52,7 @@ def process_one_signal(samp_id): label_data = utils.read_label_csv(label_path) + label_mask = utils.generate_event_mask(signal_second, label_data) manual_disable_mask = utils.generate_disable_mask(signal_second, all_samp_disable_df[all_samp_disable_df["id"] == samp_id]) print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}") @@ -60,6 +61,10 @@ def process_one_signal(samp_id): + + + + if __name__ == '__main__': yaml_path = Path("./dataset_config/HYS_config.yaml") disable_df_path = Path("./排除区间.xlsx") diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index 74e5929..ddeeaf3 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -201,7 +201,5 @@ def position_based_sleep_recognition(signal_data, movement_mask, sampling_rate=1 else: position_changes.append(0) # 0表示不存在姿势变化 - # print(i,movement_start[i], movement_end[i], round(left_amplitude_change, 2), round(right_amplitude_change, 2), round(left_energy_change, 2), round(right_energy_change, 2)) - return position_changes, position_change_times diff --git a/utils/__init__.py b/utils/__init__.py index d2c0727..87cc5b9 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,2 +1,3 @@ from utils.HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel -from utils.operation_tools import load_dataset_info, generate_disable_mask \ No newline at end of file +from utils.operation_tools import load_dataset_info, generate_disable_mask, generate_event_mask +from utils.event_map import E2N \ No newline at end of file diff --git a/utils/event_map.py b/utils/event_map.py new file mode 100644 index 0000000..c85a027 --- /dev/null +++ b/utils/event_map.py @@ -0,0 +1,7 @@ +# apnea event type to number mapping +E2N = { + "Hypopnea": 1, + "Central apnea": 2, + "Obstructive apnea": 3, + "Mixed apnea": 4 +} \ No newline at end of file diff --git a/utils/operation_tools.py b/utils/operation_tools.py index a4118b9..a63b2ad 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -5,7 +5,7 @@ import numpy as np import pandas as pd from matplotlib import pyplot as plt import yaml - +from utils.event_map import E2N plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 @@ -272,14 +272,51 @@ def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray: return disable_mask -def generate_event_mask(signal_second: int, event_df) -> np.ndarray: +def generate_event_mask(signal_second: int, event_df): event_mask = np.zeros(signal_second, dtype=int) + score_mask = np.zeros(signal_second, dtype=int) + # 剔除start = -1 的行 + event_df = event_df[event_df["correct_Start"] >= 0] for _, row in event_df.iterrows(): - start = row["start"] - end = row["end"] - event_mask[start:end] = 1 - return event_mask + start = row["correct_Start"] + end = row["correct_End"] + 1 + event_mask[start:end] = E2N[row["correct_EventsType"]] + score_mask[start:end] = row["score"] + return event_mask, score_mask + + +def slide_window_segment(signal_second: int, window_second, step_second, event_mask, score_mask, disable_mask, ): + # 避开不可用区域进行滑窗分割 + # 滑动到不可用区域时,如果窗口内一侧的不可用区域不超过1/2 windows_second,则继续滑动, 用reflect填充 + # 如果不可用区间大于1/2的window_second,则跳过该不可用区间,继续滑动 + # TODO 对于短时强体动区间 考虑填充或者掩码覆盖 + # + half_window_second = window_second // 2 + for start_second in range(0, signal_second - window_second + 1, step_second): + end_second = start_second + window_second + + # 检查当前窗口是否包含不可用区域 + windows_middle_second = (start_second + end_second) // 2 + if np.sum(disable_mask[start_second:end_second] > 1) > half_window_second: + # 如果窗口内不可用区域超过一半,跳过该窗口 + continue + + if disable_mask[start_second:end_second] > half_window_second: + + + + + + # 确保新的起始位置不超过信号长度 + if start_second + window_second > signal_second: + break + + window_event = event_mask[start_second:end_second] + window_score = score_mask[start_second:end_second] + window_disable = disable_mask[start_second:end_second] + + yield start_second, end_second, window_event, window_score, window_disable From ddedfbf2cf65e0e0e183f289d0d56decfb24dce8 Mon Sep 17 00:00:00 2001 From: marques Date: Thu, 16 Oct 2025 19:58:13 +0800 Subject: [PATCH 07/28] Enhance position-based sleep recognition with version updates and amplitude calculations --- signal_method/rule_base_event.py | 108 ++++++++++++++++++++++++++----- utils/operation_tools.py | 2 - 2 files changed, 92 insertions(+), 18 deletions(-) diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index ddeeaf3..e9ad3df 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -2,6 +2,7 @@ from utils.operation_tools import timing_decorator import numpy as np from utils.operation_tools import merge_short_gaps, remove_short_durations + @timing_decorator() def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, stride_sec=None, amplitude_threshold=50, merge_gap_sec=10, min_duration_sec=5): @@ -110,8 +111,9 @@ def get_typical_segment_for_continues_signal(signal_data, sampling_rate=100, win # 仅对比相邻片段的幅值指标,如果存在显著差异,则认为存在睡姿变化,即每个体动相邻的30秒内存在睡姿变化,如果片段不足30秒,则按实际长度对比 @timing_decorator() -def position_based_sleep_recognition(signal_data, movement_mask, sampling_rate=100, window_size_sec=30, - interval_to_movement=10): +def position_based_sleep_recognition_v1(signal_data, movement_mask, sampling_rate=100, window_size_sec=30, + interval_to_movement=10): + mav_calc_window_sec = 2 # 计算mav的窗口大小,单位秒 # 获取有效片段起止位置 valid_mask = 1 - movement_mask valid_starts = np.where(np.diff(np.concatenate([[0], valid_mask])) == 1)[0] @@ -150,17 +152,21 @@ def position_based_sleep_recognition(signal_data, movement_mask, sampling_rate=1 right_end = end # 新的end - start确保为200的整数倍 - if (left_end - left_start) % (2 * sampling_rate) != 0: - left_end = left_start + ((left_end - left_start) // (2 * sampling_rate)) * (2 * sampling_rate) - if (right_end - right_start) % (2 * sampling_rate) != 0: - right_end = right_start + ((right_end - right_start) // (2 * sampling_rate)) * (2 * sampling_rate) + if (left_end - left_start) % (mav_calc_window_sec * sampling_rate) != 0: + left_end = left_start + ((left_end - left_start) // (mav_calc_window_sec * sampling_rate)) * ( + mav_calc_window_sec * sampling_rate) + if (right_end - right_start) % (mav_calc_window_sec * sampling_rate) != 0: + right_end = right_start + ((right_end - right_start) // (mav_calc_window_sec * sampling_rate)) * ( + mav_calc_window_sec * sampling_rate) # 计算每个片段的幅值指标 - left_mav = np.mean(np.max(signal_data[left_start:left_end].reshape(-1, 2 * sampling_rate), axis=0)) - np.mean( - np.min(signal_data[left_start:left_end].reshape(-1, 2 * sampling_rate), axis=0)) + left_mav = np.mean(np.max(signal_data[left_start:left_end].reshape(-1, mav_calc_window_sec * sampling_rate), + axis=0)) - np.mean( + np.min(signal_data[left_start:left_end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) right_mav = np.mean( - np.max(signal_data[right_start:right_end].reshape(-1, 2 * sampling_rate), axis=0)) - np.mean( - np.min(signal_data[right_start:right_end].reshape(-1, 2 * sampling_rate), axis=0)) + np.max(signal_data[right_start:right_end].reshape(-1, mav_calc_window_sec * sampling_rate), + axis=0)) - np.mean( + np.min(signal_data[right_start:right_end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) segment_left_average_amplitude.append(left_mav) segment_right_average_amplitude.append(right_mav) @@ -171,6 +177,10 @@ def position_based_sleep_recognition(signal_data, movement_mask, sampling_rate=1 position_changes = [] position_change_times = [] + # 判断是否存在显著变化 (可根据实际情况调整阈值) + threshold_amplitude = 0.1 # 幅值变化阈值 + threshold_energy = 0.1 # 能量变化阈值 + for i in range(1, len(segment_left_average_amplitude)): # 计算幅值指标的变化率 left_amplitude_change = abs(segment_left_average_amplitude[i] - segment_left_average_amplitude[i - 1]) / max( @@ -184,15 +194,11 @@ def position_based_sleep_recognition(signal_data, movement_mask, sampling_rate=1 right_energy_change = abs(segment_right_average_energy[i] - segment_right_average_energy[i - 1]) / max( segment_right_average_energy[i - 1], 1e-6) - # 判断是否存在显著变化 (可根据实际情况调整阈值) - threshold_amplitude = 0.1 # 幅值变化阈值 - threshold_energy = 0.1 # 能量变化阈值 - # 如果左右通道中的任一通道同时满足幅值和能量的变化阈值,则认为存在姿势变化 left_significant_change = (left_amplitude_change > threshold_amplitude) and ( - left_energy_change > threshold_energy) + left_energy_change > threshold_energy) right_significant_change = (right_amplitude_change > threshold_amplitude) and ( - right_energy_change > threshold_energy) + right_energy_change > threshold_energy) if left_significant_change or right_significant_change: # 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示 @@ -203,3 +209,73 @@ def position_based_sleep_recognition(signal_data, movement_mask, sampling_rate=1 return position_changes, position_change_times + +def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rate=100, window_size_sec=30): + """ + + :param signal_data: + :param movement_mask: mask的采样率为1Hz + :param sampling_rate: + :param window_size_sec: + :return: + """ + mav_calc_window_sec = 2 # 计算mav的窗口大小,单位秒 + # 获取有效片段起止位置 + valid_mask = 1 - movement_mask + valid_starts = np.where(np.diff(np.concatenate([[0], valid_mask])) == 1)[0] + valid_ends = np.where(np.diff(np.concatenate([valid_mask, [0]])) == -1)[0] + + # 对于有效区间大于12分钟的,拆成多个5分钟 + + movement_start = np.where(np.diff(np.concatenate([[0], movement_mask])) == 1)[0] + movement_end = np.where(np.diff(np.concatenate([movement_mask, [0]])) == -1)[0] + + segment_average_amplitude = [] + segment_average_energy = [] + + for start, end in zip(valid_starts, valid_ends): + start *= sampling_rate + end *= sampling_rate + # 避免过短的片段 + if end - start <= sampling_rate: # 小于1秒的片段不考虑 + continue + + # 新的end - start确保为200的整数倍 + if (end - start) % (mav_calc_window_sec * sampling_rate) != 0: + end = start + ((end - start) // (mav_calc_window_sec * sampling_rate)) * ( + mav_calc_window_sec * sampling_rate) + + # 计算每个片段的幅值指标 + mav = np.mean( + np.max(signal_data[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) - np.mean( + np.min(signal_data[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) + segment_average_amplitude.append(mav) + + energy = np.sum(np.abs(signal_data[start:end] ** 2)) + segment_average_energy.append(energy) + + position_changes = [] + position_change_times = [] + # 判断是否存在显著变化 (可根据实际情况调整阈值) + threshold_amplitude = 0.1 # 幅值变化阈值 + threshold_energy = 0.1 # 能量变化阈值 + + for i in range(1, len(segment_average_amplitude)): + # 计算幅值指标的变化率 + amplitude_change = abs(segment_average_amplitude[i] - segment_average_amplitude[i - 1]) / max( + segment_average_amplitude[i - 1], 1e-6) + + # 计算能量指标的变化率 + energy_change = abs(segment_average_energy[i] - segment_average_energy[i - 1]) / max( + segment_average_energy[i - 1], 1e-6) + + significant_change = (amplitude_change > threshold_amplitude) and (energy_change > threshold_energy) + + if significant_change: + # 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示 + position_changes.append(1) + position_change_times.append((movement_start[i - 1], movement_end[i - 1])) + else: + position_changes.append(0) # 0表示不存在姿势变化 + + return position_changes, position_change_times diff --git a/utils/operation_tools.py b/utils/operation_tools.py index a63b2ad..c83c82e 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -318,5 +318,3 @@ def slide_window_segment(signal_second: int, window_second, step_second, event_m yield start_second, end_second, window_event, window_score, window_disable - - From f79f42fae7999a25134e4d31721012dce4f36bfe Mon Sep 17 00:00:00 2001 From: marques Date: Thu, 16 Oct 2025 20:46:42 +0800 Subject: [PATCH 08/28] Remove unnecessary comment regarding splitting valid intervals in rule_base_event.py --- signal_method/rule_base_event.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index e9ad3df..95a540f 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -225,8 +225,6 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat valid_starts = np.where(np.diff(np.concatenate([[0], valid_mask])) == 1)[0] valid_ends = np.where(np.diff(np.concatenate([valid_mask, [0]])) == -1)[0] - # 对于有效区间大于12分钟的,拆成多个5分钟 - movement_start = np.where(np.diff(np.concatenate([[0], movement_mask])) == 1)[0] movement_end = np.where(np.diff(np.concatenate([movement_mask, [0]])) == -1)[0] From 40aad46d6f042d03226a3032a462888d859a0ed7 Mon Sep 17 00:00:00 2001 From: marques Date: Thu, 23 Oct 2025 15:43:28 +0800 Subject: [PATCH 09/28] =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E5=A4=9A?= =?UTF-8?q?=E4=B8=AA=E6=96=87=E4=BB=B6=EF=BC=8C=E5=AE=8C=E6=88=90=E5=9F=BA?= =?UTF-8?q?=E6=9C=AC=E8=AF=BB=E5=8F=96=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- HYS_process.py | 40 +++++++- dataset_config/HYS_config.yaml | 24 ++++- signal_method/__init__.py | 1 + signal_method/rule_base_event.py | 169 +++++++++++++++++++++++++++++++ utils/HYS_FileReader.py | 5 +- utils/__init__.py | 5 +- utils/operation_tools.py | 122 +--------------------- utils/signal_process.py | 92 +++++++++++++++++ 8 files changed, 331 insertions(+), 127 deletions(-) create mode 100644 utils/signal_process.py diff --git a/HYS_process.py b/HYS_process.py index c0f4ddb..9422d9c 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -24,7 +24,7 @@ from pathlib import Path from typing import Union import utils import numpy as np - +import signal_method @@ -50,13 +50,40 @@ def process_one_signal(samp_id): signal_second = signal_length // signal_fs print(f"signal_second: {signal_second}") + # 滤波 + # 50Hz陷波滤波器 + # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) + resp_data = utils.butterworth(data=signal_data, _type=conf["resp"]["filter_type"], low_cut=conf["resp"]["low_cut"], + high_cut=conf["resp"]["high_cut"], order=conf["resp"]["order"], sample_rate=signal_fs) - label_data = utils.read_label_csv(label_path) - label_mask = utils.generate_event_mask(signal_second, label_data) + bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg"]["filter_type"], low_cut=conf["bcg"]["low_cut"], + high_cut=conf["bcg"]["high_cut"], order=conf["bcg"]["order"], sample_rate=signal_fs) - manual_disable_mask = utils.generate_disable_mask(signal_second, all_samp_disable_df[all_samp_disable_df["id"] == samp_id]) + + label_data = utils.read_label_csv(path=label_path) + label_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data) + + manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[all_samp_disable_df["id"] == samp_id]) print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}") + # 分析Resp的低幅值区间 + resp_low_amp_conf = getattr(conf, "resp_low_amp", None) + if resp_low_amp_conf is not None: + resp_low_amp_mask = signal_method.detect_low_amplitude_signal( + signal_data=resp_data, + sampling_rate=signal_fs, + window_size_sec=resp_low_amp_conf["window_size_sec"], + stride_sec=resp_low_amp_conf["stride_sec"], + amplitude_threshold=resp_low_amp_conf["amplitude_threshold"], + merge_gap_sec=resp_low_amp_conf["merge_gap_sec"], + min_duration_sec=resp_low_amp_conf["min_duration_sec"] + ) + else: + resp_low_amp_mask = None + + # 分析Resp的高幅值伪迹区间 + resp_move + @@ -69,7 +96,10 @@ if __name__ == '__main__': yaml_path = Path("./dataset_config/HYS_config.yaml") disable_df_path = Path("./排除区间.xlsx") - select_ids, root_path = utils.load_dataset_info(yaml_path) + conf = utils.load_dataset_conf(yaml_path) + select_ids = conf["select_ids"] + root_path = Path(conf["root_path"]) + print(f"select_ids: {select_ids}") print(f"root_path: {root_path}") diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index d30264f..dfff364 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -1,4 +1,4 @@ -select_id: +select_ids: - 1302 - 286 - 950 @@ -10,4 +10,24 @@ select_id: - 684 - 960 -root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS \ No newline at end of file +root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS + +resp_filter: + filter_type: bandpass + low_cut: 0.01 + high_cut: 0.7 + order: 10 + +resp_low_amp: + windows_size_sec: 1 + stride_sec: None + amplitude_threshold: 50 + merge_gap_sec: 10 + min_duration_sec: 5 + +bcg_filter: + filter_type: bandpass + low_cut: 1 + high_cut: 10 + order: 10 + diff --git a/signal_method/__init__.py b/signal_method/__init__.py index e69de29..46eac36 100644 --- a/signal_method/__init__.py +++ b/signal_method/__init__.py @@ -0,0 +1 @@ +from signal_method.rule_base_event import detect_low_amplitude_signal \ No newline at end of file diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index 95a540f..8de49da 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -3,6 +3,175 @@ import numpy as np from utils.operation_tools import merge_short_gaps, remove_short_durations +@timing_decorator() +def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=None, + std_median_multiplier=4.5, compare_intervals_sec=[30, 60], + interval_multiplier=2.5, + merge_gap_sec=10, min_duration_sec=5, + low_amplitude_periods=None): + """ + 检测信号中的体动状态,结合两种方法:标准差比较和前后窗口幅值对比。 + 使用反射填充处理信号边界。 + + 参数: + - signal_data: numpy array,输入的信号数据 + - sampling_rate: int,信号的采样率(Hz) + - window_size_sec: float,分析窗口的时长(秒),默认值为 2 秒 + - stride_sec: float,窗口滑动步长(秒),默认值为None(等于window_size_sec,无重叠) + - std_median_multiplier: float,标准差中位数的乘数阈值,默认值为 4.5 + - compare_intervals_sec: list,用于比较的时间间隔列表(秒),默认为 [30, 60] + - interval_multiplier: float,间隔中位数的上限乘数,默认值为 2.5 + - merge_gap_sec: float,要合并的体动状态之间的最大间隔(秒),默认值为 10 秒 + - min_duration_sec: float,要保留的体动状态的最小持续时间(秒),默认值为 5 秒 + - low_amplitude_periods: numpy array,低幅值期间的掩码(1表示低幅值期间),默认为None + + 返回: + - movement_mask: numpy array,体动状态的掩码(1表示体动,0表示睡眠) + """ + # 计算窗口大小(样本数) + window_samples = int(window_size_sec * sampling_rate) + + # 如果未指定步长,设置为窗口大小(无重叠) + if stride_sec is None: + stride_sec = window_size_sec + + # 计算步长(样本数) + stride_samples = int(stride_sec * sampling_rate) + + # 确保步长至少为1 + stride_samples = max(1, stride_samples) + + # 计算需要的最大填充大小(基于比较间隔) + max_interval_samples = int(max(compare_intervals_sec) * sampling_rate) + + # 应用反射填充以正确处理边界 + # 填充大小为最大比较间隔的一半,以确保边界有足够的上下文 + pad_size = max_interval_samples + padded_signal = np.pad(signal_data, pad_size, mode='reflect') + + # 计算填充后的窗口数量 + num_windows = max(1, (len(padded_signal) - window_samples) // stride_samples + 1) + + # 初始化窗口标准差数组 + window_std = np.zeros(num_windows) + # 计算每个窗口的标准差 + # 分窗计算标准差 + for i in range(num_windows): + start_idx = i * stride_samples + end_idx = min(start_idx + window_samples, len(padded_signal)) + + # 处理窗口,包括可能不完整的最后一个窗口 + window_data = padded_signal[start_idx:end_idx] + if len(window_data) > 0: + window_std[i] = np.std(window_data, ddof=1) + else: + window_std[i] = 0 + + # 计算原始信号对应的窗口索引范围 + # 填充后,原始信号从pad_size开始 + orig_start_window = pad_size // stride_samples + if stride_sec == 1: + orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + else: + orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + 1 + + # 只保留原始信号对应的窗口标准差 + original_window_std = window_std[orig_start_window:orig_end_window] + num_original_windows = len(original_window_std) + + # 创建时间点数组(秒) + time_points = np.arange(num_original_windows) * stride_sec + + # # 如果提供了低幅值期间的掩码,则在计算全局中位数时排除这些期间 + # if low_amplitude_periods is not None and len(low_amplitude_periods) == num_original_windows: + # valid_std = original_window_std[low_amplitude_periods == 0] + # if len(valid_std) == 0: # 如果所有窗口都在低幅值期间 + # valid_std = original_window_std # 使用全部窗口 + # else: + # valid_std = original_window_std + + valid_std = original_window_std ##20250418新修改 + + #---------------------- 方法一:基于STD的体动判定 ----------------------# + # 计算所有有效窗口标准差的中位数 + median_std = np.median(valid_std) + + # 当窗口标准差大于中位数的倍数,判定为体动状态 + std_movement = np.where(original_window_std > median_std * std_median_multiplier, 1, 0) + + #------------------ 方法二:基于前后信号幅值变化的体动判定 ------------------# + amplitude_movement = np.zeros(num_original_windows, dtype=int) + + # 定义基于时间粒度的比较间隔索引 + compare_intervals_idx = [int(interval // stride_sec) for interval in compare_intervals_sec] + + # 逐窗口判断 + for win_idx in range(num_original_windows): + # 全局索引(在填充后的窗口数组中) + global_win_idx = win_idx + orig_start_window + + # 对每个比较间隔进行检查 + for interval_idx in compare_intervals_idx: + # 确定比较范围的结束索引(在填充后的窗口数组中) + end_idx = min(global_win_idx + interval_idx, len(window_std)) + + # 提取相应时间范围内的标准差值 + if global_win_idx < end_idx: + interval_std = window_std[global_win_idx:end_idx] + + # 计算该间隔的中位数 + interval_median = np.median(interval_std) + + # 计算上下阈值 + upper_threshold = interval_median * interval_multiplier + + # 检查当前窗口是否超出阈值范围,如果超出则直接标记为体动 + if window_std[global_win_idx] > upper_threshold: + amplitude_movement[win_idx] = 1 + break # 一旦确定为体动就不需要继续检查其他间隔 + + # 将两种方法的结果合并:只要其中一种方法判定为体动,最终结果就是体动 + movement_mask = np.logical_or(std_movement, amplitude_movement).astype(int) + raw_movement_mask = movement_mask + + # 如果需要合并间隔小的体动状态 + if merge_gap_sec > 0: + movement_mask = merge_short_gaps(movement_mask, time_points, merge_gap_sec) + + # 如果需要移除短时体动状态 + if min_duration_sec > 0: + movement_mask = remove_short_durations(movement_mask, time_points, min_duration_sec) + + # raw_movement_mask, movement_mask恢复对应秒数,而不是点数 + raw_movement_mask = raw_movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate] + movement_mask = movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate] + + + # 比较剔除的体动,如果被剔除的体动所在区域有高于3std的幅值,则不剔除 + removed_movement_mask = (raw_movement_mask - movement_mask) > 0 + removed_movement_start = np.where(np.diff(np.concatenate([[0], removed_movement_mask])) == 1)[0] + removed_movement_end = np.where(np.diff(np.concatenate([removed_movement_mask, [0]])) == -1)[0] + + for start, end in zip(removed_movement_start, removed_movement_end): + # print(start ,end) + # 计算剔除的体动区域的幅值 + if np.nanmax(signal_data[start*sampling_rate:(end+1)*sampling_rate]) > median_std * std_median_multiplier: + movement_mask[start:end+1] = 1 + + # raw体动起止位置 [[start, end], [start, end], ...] + raw_movement_start = np.where(np.diff(np.concatenate([[0], raw_movement_mask])) == 1)[0] + raw_movement_end = np.where(np.diff(np.concatenate([raw_movement_mask, [0]])) == -1)[0] + 1 + raw_movement_position_list = [[start, end] for start, end in zip(raw_movement_start, raw_movement_end)] + + # merge体动起止位置 [[start, end], [start, end], ...] + movement_start = np.where(np.diff(np.concatenate([[0], movement_mask])) == 1)[0] + movement_end = np.where(np.diff(np.concatenate([movement_mask, [0]])) == -1)[0] + 1 + movement_position_list = [[start, end] for start, end in zip(movement_start, movement_end)] + + return raw_movement_mask, movement_mask, raw_movement_position_list, movement_position_list + + + @timing_decorator() def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, stride_sec=None, amplitude_threshold=50, merge_gap_sec=10, min_duration_sec=5): diff --git a/utils/HYS_FileReader.py b/utils/HYS_FileReader.py index 5812da8..dd65bab 100644 --- a/utils/HYS_FileReader.py +++ b/utils/HYS_FileReader.py @@ -29,7 +29,7 @@ def read_signal_txt(path: Union[str, Path]) -> np.ndarray: if HAS_POLARS: df = pl.read_csv(path, has_header=False, infer_schema_length=0) - return df[:, 0].to_numpy() + return df[:, 0].to_numpy().astype(float) else: df = pd.read_csv(path, header=None, dtype=float) return df.iloc[:, 0].to_numpy() @@ -41,8 +41,11 @@ def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame: Args: path (str | Path): Path to the CSV file. + verbose (bool): Returns: pd.DataFrame: The content of the CSV file as a pandas DataFrame. + :param path: + :param verbose: """ path = Path(path) if not path.exists(): diff --git a/utils/__init__.py b/utils/__init__.py index 87cc5b9..faaebaf 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,3 +1,4 @@ from utils.HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel -from utils.operation_tools import load_dataset_info, generate_disable_mask, generate_event_mask -from utils.event_map import E2N \ No newline at end of file +from utils.operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask +from utils.event_map import E2N +from utils.signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter \ No newline at end of file diff --git a/utils/operation_tools.py b/utils/operation_tools.py index c83c82e..f775d73 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -47,85 +47,6 @@ def read_auto(file_path): else: raise ValueError('这个文件类型不支持,需要自己写读取程序') -@timing_decorator() -def Butterworth(data, type, low_cut=0.0, high_cut=0.0, order=10,sample_rate=1000): - - if type == "lowpass": # 低通滤波处理 - sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos') - return signal.sosfiltfilt(sos, np.array(data)) - elif type == "bandpass": # 带通滤波处理 - low = low_cut / (sample_rate * 0.5) - high = high_cut / (sample_rate * 0.5) - sos = signal.butter(order, [low, high], btype='bandpass', output='sos') - return signal.sosfiltfilt(sos, np.array(data)) - elif type == "highpass": # 高通滤波处理 - sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos') - return signal.sosfiltfilt(sos, np.array(data)) - else: # 警告,滤波器类型必须有 - raise ValueError("Please choose a type of fliter") - -@timing_decorator() -def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000): - """ - 高效整数倍降采样长信号(适合8小时以上),分段处理以优化内存和速度。 - - 参数: - original_signal : array-like, 原始信号数组 - original_fs : float, 原始采样率 (Hz) - target_fs : float, 目标采样率 (Hz) - chunk_size : int, 每段处理的样本数,默认100000 - - 返回: - downsampled_signal : array-like, 降采样后的信号 - """ - # 输入验证 - if not isinstance(original_signal, np.ndarray): - original_signal = np.array(original_signal) - if original_fs <= target_fs: - raise ValueError("目标采样率必须小于原始采样率") - if target_fs <= 0 or original_fs <= 0: - raise ValueError("采样率必须为正数") - - # 计算降采样因子(必须为整数) - downsample_factor = original_fs / target_fs - if not downsample_factor.is_integer(): - raise ValueError("降采样因子必须为整数倍") - downsample_factor = int(downsample_factor) - - # 计算总输出长度 - total_length = len(original_signal) - output_length = total_length // downsample_factor - - # 初始化输出数组 - downsampled_signal = np.zeros(output_length) - - # 分段处理 - for start in range(0, total_length, chunk_size): - end = min(start + chunk_size, total_length) - chunk = original_signal[start:end] - - # 使用decimate进行整数倍降采样 - chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True) - - # 计算输出位置 - out_start = start // downsample_factor - out_end = out_start + len(chunk_downsampled) - if out_end > output_length: - chunk_downsampled = chunk_downsampled[:output_length - out_start] - - downsampled_signal[out_start:out_end] = chunk_downsampled - - return downsampled_signal - -@timing_decorator() -def average_filter(raw_data, sample_rate, window_size=20): - kernel = np.ones(window_size * sample_rate) / (window_size * sample_rate) - filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect') - convolve_filter_signal = raw_data - filtered - return convolve_filter_signal - - - def merge_short_gaps(state_sequence, time_points, max_gap_sec): """ @@ -252,14 +173,14 @@ def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, return values_nan, values -def load_dataset_info(yaml_path): +def load_dataset_conf(yaml_path): with open(yaml_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) - select_ids = config.get('select_id', []) - root_path = config.get('root_path', None) - data_path = Path(root_path) - return select_ids, data_path + # select_ids = config.get('select_id', []) + # root_path = config.get('root_path', None) + # data_path = Path(root_path) + return config def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray: @@ -285,36 +206,3 @@ def generate_event_mask(signal_second: int, event_df): score_mask[start:end] = row["score"] return event_mask, score_mask - -def slide_window_segment(signal_second: int, window_second, step_second, event_mask, score_mask, disable_mask, ): - # 避开不可用区域进行滑窗分割 - # 滑动到不可用区域时,如果窗口内一侧的不可用区域不超过1/2 windows_second,则继续滑动, 用reflect填充 - # 如果不可用区间大于1/2的window_second,则跳过该不可用区间,继续滑动 - # TODO 对于短时强体动区间 考虑填充或者掩码覆盖 - # - half_window_second = window_second // 2 - for start_second in range(0, signal_second - window_second + 1, step_second): - end_second = start_second + window_second - - # 检查当前窗口是否包含不可用区域 - windows_middle_second = (start_second + end_second) // 2 - if np.sum(disable_mask[start_second:end_second] > 1) > half_window_second: - # 如果窗口内不可用区域超过一半,跳过该窗口 - continue - - if disable_mask[start_second:end_second] > half_window_second: - - - - - - # 确保新的起始位置不超过信号长度 - if start_second + window_second > signal_second: - break - - window_event = event_mask[start_second:end_second] - window_score = score_mask[start_second:end_second] - window_disable = disable_mask[start_second:end_second] - - yield start_second, end_second, window_event, window_score, window_disable - diff --git a/utils/signal_process.py b/utils/signal_process.py new file mode 100644 index 0000000..dea0ea1 --- /dev/null +++ b/utils/signal_process.py @@ -0,0 +1,92 @@ +from utils.operation_tools import timing_decorator +import numpy as np +from scipy import signal, ndimage + + +@timing_decorator() +def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10,sample_rate=1000): + + if _type == "lowpass": # 低通滤波处理 + sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + elif _type == "bandpass": # 带通滤波处理 + low = low_cut / (sample_rate * 0.5) + high = high_cut / (sample_rate * 0.5) + sos = signal.butter(order, [low, high], btype='bandpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + elif _type == "highpass": # 高通滤波处理 + sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + else: # 警告,滤波器类型必须有 + raise ValueError("Please choose a type of fliter") + + +@timing_decorator() +def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000): + """ + 高效整数倍降采样长信号,分段处理以优化内存和速度。 + + 参数: + original_signal : array-like, 原始信号数组 + original_fs : float, 原始采样率 (Hz) + target_fs : float, 目标采样率 (Hz) + chunk_size : int, 每段处理的样本数,默认100000 + + 返回: + downsampled_signal : array-like, 降采样后的信号 + """ + # 输入验证 + if not isinstance(original_signal, np.ndarray): + original_signal = np.array(original_signal) + if original_fs <= target_fs: + raise ValueError("目标采样率必须小于原始采样率") + if target_fs <= 0 or original_fs <= 0: + raise ValueError("采样率必须为正数") + + # 计算降采样因子(必须为整数) + downsample_factor = original_fs / target_fs + if not downsample_factor.is_integer(): + raise ValueError("降采样因子必须为整数倍") + downsample_factor = int(downsample_factor) + + # 计算总输出长度 + total_length = len(original_signal) + output_length = total_length // downsample_factor + + # 初始化输出数组 + downsampled_signal = np.zeros(output_length) + + # 分段处理 + for start in range(0, total_length, chunk_size): + end = min(start + chunk_size, total_length) + chunk = original_signal[start:end] + + # 使用decimate进行整数倍降采样 + chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True) + + # 计算输出位置 + out_start = start // downsample_factor + out_end = out_start + len(chunk_downsampled) + if out_end > output_length: + chunk_downsampled = chunk_downsampled[:output_length - out_start] + + downsampled_signal[out_start:out_end] = chunk_downsampled + + return downsampled_signal + +@timing_decorator() +def average_filter(raw_data, sample_rate, window_size=20): + kernel = np.ones(window_size * sample_rate) / (window_size * sample_rate) + filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect') + convolve_filter_signal = raw_data - filtered + return convolve_filter_signal + + +# 陷波滤波器 +@timing_decorator() +def notch_filter(data, notch_freq=50.0, quality_factor=30.0, sample_rate=1000): + nyquist = 0.5 * sample_rate + norm_notch_freq = notch_freq / nyquist + b, a = signal.iirnotch(norm_notch_freq, quality_factor) + filtered_data = signal.filtfilt(b, a, data) + return filtered_data From 9fdbc4a1cba0e7e4e56cd0b93a11c968d433c290 Mon Sep 17 00:00:00 2001 From: marques Date: Wed, 29 Oct 2025 10:53:14 +0800 Subject: [PATCH 10/28] Add signal drawing functionality and enhance signal processing methods --- HYS_process.py | 160 ++++++++++++++++++++++++++++--- dataset_config/HYS_config.yaml | 26 ++++- draw_tools/__init__.py | 1 + draw_tools/draw_statics.py | 130 ++++++++++++++++++++++++- signal_method/__init__.py | 2 +- signal_method/rule_base_event.py | 4 +- utils/__init__.py | 2 +- utils/signal_process.py | 20 +++- 8 files changed, 318 insertions(+), 27 deletions(-) diff --git a/HYS_process.py b/HYS_process.py index 9422d9c..85fb7a3 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -21,12 +21,14 @@ todo: 使用mask 屏蔽无用区间 """ from pathlib import Path -from typing import Union + +import draw_tools import utils import numpy as np import signal_method - - +import os +from matplotlib import pyplot as plt +os.environ['DISPLAY'] = "localhost:10.0" def process_one_signal(samp_id): signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) @@ -41,7 +43,6 @@ def process_one_signal(samp_id): label_path = list(label_path)[0] print(f"Processing Label_corrected file: {label_path}") - signal_data = utils.read_signal_txt(signal_path) signal_length = len(signal_data) print(f"signal_length: {signal_length}") @@ -50,43 +51,174 @@ def process_one_signal(samp_id): signal_second = signal_length // signal_fs print(f"signal_second: {signal_second}") + # 根据采样率进行截断 + signal_data = signal_data[:signal_second * signal_fs] + # 滤波 # 50Hz陷波滤波器 # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) - resp_data = utils.butterworth(data=signal_data, _type=conf["resp"]["filter_type"], low_cut=conf["resp"]["low_cut"], - high_cut=conf["resp"]["high_cut"], order=conf["resp"]["order"], sample_rate=signal_fs) + print("Applying 50Hz notch filter...") + signal_data = utils.notch_filter(data=signal_data, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) - bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg"]["filter_type"], low_cut=conf["bcg"]["low_cut"], - high_cut=conf["bcg"]["high_cut"], order=conf["bcg"]["order"], sample_rate=signal_fs) + resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs) + resp_fs = conf["resp"]["downsample_fs_1"] + resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs) + resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20) + resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"], + low_cut=conf["resp_filter"]["low_cut"], + high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"], + sample_rate=resp_fs) + print("Begin plotting signal data...") + # fig = plt.figure(figsize=(12, 8)) + # # 绘制三个图raw_data、resp_data_1、resp_data_2 + # ax0 = fig.add_subplot(3, 1, 1) + # ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue') + # ax0.set_title('Raw Signal Data') + # ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) + # ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange') + # ax1.set_title('Resp Data after Average Filtering') + # ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) + # ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green') + # ax2.set_title('Resp Data after Butterworth Filtering') + # plt.tight_layout() + # plt.show() + + bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"], + low_cut=conf["bcg_filter"]["low_cut"], + high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"], + sample_rate=signal_fs) + + # 降采样 + old_resp_fs = resp_fs + resp_fs = conf["resp"]["downsample_fs_2"] + resp_data = utils.downsample_signal_fast(original_signal=resp_data_2, original_fs=old_resp_fs, target_fs=resp_fs) + bcg_fs = conf["bcg"]["downsample_fs"] + bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs) + label_data = utils.read_label_csv(path=label_path) label_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data) - manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[all_samp_disable_df["id"] == samp_id]) + manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[ + all_samp_disable_df["id"] == samp_id]) print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}") # 分析Resp的低幅值区间 - resp_low_amp_conf = getattr(conf, "resp_low_amp", None) + resp_low_amp_conf = conf.get("resp_low_amp", None) if resp_low_amp_conf is not None: - resp_low_amp_mask = signal_method.detect_low_amplitude_signal( + resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal( signal_data=resp_data, - sampling_rate=signal_fs, + sampling_rate=resp_fs, window_size_sec=resp_low_amp_conf["window_size_sec"], stride_sec=resp_low_amp_conf["stride_sec"], amplitude_threshold=resp_low_amp_conf["amplitude_threshold"], merge_gap_sec=resp_low_amp_conf["merge_gap_sec"], min_duration_sec=resp_low_amp_conf["min_duration_sec"] ) + print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}") else: - resp_low_amp_mask = None + resp_low_amp_mask, resp_low_amp_position_list = None, None + print("resp_low_amp_mask is None") # 分析Resp的高幅值伪迹区间 - resp_move + resp_movement_conf = conf.get("resp_movement", None) + if resp_movement_conf is not None: + raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement( + signal_data=resp_data, + sampling_rate=resp_fs, + window_size_sec=resp_movement_conf["window_size_sec"], + stride_sec=resp_movement_conf["stride_sec"], + std_median_multiplier=resp_movement_conf["std_median_multiplier"], + compare_intervals_sec=resp_movement_conf["compare_intervals_sec"], + interval_multiplier=resp_movement_conf["interval_multiplier"], + merge_gap_sec=resp_movement_conf["merge_gap_sec"], + min_duration_sec=resp_movement_conf["min_duration_sec"] + ) + print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}") + else: + resp_movement_mask = None + print("resp_movement_mask is None") + # 分析Resp的幅值突变区间 + if resp_movement_mask is not None: + resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v2( + signal_data=resp_data, + movement_mask=resp_movement_mask, + sampling_rate=resp_fs) + print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}") + else: + resp_amp_change_mask = None + print("amp_change_mask is None") + # 分析Bcg的低幅值区间 + bcg_low_amp_conf = conf.get("bcg_low_amp", None) + if bcg_low_amp_conf is not None: + bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal( + signal_data=bcg_data, + sampling_rate=bcg_fs, + window_size_sec=bcg_low_amp_conf["window_size_sec"], + stride_sec=bcg_low_amp_conf["stride_sec"], + amplitude_threshold=bcg_low_amp_conf["amplitude_threshold"], + merge_gap_sec=bcg_low_amp_conf["merge_gap_sec"], + min_duration_sec=bcg_low_amp_conf["min_duration_sec"] + ) + print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}") + else: + bcg_low_amp_mask, bcg_low_amp_position_list = None, None + print("bcg_low_amp_mask is None") + # 分析Bcg的高幅值伪迹区间 + bcg_movement_conf = conf.get("bcg_movement", None) + if bcg_movement_conf is not None: + raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement( + signal_data=bcg_data, + sampling_rate=bcg_fs, + window_size_sec=bcg_movement_conf["window_size_sec"], + stride_sec=bcg_movement_conf["stride_sec"], + std_median_multiplier=bcg_movement_conf["std_median_multiplier"], + compare_intervals_sec=bcg_movement_conf["compare_intervals_sec"], + interval_multiplier=bcg_movement_conf["interval_multiplier"], + merge_gap_sec=bcg_movement_conf["merge_gap_sec"], + min_duration_sec=bcg_movement_conf["min_duration_sec"] + ) + print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}") + else: + bcg_movement_mask = None + print("bcg_movement_mask is None") + # 分析Bcg的幅值突变区间 + if bcg_movement_mask is not None: + bcg_amp_change_mask, bcg_amp_change_list = signal_method.position_based_sleep_recognition_v2( + signal_data=bcg_data, + movement_mask=bcg_movement_mask, + sampling_rate=bcg_fs) + print(f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}") + else: + bcg_amp_change_mask = None + print("bcg_amp_change_mask is None") + + + # 如果signal_data采样率过,进行降采样 + if signal_fs == 1000: + signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100) + signal_fs = 100 + + draw_tools.draw_signal_with_mask(samp_id=samp_id, + signal_data=signal_data, + signal_fs=signal_fs, + resp_data=resp_data, + resp_fs=resp_fs, + bcg_data=bcg_data, + bcg_fs=bcg_fs, + signal_disable_mask=manual_disable_mask, + resp_low_amp_mask=resp_low_amp_mask, + resp_movement_mask=resp_movement_mask, + resp_change_mask=resp_amp_change_mask, + resp_sa_mask=None, + bcg_low_amp_mask=bcg_low_amp_mask, + bcg_movement_mask=bcg_movement_mask, + bcg_change_mask=bcg_amp_change_mask) diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index dfff364..c30c3c5 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -12,19 +12,37 @@ select_ids: root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS +resp: + downsample_fs_1: 100 + downsample_fs_2: 10 + resp_filter: filter_type: bandpass low_cut: 0.01 high_cut: 0.7 - order: 10 + order: 2 resp_low_amp: - windows_size_sec: 1 - stride_sec: None - amplitude_threshold: 50 + window_size_sec: 1 + stride_sec: + amplitude_threshold: 20 merge_gap_sec: 10 min_duration_sec: 5 +resp_movement: + window_size_sec: 2 + stride_sec: + std_median_multiplier: 4.5 + compare_intervals_sec: + - 30 + - 60 + interval_multiplier: 2.5 + merge_gap_sec: 10 + min_duration_sec: 5 + +bcg: + downsample_fs: 100 + bcg_filter: filter_type: bandpass low_cut: 1 diff --git a/draw_tools/__init__.py b/draw_tools/__init__.py index e69de29..5d4efe2 100644 --- a/draw_tools/__init__.py +++ b/draw_tools/__init__.py @@ -0,0 +1 @@ +from draw_tools.draw_statics import draw_signal_with_mask \ No newline at end of file diff --git a/draw_tools/draw_statics.py b/draw_tools/draw_statics.py index 88f790e..b4e401a 100644 --- a/draw_tools/draw_statics.py +++ b/draw_tools/draw_statics.py @@ -1,6 +1,8 @@ from matplotlib.axes import Axes from matplotlib.gridspec import GridSpec from matplotlib.colors import PowerNorm +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches import seaborn as sns import numpy as np @@ -74,8 +76,7 @@ def draw_ax_confusion_matrix(ax:Axes, confusion_matrix, segment_count_matrix, co ax.text(len(time_labels) + 0.5, len(amp_labels) + 0.5, f"[{int(total_segments)}]{valid_signal_length}\n({total_percent:.2f}%)", ha='center', va='center') -import matplotlib.pyplot as plt -import matplotlib.patches as mpatches + def draw_ax_amp(ax, signal_name, original_times, origin_signal, no_movement_signal, mav_values, movement_position_list, low_amp_position_list, signal_second_length, aml_list=None): @@ -172,4 +173,127 @@ def draw_signal_metrics(bcg_origin_signal, resp_origin_signal, bcg_no_movement_s if show: plt.show() - plt.close() \ No newline at end of file + plt.close() + + +def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, resp_fs, bcg_fs, + signal_disable_mask, resp_low_amp_mask, resp_movement_mask, resp_change_mask, + resp_sa_mask, bcg_low_amp_mask, bcg_movement_mask, bcg_change_mask + ): + # 第一行绘制去工频噪声的原始信号,右侧为不可用区间标记,左侧为信号幅值纵坐标 + # 第二行绘制呼吸分量,右侧低幅值、高幅值、幅值变换标记、SA标签,左侧为呼吸幅值纵坐标 + # 第三行绘制心冲击分量,右侧为低幅值、高幅值、幅值变换标记、,左侧为心冲击幅值纵坐标 + # mask为None,则生成全Nan掩码 + def _none_to_nan_mask(mask, ref): + if mask is None: + return np.full_like(ref, np.nan) + else: + # 将mask中的0替换为nan,1替换为1 + mask = np.where(mask == 0, np.nan, 1) + return mask + + signal_disable_mask = _none_to_nan_mask(signal_disable_mask, signal_data) + resp_low_amp_mask = _none_to_nan_mask(resp_low_amp_mask, resp_data) + resp_movement_mask = _none_to_nan_mask(resp_movement_mask, resp_data) + resp_change_mask = _none_to_nan_mask(resp_change_mask, resp_data) + resp_sa_mask = _none_to_nan_mask(resp_sa_mask, resp_data) + bcg_low_amp_mask = _none_to_nan_mask(bcg_low_amp_mask, bcg_data) + bcg_movement_mask = _none_to_nan_mask(bcg_movement_mask, bcg_data) + bcg_change_mask = _none_to_nan_mask(bcg_change_mask, bcg_data) + + + fig = plt.figure(figsize=(18, 10)) + ax0 = fig.add_subplot(3, 1, 1) + ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue') + ax0.set_title(f'Sample {samp_id} - Raw Signal Data') + ax0.set_ylabel('Amplitude') + # ax0.set_xticklabels([]) + + ax0_twin = ax0.twinx() + ax0_twin.plot(np.linspace(0, len(signal_disable_mask), len(signal_disable_mask)), signal_disable_mask, + color='red', alpha=0.5) + ax0_twin.autoscale(enable=False, axis='y', tight=True) + ax0_twin.set_ylim((-2, 2)) + ax0_twin.set_ylabel('Disable Mask') + ax0_twin.set_yticks([0, 1]) + ax0_twin.set_yticklabels(['Enabled', 'Disabled']) + ax0_twin.grid(False) + ax0_twin.legend(['Disable Mask'], loc='upper right') + + + ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) + ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='orange') + ax1.set_ylabel('Amplitude') + ax1.set_xticklabels([]) + ax1_twin = ax1.twinx() + ax1_twin.plot(np.linspace(0, len(resp_low_amp_mask), len(resp_low_amp_mask)), resp_low_amp_mask*-1, + color='blue', alpha=0.5, label='Low Amplitude Mask') + ax1_twin.plot(np.linspace(0, len(resp_movement_mask), len(resp_movement_mask)), resp_movement_mask*-2, + color='red', alpha=0.5, label='Movement Mask') + ax1_twin.plot(np.linspace(0, len(resp_change_mask), len(resp_change_mask)), resp_change_mask*-3, + color='green', alpha=0.5, label='Amplitude Change Mask') + ax1_twin.plot(np.linspace(0, len(resp_sa_mask), len(resp_sa_mask)), resp_sa_mask, + color='purple', alpha=0.5, label='SA Mask') + ax1_twin.autoscale(enable=False, axis='y', tight=True) + ax1_twin.set_ylim((-4, 5)) + # ax1_twin.set_ylabel('Resp Masks') + # ax1_twin.set_yticks([0, 1]) + # ax1_twin.set_yticklabels(['No', 'Yes']) + ax1_twin.grid(False) + + ax1_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask', 'SA Mask'], loc='upper right') + ax1.set_title(f'Sample {samp_id} - Respiration Component') + + ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) + ax2.plot(np.linspace(0, len(bcg_data) // bcg_fs, len(bcg_data)), bcg_data, color='green') + ax2.set_ylabel('Amplitude') + ax2.set_xlabel('Time (s)') + ax2_twin = ax2.twinx() + ax2_twin.plot(np.linspace(0, len(bcg_low_amp_mask), len(bcg_low_amp_mask)), bcg_low_amp_mask*-1, + color='blue', alpha=0.5, label='Low Amplitude Mask') + ax2_twin.plot(np.linspace(0, len(bcg_movement_mask), len(bcg_movement_mask)), bcg_movement_mask*-2, + color='red', alpha=0.5, label='Movement Mask') + ax2_twin.plot(np.linspace(0, len(bcg_change_mask), len(bcg_change_mask)), bcg_change_mask*-3, + color='green', alpha=0.5, label='Amplitude Change Mask') + ax2_twin.autoscale(enable=False, axis='y', tight=True) + ax2_twin.set_ylim((-4, 2)) + ax2_twin.set_ylabel('BCG Masks') + # ax2_twin.set_yticks([0, 1]) + # ax2_twin.set_yticklabels(['No', 'Yes']) + ax2_twin.grid(False) + ax2_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask'], loc='upper right') + # ax2.set_title(f'Sample {samp_id} - BCG Component') + + ax0_twin._lim_lock = False + ax1_twin._lim_lock = False + ax2_twin._lim_lock = False + + def on_lims_change(event_ax): + if getattr(event_ax, '_lim_lock', False): + return + try: + event_ax._lim_lock = True + + if event_ax == ax0_twin: + # 重新锁定 ax1 的 Y 轴范围 + ax0_twin.set_ylim(-2, 2) + elif event_ax == ax1_twin: + ax1_twin.set_ylim(-3, 5) + elif event_ax == ax2_twin: + ax2_twin.set_ylim(-4, 2) + + finally: + event_ax._lim_lock = False + + + ax0_twin.callbacks.connect('ylim_changed', on_lims_change) + ax1_twin.callbacks.connect('ylim_changed', on_lims_change) + ax2_twin.callbacks.connect('ylim_changed', on_lims_change) + + + plt.tight_layout() + plt.show() + + + + diff --git a/signal_method/__init__.py b/signal_method/__init__.py index 46eac36..a1d61f2 100644 --- a/signal_method/__init__.py +++ b/signal_method/__init__.py @@ -1 +1 @@ -from signal_method.rule_base_event import detect_low_amplitude_signal \ No newline at end of file +from signal_method.rule_base_event import detect_low_amplitude_signal, detect_movement, position_based_sleep_recognition_v2 \ No newline at end of file diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index 8de49da..a8e2480 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -379,7 +379,7 @@ def position_based_sleep_recognition_v1(signal_data, movement_mask, sampling_rat return position_changes, position_change_times -def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rate=100, window_size_sec=30): +def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rate=100): """ :param signal_data: @@ -445,4 +445,4 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat else: position_changes.append(0) # 0表示不存在姿势变化 - return position_changes, position_change_times + return np.array(position_changes), position_change_times diff --git a/utils/__init__.py b/utils/__init__.py index faaebaf..ae2ee06 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,4 +1,4 @@ from utils.HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel from utils.operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask from utils.event_map import E2N -from utils.signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter \ No newline at end of file +from utils.signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel \ No newline at end of file diff --git a/utils/signal_process.py b/utils/signal_process.py index dea0ea1..c657d3f 100644 --- a/utils/signal_process.py +++ b/utils/signal_process.py @@ -21,6 +21,22 @@ def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10,sample_rate=100 raise ValueError("Please choose a type of fliter") +def bessel(data, _type, low_cut=0.0, high_cut=0.0, order=4, sample_rate=1000): + if _type == "lowpass": # 低通滤波处理 + b, a = signal.bessel(order, low_cut / (sample_rate * 0.5), btype='lowpass', analog=False, norm='mag') + return signal.filtfilt(b, a, np.array(data)) + elif _type == "bandpass": # 带通滤波处理 + low = low_cut / (sample_rate * 0.5) + high = high_cut / (sample_rate * 0.5) + b, a = signal.bessel(order, [low, high], btype='bandpass', analog=False, norm='mag') + return signal.filtfilt(b, a, np.array(data)) + elif _type == "highpass": # 高通滤波处理 + b, a = signal.bessel(order, high_cut / (sample_rate * 0.5), btype='highpass', analog=False, norm='mag') + return signal.filtfilt(b, a, np.array(data)) + else: # 警告,滤波器类型必须有 + raise ValueError("Please choose a type of fliter") + + @timing_decorator() def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000): """ @@ -75,8 +91,8 @@ def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=1 return downsampled_signal @timing_decorator() -def average_filter(raw_data, sample_rate, window_size=20): - kernel = np.ones(window_size * sample_rate) / (window_size * sample_rate) +def average_filter(raw_data, sample_rate, window_size_sec=20): + kernel = np.ones(window_size_sec * sample_rate) / (window_size_sec * sample_rate) filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect') convolve_filter_signal = raw_data - filtered return convolve_filter_signal From 965f88843ae340f8a02bd312d40a746faf114d64 Mon Sep 17 00:00:00 2001 From: marques Date: Thu, 30 Oct 2025 15:46:08 +0800 Subject: [PATCH 11/28] Refactor signal processing configurations and improve mask generation logic --- HYS_process.py | 34 ++++++---------------------- dataset_config/HYS_config.yaml | 38 ++++++++++++++++++++++---------- draw_tools/draw_statics.py | 6 ++--- signal_method/rule_base_event.py | 8 +++---- utils/operation_tools.py | 4 ++-- 5 files changed, 41 insertions(+), 49 deletions(-) diff --git a/HYS_process.py b/HYS_process.py index 85fb7a3..edb4065 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -28,7 +28,7 @@ import numpy as np import signal_method import os from matplotlib import pyplot as plt -os.environ['DISPLAY'] = "localhost:10.0" +os.environ['DISPLAY'] = "localhost:11.0" def process_one_signal(samp_id): signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) @@ -98,7 +98,7 @@ def process_one_signal(samp_id): bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs) label_data = utils.read_label_csv(path=label_path) - label_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data) + event_mask, score_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data) manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[ all_samp_disable_df["id"] == samp_id]) @@ -110,11 +110,7 @@ def process_one_signal(samp_id): resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal( signal_data=resp_data, sampling_rate=resp_fs, - window_size_sec=resp_low_amp_conf["window_size_sec"], - stride_sec=resp_low_amp_conf["stride_sec"], - amplitude_threshold=resp_low_amp_conf["amplitude_threshold"], - merge_gap_sec=resp_low_amp_conf["merge_gap_sec"], - min_duration_sec=resp_low_amp_conf["min_duration_sec"] + **resp_low_amp_conf ) print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}") else: @@ -127,13 +123,7 @@ def process_one_signal(samp_id): raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement( signal_data=resp_data, sampling_rate=resp_fs, - window_size_sec=resp_movement_conf["window_size_sec"], - stride_sec=resp_movement_conf["stride_sec"], - std_median_multiplier=resp_movement_conf["std_median_multiplier"], - compare_intervals_sec=resp_movement_conf["compare_intervals_sec"], - interval_multiplier=resp_movement_conf["interval_multiplier"], - merge_gap_sec=resp_movement_conf["merge_gap_sec"], - min_duration_sec=resp_movement_conf["min_duration_sec"] + **resp_movement_conf ) print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}") else: @@ -159,11 +149,7 @@ def process_one_signal(samp_id): bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal( signal_data=bcg_data, sampling_rate=bcg_fs, - window_size_sec=bcg_low_amp_conf["window_size_sec"], - stride_sec=bcg_low_amp_conf["stride_sec"], - amplitude_threshold=bcg_low_amp_conf["amplitude_threshold"], - merge_gap_sec=bcg_low_amp_conf["merge_gap_sec"], - min_duration_sec=bcg_low_amp_conf["min_duration_sec"] + **bcg_low_amp_conf ) print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}") else: @@ -175,13 +161,7 @@ def process_one_signal(samp_id): raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement( signal_data=bcg_data, sampling_rate=bcg_fs, - window_size_sec=bcg_movement_conf["window_size_sec"], - stride_sec=bcg_movement_conf["stride_sec"], - std_median_multiplier=bcg_movement_conf["std_median_multiplier"], - compare_intervals_sec=bcg_movement_conf["compare_intervals_sec"], - interval_multiplier=bcg_movement_conf["interval_multiplier"], - merge_gap_sec=bcg_movement_conf["merge_gap_sec"], - min_duration_sec=bcg_movement_conf["min_duration_sec"] + **bcg_movement_conf ) print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}") else: @@ -215,7 +195,7 @@ def process_one_signal(samp_id): resp_low_amp_mask=resp_low_amp_mask, resp_movement_mask=resp_movement_mask, resp_change_mask=resp_amp_change_mask, - resp_sa_mask=None, + resp_sa_mask=event_mask, bcg_low_amp_mask=bcg_low_amp_mask, bcg_movement_mask=bcg_movement_mask, bcg_change_mask=bcg_amp_change_mask) diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index c30c3c5..2d50926 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -20,25 +20,25 @@ resp_filter: filter_type: bandpass low_cut: 0.01 high_cut: 0.7 - order: 2 + order: 3 resp_low_amp: - window_size_sec: 1 + window_size_sec: 30 stride_sec: - amplitude_threshold: 20 - merge_gap_sec: 10 - min_duration_sec: 5 + amplitude_threshold: 5 + merge_gap_sec: 180 + min_duration_sec: 30 resp_movement: - window_size_sec: 2 - stride_sec: - std_median_multiplier: 4.5 + window_size_sec: 30 + stride_sec: 5 + std_median_multiplier: 5 compare_intervals_sec: - - 30 - 60 - interval_multiplier: 2.5 - merge_gap_sec: 10 - min_duration_sec: 5 + - 90 + interval_multiplier: 3.5 + merge_gap_sec: 45 + min_duration_sec: 10 bcg: downsample_fs: 100 @@ -49,3 +49,17 @@ bcg_filter: high_cut: 10 order: 10 +bcg_low_amp: + window_size_sec: 1 + stride_sec: + amplitude_threshold: 10 + merge_gap_sec: 20 + min_duration_sec: 3 + + +bcg_movement: + window_size_sec: 2 + stride_sec: + merge_gap_sec: 20 + min_duration_sec: 4 + diff --git a/draw_tools/draw_statics.py b/draw_tools/draw_statics.py index b4e401a..ad74a3e 100644 --- a/draw_tools/draw_statics.py +++ b/draw_tools/draw_statics.py @@ -188,8 +188,8 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, if mask is None: return np.full_like(ref, np.nan) else: - # 将mask中的0替换为nan,1替换为1 - mask = np.where(mask == 0, np.nan, 1) + # 将mask中的0替换为nan,其他的保持 + mask = np.where(mask == 0, np.nan, mask) return mask signal_disable_mask = _none_to_nan_mask(signal_disable_mask, signal_data) @@ -224,7 +224,7 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='orange') ax1.set_ylabel('Amplitude') - ax1.set_xticklabels([]) + # ax1.set_xticklabels([]) ax1_twin = ax1.twinx() ax1_twin.plot(np.linspace(0, len(resp_low_amp_mask), len(resp_low_amp_mask)), resp_low_amp_mask*-1, color='blue', alpha=0.5, label='Low Amplitude Mask') diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index a8e2480..eb8db1d 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -421,7 +421,7 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat energy = np.sum(np.abs(signal_data[start:end] ** 2)) segment_average_energy.append(energy) - position_changes = [] + position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int) position_change_times = [] # 判断是否存在显著变化 (可根据实际情况调整阈值) threshold_amplitude = 0.1 # 幅值变化阈值 @@ -440,9 +440,7 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat if significant_change: # 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示 - position_changes.append(1) + position_changes[movement_start[i - 1]:movement_end[i - 1]] = 1 position_change_times.append((movement_start[i - 1], movement_end[i - 1])) - else: - position_changes.append(0) # 0表示不存在姿势变化 - return np.array(position_changes), position_change_times + return position_changes, position_change_times diff --git a/utils/operation_tools.py b/utils/operation_tools.py index f775d73..75b6e75 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -184,12 +184,12 @@ def load_dataset_conf(yaml_path): def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray: - disable_mask = np.ones(signal_second, dtype=int) + disable_mask = np.zeros(signal_second, dtype=int) for _, row in disable_df.iterrows(): start = row["start"] end = row["end"] - disable_mask[start:end] = 0 + disable_mask[start:end] = 1 return disable_mask From 7e3459d9f1e6610d45447376801c56fbf7cbbe5e Mon Sep 17 00:00:00 2001 From: marques Date: Thu, 30 Oct 2025 16:05:53 +0800 Subject: [PATCH 12/28] Adjust Y-axis limits in draw_statics.py and update resp_movement parameters in HYS_config.yaml --- dataset_config/HYS_config.yaml | 6 +++--- draw_tools/draw_statics.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index 2d50926..bc07225 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -30,15 +30,15 @@ resp_low_amp: min_duration_sec: 30 resp_movement: - window_size_sec: 30 + window_size_sec: 20 stride_sec: 5 std_median_multiplier: 5 compare_intervals_sec: - 60 - 90 interval_multiplier: 3.5 - merge_gap_sec: 45 - min_duration_sec: 10 + merge_gap_sec: 30 + min_duration_sec: 5 bcg: downsample_fs: 100 diff --git a/draw_tools/draw_statics.py b/draw_tools/draw_statics.py index ad74a3e..adbc3d9 100644 --- a/draw_tools/draw_statics.py +++ b/draw_tools/draw_statics.py @@ -278,7 +278,7 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, # 重新锁定 ax1 的 Y 轴范围 ax0_twin.set_ylim(-2, 2) elif event_ax == ax1_twin: - ax1_twin.set_ylim(-3, 5) + ax1_twin.set_ylim(-4, 5) elif event_ax == ax2_twin: ax2_twin.set_ylim(-4, 2) From c4f163eacc36d2c2e53ca14e1298c345a8023739 Mon Sep 17 00:00:00 2001 From: marques Date: Thu, 30 Oct 2025 21:41:37 +0800 Subject: [PATCH 13/28] Update signal processing configurations and improve event mask handling --- HYS_process.py | 16 ++++++++-------- dataset_config/HYS_config.yaml | 4 ++-- signal_method/rule_base_event.py | 16 +++++----------- utils/__init__.py | 3 ++- utils/operation_tools.py | 6 ++++++ utils/signal_process.py | 4 ++-- 6 files changed, 25 insertions(+), 24 deletions(-) diff --git a/HYS_process.py b/HYS_process.py index edb4065..4f53640 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -28,7 +28,7 @@ import numpy as np import signal_method import os from matplotlib import pyplot as plt -os.environ['DISPLAY'] = "localhost:11.0" +os.environ['DISPLAY'] = "localhost:10.0" def process_one_signal(samp_id): signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) @@ -112,7 +112,7 @@ def process_one_signal(samp_id): sampling_rate=resp_fs, **resp_low_amp_conf ) - print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}") + print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}, count_low_amp_positions: {len(resp_low_amp_position_list)}") else: resp_low_amp_mask, resp_low_amp_position_list = None, None print("resp_low_amp_mask is None") @@ -125,7 +125,7 @@ def process_one_signal(samp_id): sampling_rate=resp_fs, **resp_movement_conf ) - print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}") + print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") else: resp_movement_mask = None print("resp_movement_mask is None") @@ -137,7 +137,7 @@ def process_one_signal(samp_id): signal_data=resp_data, movement_mask=resp_movement_mask, sampling_rate=resp_fs) - print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}") + print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}") else: resp_amp_change_mask = None print("amp_change_mask is None") @@ -151,7 +151,7 @@ def process_one_signal(samp_id): sampling_rate=bcg_fs, **bcg_low_amp_conf ) - print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}") + print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}, count_low_amp_positions: {len(bcg_low_amp_position_list)}") else: bcg_low_amp_mask, bcg_low_amp_position_list = None, None print("bcg_low_amp_mask is None") @@ -163,7 +163,7 @@ def process_one_signal(samp_id): sampling_rate=bcg_fs, **bcg_movement_conf ) - print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}") + print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}, count_movement_positions: {len(bcg_movement_position_list)}") else: bcg_movement_mask = None print("bcg_movement_mask is None") @@ -173,7 +173,7 @@ def process_one_signal(samp_id): signal_data=bcg_data, movement_mask=bcg_movement_mask, sampling_rate=bcg_fs) - print(f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}") + print(f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}, count_amp_change_positions: {len(bcg_amp_change_list)}") else: bcg_amp_change_mask = None print("bcg_amp_change_mask is None") @@ -220,4 +220,4 @@ if __name__ == '__main__': all_samp_disable_df = utils.read_disable_excel(disable_df_path) - process_one_signal(select_ids[0]) + process_one_signal(select_ids[2]) diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index bc07225..1c70389 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -25,7 +25,7 @@ resp_filter: resp_low_amp: window_size_sec: 30 stride_sec: - amplitude_threshold: 5 + amplitude_threshold: 3 merge_gap_sec: 180 min_duration_sec: 30 @@ -52,7 +52,7 @@ bcg_filter: bcg_low_amp: window_size_sec: 1 stride_sec: - amplitude_threshold: 10 + amplitude_threshold: 5 merge_gap_sec: 20 min_duration_sec: 3 diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index eb8db1d..1253f06 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -1,6 +1,6 @@ from utils.operation_tools import timing_decorator import numpy as np -from utils.operation_tools import merge_short_gaps, remove_short_durations +from utils import merge_short_gaps, remove_short_durations, event_mask_2_list @timing_decorator() @@ -159,14 +159,10 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No movement_mask[start:end+1] = 1 # raw体动起止位置 [[start, end], [start, end], ...] - raw_movement_start = np.where(np.diff(np.concatenate([[0], raw_movement_mask])) == 1)[0] - raw_movement_end = np.where(np.diff(np.concatenate([raw_movement_mask, [0]])) == -1)[0] + 1 - raw_movement_position_list = [[start, end] for start, end in zip(raw_movement_start, raw_movement_end)] + raw_movement_position_list = event_mask_2_list(raw_movement_mask) # merge体动起止位置 [[start, end], [start, end], ...] - movement_start = np.where(np.diff(np.concatenate([[0], movement_mask])) == 1)[0] - movement_end = np.where(np.diff(np.concatenate([movement_mask, [0]])) == -1)[0] + 1 - movement_position_list = [[start, end] for start, end in zip(movement_start, movement_end)] + movement_position_list = event_mask_2_list(movement_mask) return raw_movement_mask, movement_mask, raw_movement_position_list, movement_position_list @@ -201,7 +197,7 @@ def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, s stride_samples = int(stride_sec * sampling_rate) # 确保步长至少为1 - stride_samples = max(1, stride_samples) + stride_samples = max(sampling_rate, stride_samples) # 处理信号边界,使用反射填充 pad_size = window_samples // 2 @@ -255,9 +251,7 @@ def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, s low_amplitude_mask = low_amplitude_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate] # 低幅值状态起止位置 [[start, end], [start, end], ...] - low_amplitude_start = np.where(np.diff(np.concatenate([[0], low_amplitude_mask])) == 1)[0] - low_amplitude_end = np.where(np.diff(np.concatenate([low_amplitude_mask, [0]])) == -1)[0] - low_amplitude_position_list = [[start, end] for start, end in zip(low_amplitude_start, low_amplitude_end)] + low_amplitude_position_list = event_mask_2_list(low_amplitude_mask) return low_amplitude_mask, low_amplitude_position_list diff --git a/utils/__init__.py b/utils/__init__.py index ae2ee06..51a18ba 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,4 +1,5 @@ from utils.HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel -from utils.operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask +from utils.operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list +from utils.operation_tools import merge_short_gaps, remove_short_durations from utils.event_map import E2N from utils.signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel \ No newline at end of file diff --git a/utils/operation_tools.py b/utils/operation_tools.py index 75b6e75..4feabdf 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -206,3 +206,9 @@ def generate_event_mask(signal_second: int, event_df): score_mask[start:end] = row["score"] return event_mask, score_mask + +def event_mask_2_list(mask): + mask_start = np.where(np.diff(mask, append=0) == 1)[0] + mask_end = np.where(np.diff(mask, append=0) == -1)[0] + 1 + event_list =[[start, end] for start, end in zip(mask_start, mask_end)] + return event_list \ No newline at end of file diff --git a/utils/signal_process.py b/utils/signal_process.py index c657d3f..e690c33 100644 --- a/utils/signal_process.py +++ b/utils/signal_process.py @@ -4,8 +4,7 @@ from scipy import signal, ndimage @timing_decorator() -def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10,sample_rate=1000): - +def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10, sample_rate=1000): if _type == "lowpass": # 低通滤波处理 sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos') return signal.sosfiltfilt(sos, np.array(data)) @@ -90,6 +89,7 @@ def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=1 return downsampled_signal + @timing_decorator() def average_filter(raw_data, sample_rate, window_size_sec=20): kernel = np.ones(window_size_sec * sample_rate) / (window_size_sec * sample_rate) From 998890377b93219e551c23a1769a0dd58e525547 Mon Sep 17 00:00:00 2001 From: marques Date: Wed, 5 Nov 2025 10:29:24 +0800 Subject: [PATCH 14/28] Update HYS_config.yaml and HYS_process.py for signal processing parameters and add movement revision function --- HYS_process.py | 36 +++++++++++++++++++++++++++----- dataset_config/HYS_config.yaml | 12 +++++------ signal_method/rule_base_event.py | 35 +++++++++++++++++++++++++++---- 3 files changed, 68 insertions(+), 15 deletions(-) diff --git a/HYS_process.py b/HYS_process.py index 4f53640..8dec227 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -43,8 +43,8 @@ def process_one_signal(samp_id): label_path = list(label_path)[0] print(f"Processing Label_corrected file: {label_path}") - signal_data = utils.read_signal_txt(signal_path) - signal_length = len(signal_data) + signal_data_raw = utils.read_signal_txt(signal_path) + signal_length = len(signal_data_raw) print(f"signal_length: {signal_length}") signal_fs = int(signal_path.stem.split("_")[-1]) print(f"signal_fs: {signal_fs}") @@ -52,13 +52,13 @@ def process_one_signal(samp_id): print(f"signal_second: {signal_second}") # 根据采样率进行截断 - signal_data = signal_data[:signal_second * signal_fs] + signal_data_raw = signal_data_raw[:signal_second * signal_fs] # 滤波 # 50Hz陷波滤波器 # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) print("Applying 50Hz notch filter...") - signal_data = utils.notch_filter(data=signal_data, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) + signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs) resp_fs = conf["resp"]["downsample_fs_1"] @@ -130,6 +130,30 @@ def process_one_signal(samp_id): resp_movement_mask = None print("resp_movement_mask is None") + if resp_movement_mask is not None: + # 左右翻转resp_data + reverse_resp_data = resp_data[::-1] + _, resp_movement_mask_reverse, _, resp_movement_position_list_reverse = signal_method.detect_movement( + signal_data=reverse_resp_data, + sampling_rate=resp_fs, + **resp_movement_conf + ) + print(f"resp_movement_mask_reverse_shape: {resp_movement_mask_reverse.shape}, num_movement_reverse: {np.sum(resp_movement_mask_reverse == 1)}, count_movement_positions_reverse: {len(resp_movement_position_list_reverse)}") + # 将resp_movement_mask_reverse翻转回来 + resp_movement_mask_reverse = resp_movement_mask_reverse[::-1] + else: + resp_movement_mask_reverse = None + print("resp_movement_mask_reverse is None") + + + # 取交集 + if resp_movement_mask is not None and resp_movement_mask_reverse is not None: + combined_resp_movement_mask = np.logical_and(resp_movement_mask, resp_movement_mask_reverse).astype(int) + resp_movement_mask = combined_resp_movement_mask + print(f"combined_resp_movement_mask_shape: {combined_resp_movement_mask.shape}, num_combined_movement: {np.sum(combined_resp_movement_mask == 1)}") + else: + print("combined_resp_movement_mask is None") + # 分析Resp的幅值突变区间 if resp_movement_mask is not None: @@ -143,6 +167,7 @@ def process_one_signal(samp_id): print("amp_change_mask is None") + # 分析Bcg的低幅值区间 bcg_low_amp_conf = conf.get("bcg_low_amp", None) if bcg_low_amp_conf is not None: @@ -182,6 +207,7 @@ def process_one_signal(samp_id): # 如果signal_data采样率过,进行降采样 if signal_fs == 1000: signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100) + signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs, target_fs=100) signal_fs = 100 draw_tools.draw_signal_with_mask(samp_id=samp_id, @@ -220,4 +246,4 @@ if __name__ == '__main__': all_samp_disable_df = utils.read_disable_excel(disable_df_path) - process_one_signal(select_ids[2]) + process_one_signal(select_ids[5]) diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index 1c70389..605c010 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -26,19 +26,19 @@ resp_low_amp: window_size_sec: 30 stride_sec: amplitude_threshold: 3 - merge_gap_sec: 180 - min_duration_sec: 30 + merge_gap_sec: 60 + min_duration_sec: 60 resp_movement: window_size_sec: 20 - stride_sec: 5 - std_median_multiplier: 5 + stride_sec: 1 + std_median_multiplier: 3.5 compare_intervals_sec: - 60 - 90 interval_multiplier: 3.5 merge_gap_sec: 30 - min_duration_sec: 5 + min_duration_sec: 2 bcg: downsample_fs: 100 @@ -52,7 +52,7 @@ bcg_filter: bcg_low_amp: window_size_sec: 1 stride_sec: - amplitude_threshold: 5 + amplitude_threshold: 8 merge_gap_sec: 20 min_duration_sec: 3 diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index 1253f06..76b6302 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -168,6 +168,24 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No + +def movement_revise(signal_data, sampling_rate, movement_mask, std_median_multiplier=4.5): + """ + 基于标准差对已有体动掩码进行修正。 用于大尺度的体动检测后的位置修正 + + 参数: + - signal_data: numpy array,输入的信号数据 + - sampling_rate: int,信号的采样率(Hz) + - movement_mask: numpy array,已有的体动掩码(1表示体动,0表示睡眠) + - std_median_multiplier: float,标准差中位数的乘数阈值,默认值为 4.5 + + 返回: + - revised_movement_mask: numpy array,修正后的体动掩码 + """ + pass + + + @timing_decorator() def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, stride_sec=None, amplitude_threshold=50, merge_gap_sec=10, min_duration_sec=5): @@ -394,6 +412,15 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat segment_average_amplitude = [] segment_average_energy = [] + signal_data_no_movement = signal_data.copy() + for start, end in zip(movement_start, movement_end): + signal_data_no_movement[start * sampling_rate:end * sampling_rate] = np.nan + + # from matplotlib import pyplot as plt + # plt.plot(signal_data, alpha=0.3, color='gray') + # plt.plot(signal_data_no_movement, color='blue', linewidth=1) + # plt.show() + for start, end in zip(valid_starts, valid_ends): start *= sampling_rate end *= sampling_rate @@ -407,12 +434,12 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat mav_calc_window_sec * sampling_rate) # 计算每个片段的幅值指标 - mav = np.mean( - np.max(signal_data[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) - np.mean( - np.min(signal_data[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) + mav = np.nanmean( + np.nanmax(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) - np.nanmean( + np.nanmin(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) segment_average_amplitude.append(mav) - energy = np.sum(np.abs(signal_data[start:end] ** 2)) + energy = np.nansum(np.abs(signal_data_no_movement[start:end] ** 2)) segment_average_energy.append(energy) position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int) From 2a2604a3237e7fdf52bcc9e35411f131d3ee360f Mon Sep 17 00:00:00 2001 From: marques Date: Thu, 6 Nov 2025 17:15:14 +0800 Subject: [PATCH 15/28] Refactor imports in __init__.py, enhance resp_movement handling in HYS_process.py, and update HYS_config.yaml for movement revision parameters --- HYS_process.py | 34 +++++-------- dataset_config/HYS_config.yaml | 14 ++++-- draw_tools/__init__.py | 2 +- draw_tools/draw_statics.py | 5 +- signal_method/__init__.py | 4 +- signal_method/rule_base_event.py | 84 ++++++++++++++++++++++++-------- signal_method/time_metrics.py | 12 +++-- utils/__init__.py | 11 +++-- utils/operation_tools.py | 44 +++++++++++------ 9 files changed, 140 insertions(+), 70 deletions(-) diff --git a/HYS_process.py b/HYS_process.py index 8dec227..9f66a51 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -28,7 +28,7 @@ import numpy as np import signal_method import os from matplotlib import pyplot as plt -os.environ['DISPLAY'] = "localhost:10.0" +os.environ['DISPLAY'] = "localhost:14.0" def process_one_signal(samp_id): signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) @@ -127,33 +127,23 @@ def process_one_signal(samp_id): ) print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") else: - resp_movement_mask = None + resp_movement_mask, resp_movement_position_list = None, None print("resp_movement_mask is None") - if resp_movement_mask is not None: - # 左右翻转resp_data - reverse_resp_data = resp_data[::-1] - _, resp_movement_mask_reverse, _, resp_movement_position_list_reverse = signal_method.detect_movement( - signal_data=reverse_resp_data, + resp_movement_revise_conf = conf.get("resp_movement_revise", None) + if resp_movement_mask is not None and resp_movement_revise_conf is not None: + resp_movement_mask, resp_movement_position_list = signal_method.movement_revise( + signal_data=resp_data, + movement_mask=resp_movement_mask, + movement_list=resp_movement_position_list, sampling_rate=resp_fs, - **resp_movement_conf + **resp_movement_revise_conf ) - print(f"resp_movement_mask_reverse_shape: {resp_movement_mask_reverse.shape}, num_movement_reverse: {np.sum(resp_movement_mask_reverse == 1)}, count_movement_positions_reverse: {len(resp_movement_position_list_reverse)}") - # 将resp_movement_mask_reverse翻转回来 - resp_movement_mask_reverse = resp_movement_mask_reverse[::-1] + print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}") else: - resp_movement_mask_reverse = None - print("resp_movement_mask_reverse is None") + print("resp_movement_mask revise is skipped") - # 取交集 - if resp_movement_mask is not None and resp_movement_mask_reverse is not None: - combined_resp_movement_mask = np.logical_and(resp_movement_mask, resp_movement_mask_reverse).astype(int) - resp_movement_mask = combined_resp_movement_mask - print(f"combined_resp_movement_mask_shape: {combined_resp_movement_mask.shape}, num_combined_movement: {np.sum(combined_resp_movement_mask == 1)}") - else: - print("combined_resp_movement_mask is None") - # 分析Resp的幅值突变区间 if resp_movement_mask is not None: @@ -246,4 +236,4 @@ if __name__ == '__main__': all_samp_disable_df = utils.read_disable_excel(disable_df_path) - process_one_signal(select_ids[5]) + process_one_signal(select_ids[0]) diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index 605c010..50e0c93 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -32,13 +32,21 @@ resp_low_amp: resp_movement: window_size_sec: 20 stride_sec: 1 - std_median_multiplier: 3.5 + std_median_multiplier: 5 compare_intervals_sec: - 60 - - 90 + - 120 + - 180 interval_multiplier: 3.5 merge_gap_sec: 30 - min_duration_sec: 2 + min_duration_sec: 1 + +resp_movement_revise: + up_interval_multiplier: 3 + down_interval_multiplier: 1.5 + compare_intervals_sec: 30 + merge_gap_sec: 10 + min_duration_sec: 1 bcg: downsample_fs: 100 diff --git a/draw_tools/__init__.py b/draw_tools/__init__.py index 5d4efe2..281cc34 100644 --- a/draw_tools/__init__.py +++ b/draw_tools/__init__.py @@ -1 +1 @@ -from draw_tools.draw_statics import draw_signal_with_mask \ No newline at end of file +from .draw_statics import draw_signal_with_mask \ No newline at end of file diff --git a/draw_tools/draw_statics.py b/draw_tools/draw_statics.py index adbc3d9..f44cde0 100644 --- a/draw_tools/draw_statics.py +++ b/draw_tools/draw_statics.py @@ -222,7 +222,10 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) - ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='orange') + ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='gray', alpha=0.5) + resp_data_no_movement = resp_data.copy() + resp_data_no_movement[resp_movement_mask.repeat(int(resp_fs)) == 1] = np.nan + ax1.plot(np.linspace(0, len(resp_data_no_movement) // resp_fs, len(resp_data_no_movement)), resp_data_no_movement, color='orange') ax1.set_ylabel('Amplitude') # ax1.set_xticklabels([]) ax1_twin = ax1.twinx() diff --git a/signal_method/__init__.py b/signal_method/__init__.py index a1d61f2..ab44ea9 100644 --- a/signal_method/__init__.py +++ b/signal_method/__init__.py @@ -1 +1,3 @@ -from signal_method.rule_base_event import detect_low_amplitude_signal, detect_movement, position_based_sleep_recognition_v2 \ No newline at end of file +from .rule_base_event import detect_low_amplitude_signal, detect_movement, position_based_sleep_recognition_v2 +from .rule_base_event import movement_revise +from .time_metrics import calc_mav \ No newline at end of file diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index 76b6302..072c375 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -1,6 +1,7 @@ from utils.operation_tools import timing_decorator import numpy as np -from utils import merge_short_gaps, remove_short_durations, event_mask_2_list +from utils import merge_short_gaps, remove_short_durations, event_mask_2_list, collect_values +from signal_method.time_metrics import calc_mav @timing_decorator() @@ -90,16 +91,16 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No # else: # valid_std = original_window_std - valid_std = original_window_std ##20250418新修改 + valid_std = original_window_std ##20250418新修改 - #---------------------- 方法一:基于STD的体动判定 ----------------------# + # ---------------------- 方法一:基于STD的体动判定 ----------------------# # 计算所有有效窗口标准差的中位数 median_std = np.median(valid_std) # 当窗口标准差大于中位数的倍数,判定为体动状态 - std_movement = np.where(original_window_std > median_std * std_median_multiplier, 1, 0) + std_movement = np.where((original_window_std > (median_std * std_median_multiplier)), 1, 0) - #------------------ 方法二:基于前后信号幅值变化的体动判定 ------------------# + # ------------------ 方法二:基于前后信号幅值变化的体动判定 ------------------# amplitude_movement = np.zeros(num_original_windows, dtype=int) # 定义基于时间粒度的比较间隔索引 @@ -146,7 +147,6 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No raw_movement_mask = raw_movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate] movement_mask = movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate] - # 比较剔除的体动,如果被剔除的体动所在区域有高于3std的幅值,则不剔除 removed_movement_mask = (raw_movement_mask - movement_mask) > 0 removed_movement_start = np.where(np.diff(np.concatenate([[0], removed_movement_mask])) == 1)[0] @@ -155,8 +155,8 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No for start, end in zip(removed_movement_start, removed_movement_end): # print(start ,end) # 计算剔除的体动区域的幅值 - if np.nanmax(signal_data[start*sampling_rate:(end+1)*sampling_rate]) > median_std * std_median_multiplier: - movement_mask[start:end+1] = 1 + if np.nanmax(signal_data[start * sampling_rate:(end + 1) * sampling_rate]) > median_std * std_median_multiplier: + movement_mask[start:end + 1] = 1 # raw体动起止位置 [[start, end], [start, end], ...] raw_movement_position_list = event_mask_2_list(raw_movement_mask) @@ -164,25 +164,70 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No # merge体动起止位置 [[start, end], [start, end], ...] movement_position_list = event_mask_2_list(movement_mask) - return raw_movement_mask, movement_mask, raw_movement_position_list, movement_position_list + return raw_movement_mask, movement_mask, raw_movement_position_list, movement_position_list - - -def movement_revise(signal_data, sampling_rate, movement_mask, std_median_multiplier=4.5): +def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up_interval_multiplier: float, + down_interval_multiplier: float, compare_intervals_sec, merge_gap_sec, min_duration_sec): """ - 基于标准差对已有体动掩码进行修正。 用于大尺度的体动检测后的位置修正 + 基于标准差对已有体动掩码进行修正。 用于大尺度的体动检测后的位置精细修正 参数: - signal_data: numpy array,输入的信号数据 - sampling_rate: int,信号的采样率(Hz) - movement_mask: numpy array,已有的体动掩码(1表示体动,0表示睡眠) - - std_median_multiplier: float,标准差中位数的乘数阈值,默认值为 4.5 返回: - revised_movement_mask: numpy array,修正后的体动掩码 """ - pass + window_size = sampling_rate + stride_size = sampling_rate + + time_points = np.arange(len(signal_data)) + + compare_size = int(compare_intervals_sec // (stride_size / sampling_rate)) + + _, mav = calc_mav(signal_data, movement_mask=None, low_amp_mask=None, sampling_rate=sampling_rate, + window_second=2, step_second=1, + inner_window_second=1) + + # 往左右两边取compare_size个点的mav,取平均值 + for start, end in movement_list: + left_values = collect_values(arr=mav, index=start - 1, step=-1, limit=compare_size, mask=movement_mask) + right_values = collect_values(arr=mav, index=end + 5, step=1, limit=compare_size, mask=movement_mask) + left_value_metrics = np.median(left_values) if len(left_values) > 0 else 0 + right_value_metrics = np.median(right_values) if len(right_values) > 0 else 0 + if left_value_metrics == 0: + value_metrics = right_value_metrics + elif right_value_metrics == 0: + value_metrics = left_value_metrics + else: + value_metrics = np.mean([left_value_metrics, right_value_metrics]) + + # 逐秒遍历mav,判断是否需要修正 + # print(f"Revising movement from index {start} to {end}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") + for i in range(start, end + 5): + # print(f"Index {i}, mav: {mav[i]:.2f}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") + if mav[i] > (value_metrics * up_interval_multiplier): + movement_mask[i] = 1 + # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * up_interval_multiplier:.2f}") + elif mav[i] < (value_metrics * down_interval_multiplier): + movement_mask[i] = 0 + # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * down_interval_multiplier:.2f}") + # else: + # print(f"No revision at index {i}, mav: {mav[i]:.2f}, up_threshold: {value_mean * up_interval_multiplier:.2f}, down_threshold: {value_mean * down_interval_multiplier:.2f}") + # + # 如果需要合并间隔小的体动状态 + if merge_gap_sec > 0: + movement_mask = merge_short_gaps(movement_mask, time_points, merge_gap_sec) + + # 如果需要移除短时体动状态 + if min_duration_sec > 0: + movement_mask = remove_short_durations(movement_mask, time_points, min_duration_sec) + + movement_list = event_mask_2_list(movement_mask) + return movement_mask, movement_list + @@ -335,10 +380,10 @@ def position_based_sleep_recognition_v1(signal_data, movement_mask, sampling_rat # 新的end - start确保为200的整数倍 if (left_end - left_start) % (mav_calc_window_sec * sampling_rate) != 0: left_end = left_start + ((left_end - left_start) // (mav_calc_window_sec * sampling_rate)) * ( - mav_calc_window_sec * sampling_rate) + mav_calc_window_sec * sampling_rate) if (right_end - right_start) % (mav_calc_window_sec * sampling_rate) != 0: right_end = right_start + ((right_end - right_start) // (mav_calc_window_sec * sampling_rate)) * ( - mav_calc_window_sec * sampling_rate) + mav_calc_window_sec * sampling_rate) # 计算每个片段的幅值指标 left_mav = np.mean(np.max(signal_data[left_start:left_end].reshape(-1, mav_calc_window_sec * sampling_rate), @@ -431,11 +476,12 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat # 新的end - start确保为200的整数倍 if (end - start) % (mav_calc_window_sec * sampling_rate) != 0: end = start + ((end - start) // (mav_calc_window_sec * sampling_rate)) * ( - mav_calc_window_sec * sampling_rate) + mav_calc_window_sec * sampling_rate) # 计算每个片段的幅值指标 mav = np.nanmean( - np.nanmax(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) - np.nanmean( + np.nanmax(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), + axis=0)) - np.nanmean( np.nanmin(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) segment_average_amplitude.append(mav) diff --git a/signal_method/time_metrics.py b/signal_method/time_metrics.py index a6a2a94..6d8c196 100644 --- a/signal_method/time_metrics.py +++ b/signal_method/time_metrics.py @@ -5,10 +5,14 @@ import numpy as np @timing_decorator() def calc_mav(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window_second=10, step_second=1, inner_window_second=2): - assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" - assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" - # print(f"movement_mask_length: {len(movement_mask)}, signal_data_length: {len(signal_data)}") - processed_mask = movement_mask.copy() + if movement_mask is not None: + assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" + # assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" + # print(f"movement_mask_length: {len(movement_mask)}, signal_data_length: {len(signal_data)}") + processed_mask = movement_mask.copy() + else: + processed_mask = None + def mav_func(x): return np.mean(np.nanmax(x.reshape(-1, inner_window_second*sampling_rate), axis=1) - np.nanmin(x.reshape(-1, inner_window_second*sampling_rate), axis=1)) / 2 mav_nan, mav = calculate_by_slide_windows(mav_func, signal_data, processed_mask, sampling_rate=sampling_rate, diff --git a/utils/__init__.py b/utils/__init__.py index 51a18ba..c89b90c 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,5 +1,6 @@ -from utils.HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel -from utils.operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list -from utils.operation_tools import merge_short_gaps, remove_short_durations -from utils.event_map import E2N -from utils.signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel \ No newline at end of file +from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel +from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list +from .operation_tools import merge_short_gaps, remove_short_durations +from .operation_tools import collect_values +from .event_map import E2N +from .signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel \ No newline at end of file diff --git a/utils/operation_tools.py b/utils/operation_tools.py index 4feabdf..0229a06 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -125,8 +125,8 @@ def remove_short_durations(state_sequence, time_points, min_duration_sec): @timing_decorator() def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, window_second=20, step_second=None): # 处理标志位长度与 signal_data 对齐 - if calc_mask is None: - calc_mask = np.zeros(len(signal_data), dtype=bool) + # if calc_mask is None: + # calc_mask = np.zeros(len(signal_data), dtype=bool) if step_second is None: step_second = window_second @@ -157,18 +157,21 @@ def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, values_nan = values_nan.repeat(step_second)[:origin_seconds] - for i in range(len(values_nan)): - if calc_mask[i]: - values_nan[i] = np.nan + if calc_mask is not None: + for i in range(len(values_nan)): + if calc_mask[i]: + values_nan[i] = np.nan - values = values_nan.copy() + values = values_nan.copy() - # 插值处理体动区域的 NaN 值 - def interpolate_nans(x, t): - valid_mask = ~np.isnan(x) - return np.interp(t, t[valid_mask], x[valid_mask]) + # 插值处理体动区域的 NaN 值 + def interpolate_nans(x, t): + valid_mask = ~np.isnan(x) + return np.interp(t, t[valid_mask], x[valid_mask]) - values = interpolate_nans(values, np.arange(len(values))) + values = interpolate_nans(values, np.arange(len(values))) + else: + values = values_nan.copy() return values_nan, values @@ -208,7 +211,20 @@ def generate_event_mask(signal_second: int, event_df): def event_mask_2_list(mask): - mask_start = np.where(np.diff(mask, append=0) == 1)[0] - mask_end = np.where(np.diff(mask, append=0) == -1)[0] + 1 + mask_start = np.where(np.diff(mask, append=0) == -1)[0] + mask_end = np.where(np.diff(mask, append=0) == 1)[0] + 1 event_list =[[start, end] for start, end in zip(mask_start, mask_end)] - return event_list \ No newline at end of file + return event_list + + +def collect_values(arr: np.ndarray, index: int, step: int, limit: int, mask=None) -> list: + """收集非 NaN 值,直到达到指定数量或边界""" + values = [] + count = 0 + mask = mask if mask is not None else arr + while count < limit and 0 <= index < len(mask): + if not np.isnan(mask[index]): + values.append(arr[index]) + count += 1 + index += step + return values \ No newline at end of file From fd7941a80a5e92a05b081f43bb6761ef577bcfd9 Mon Sep 17 00:00:00 2001 From: marques Date: Fri, 7 Nov 2025 14:48:42 +0800 Subject: [PATCH 16/28] Update HYS_config.yaml, HYS_process.py, and operation_tools.py for movement revision and event mask handling --- HYS_process.py | 3 ++- dataset_config/HYS_config.yaml | 2 +- utils/operation_tools.py | 12 +++++++++--- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/HYS_process.py b/HYS_process.py index 9f66a51..4873572 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -132,6 +132,7 @@ def process_one_signal(samp_id): resp_movement_revise_conf = conf.get("resp_movement_revise", None) if resp_movement_mask is not None and resp_movement_revise_conf is not None: + print(resp_movement_position_list) resp_movement_mask, resp_movement_position_list = signal_method.movement_revise( signal_data=resp_data, movement_mask=resp_movement_mask, @@ -236,4 +237,4 @@ if __name__ == '__main__': all_samp_disable_df = utils.read_disable_excel(disable_df_path) - process_one_signal(select_ids[0]) + process_one_signal(select_ids[7]) diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index 50e0c93..97d650e 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -36,7 +36,7 @@ resp_movement: compare_intervals_sec: - 60 - 120 - - 180 +# - 180 interval_multiplier: 3.5 merge_gap_sec: 30 min_duration_sec: 1 diff --git a/utils/operation_tools.py b/utils/operation_tools.py index 0229a06..ffd1b9f 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -210,9 +210,15 @@ def generate_event_mask(signal_second: int, event_df): return event_mask, score_mask -def event_mask_2_list(mask): - mask_start = np.where(np.diff(mask, append=0) == -1)[0] - mask_end = np.where(np.diff(mask, append=0) == 1)[0] + 1 +def event_mask_2_list(mask, event_true=True): + if event_true: + event_2_normal = 1 + normal_2_event = -1 + else: + event_2_normal = -1 + normal_2_event = 1 + mask_start = np.where(np.diff(mask, append=0) == normal_2_event)[0] + mask_end = np.where(np.diff(mask, append=0) == normal_2_event)[0] + 1 event_list =[[start, end] for start, end in zip(mask_start, mask_end)] return event_list From 265fcd958ab8449f58dcd7f43feece37cb02e156 Mon Sep 17 00:00:00 2001 From: marques Date: Fri, 7 Nov 2025 16:52:31 +0800 Subject: [PATCH 17/28] Refactor signal processing functions in HYS_process.py and rule_base_event.py, update imports in __init__.py, and enhance event mask handling in operation_tools.py --- HYS_process.py | 6 +- signal_method/__init__.py | 5 +- signal_method/rule_base_event.py | 117 +++++++++++++++++++++++++++++-- signal_method/time_metrics.py | 6 +- utils/operation_tools.py | 14 ++-- 5 files changed, 129 insertions(+), 19 deletions(-) diff --git a/HYS_process.py b/HYS_process.py index 4873572..0c3627c 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -132,7 +132,6 @@ def process_one_signal(samp_id): resp_movement_revise_conf = conf.get("resp_movement_revise", None) if resp_movement_mask is not None and resp_movement_revise_conf is not None: - print(resp_movement_position_list) resp_movement_mask, resp_movement_position_list = signal_method.movement_revise( signal_data=resp_data, movement_mask=resp_movement_mask, @@ -140,7 +139,7 @@ def process_one_signal(samp_id): sampling_rate=resp_fs, **resp_movement_revise_conf ) - print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}") + print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") else: print("resp_movement_mask revise is skipped") @@ -148,9 +147,10 @@ def process_one_signal(samp_id): # 分析Resp的幅值突变区间 if resp_movement_mask is not None: - resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v2( + resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v3( signal_data=resp_data, movement_mask=resp_movement_mask, + movement_list=resp_movement_position_list, sampling_rate=resp_fs) print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}") else: diff --git a/signal_method/__init__.py b/signal_method/__init__.py index ab44ea9..eaea6ea 100644 --- a/signal_method/__init__.py +++ b/signal_method/__init__.py @@ -1,3 +1,4 @@ -from .rule_base_event import detect_low_amplitude_signal, detect_movement, position_based_sleep_recognition_v2 +from .rule_base_event import detect_low_amplitude_signal, detect_movement +from .rule_base_event import position_based_sleep_recognition_v2, position_based_sleep_recognition_v3 from .rule_base_event import movement_revise -from .time_metrics import calc_mav \ No newline at end of file +from .time_metrics import calc_mav_by_slide_windows \ No newline at end of file diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index 072c375..161f378 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -1,7 +1,8 @@ +import utils from utils.operation_tools import timing_decorator import numpy as np from utils import merge_short_gaps, remove_short_durations, event_mask_2_list, collect_values -from signal_method.time_metrics import calc_mav +from signal_method.time_metrics import calc_mav_by_slide_windows @timing_decorator() @@ -187,9 +188,9 @@ def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up compare_size = int(compare_intervals_sec // (stride_size / sampling_rate)) - _, mav = calc_mav(signal_data, movement_mask=None, low_amp_mask=None, sampling_rate=sampling_rate, - window_second=2, step_second=1, - inner_window_second=1) + _, mav = calc_mav_by_slide_windows(signal_data, movement_mask=None, low_amp_mask=None, sampling_rate=sampling_rate, + window_second=2, step_second=1, + inner_window_second=1) # 往左右两边取compare_size个点的mav,取平均值 for start, end in movement_list: @@ -207,6 +208,8 @@ def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up # 逐秒遍历mav,判断是否需要修正 # print(f"Revising movement from index {start} to {end}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") for i in range(start, end + 5): + if i < 0 or i >= len(mav): + continue # print(f"Index {i}, mav: {mav[i]:.2f}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") if mav[i] > (value_metrics * up_interval_multiplier): movement_mask[i] = 1 @@ -229,8 +232,6 @@ def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up return movement_mask, movement_list - - @timing_decorator() def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, stride_sec=None, amplitude_threshold=50, merge_gap_sec=10, min_duration_sec=5): @@ -511,3 +512,107 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat position_change_times.append((movement_start[i - 1], movement_end[i - 1])) return position_changes, position_change_times + + +def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_list, sampling_rate=100): + """ + + :param movement_list: + :param signal_data: + :param movement_mask: mask的采样率为1Hz + :param sampling_rate: + :param window_size_sec: + :return: + """ + mav_calc_window_sec = 1 # 计算mav的窗口大小,单位秒 + # 判断是否存在显著变化 (可根据实际情况调整阈值) + threshold_amplitude = 0.1 # 幅值变化阈值 + threshold_energy = 0.1 # 能量变化阈值 + + # 获取有效片段起止位置 + + valid_list = utils.event_mask_2_list(movement_mask, event_true=False) + + segment_average_amplitude = [] + segment_average_energy = [] + + signal_data_no_movement = signal_data.copy() + for start, end in movement_list: + signal_data_no_movement[start * sampling_rate:end * sampling_rate] = np.nan + + # from matplotlib import pyplot as plt + # plt.plot(signal_data, alpha=0.3, color='gray') + # plt.plot(signal_data_no_movement, color='blue', linewidth=1) + # plt.show() + + if len(valid_list) < 2: + return [], [] + + def clac_mav(data_segment): + mav = np.nanmean( + np.nanmax(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), + axis=0)) - np.nanmean( + np.nanmin(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) + return mav + + def clac_energy(data_segment): + energy = np.nansum(np.abs(data_segment ** 2)) + return energy + + position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int) + position_change_list = [] + + pre_valid_start = valid_list[0][0] * sampling_rate + pre_valid_end = valid_list[0][1] * sampling_rate + + print(f"Total movement segments to analyze: {len(movement_list)}") + print(f"Total valid segments available: {len(valid_list)}") + + for i in range(len(movement_list)): + print(f"Analyzing movement segment {i + 1}/{len(movement_list)}") + + if i + 1 >= len(valid_list): + print("No more valid segments to compare. Ending analysis.") + break + + next_valid_start = valid_list[i + 1][0] * sampling_rate + next_valid_end = valid_list[i + 1][1] * sampling_rate + + movement_start = movement_list[i][0] + movement_end = movement_list[i][1] + + # 避免过短的片段 + if movement_end - movement_start <= sampling_rate: # 小于1秒的片段不考虑 + print(f"Skipping movement segment {i + 1} due to insufficient length. movement start: {movement_start}, movement end: {movement_end}") + continue + + # 计算前后片段的幅值和能量 + left_mav = clac_mav(signal_data_no_movement[pre_valid_start:pre_valid_end]) + right_mav = clac_mav(signal_data_no_movement[next_valid_start:next_valid_end]) + left_energy = clac_energy(signal_data_no_movement[pre_valid_start:pre_valid_end]) + right_energy = clac_energy(signal_data_no_movement[next_valid_start:next_valid_end]) + + # 计算幅值指标的变化率 + amplitude_change = abs(right_mav - left_mav) / max(left_mav, 1e-6) + # 计算能量指标的变化率 + energy_change = abs(right_energy - left_energy) / max(left_energy, 1e-6) + + significant_change = (amplitude_change > threshold_amplitude) and (energy_change > threshold_energy) + if significant_change: + print( + f"Significant position change detected between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}") + # 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示 + position_changes[movement_start:movement_end] = 1 + position_change_list.append(movement_list[i]) + # 更新前后片段 + pre_valid_start = next_valid_start + pre_valid_end = next_valid_end + + else: + print( + f"No significant position change between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}") + # 仅更新前片段 + pre_valid_start = pre_valid_start + pre_valid_end = next_valid_end + + return position_changes, position_change_list diff --git a/signal_method/time_metrics.py b/signal_method/time_metrics.py index 6d8c196..7895c3d 100644 --- a/signal_method/time_metrics.py +++ b/signal_method/time_metrics.py @@ -4,7 +4,7 @@ import numpy as np @timing_decorator() -def calc_mav(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window_second=10, step_second=1, inner_window_second=2): +def calc_mav_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window_second=10, step_second=1, inner_window_second=2): if movement_mask is not None: assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" # assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" @@ -21,7 +21,7 @@ def calc_mav(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window return mav_nan, mav @timing_decorator() -def calc_wavefactor(signal_data, movement_mask, low_amp_mask, sampling_rate=100): +def calc_wavefactor_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100): assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" @@ -33,7 +33,7 @@ def calc_wavefactor(signal_data, movement_mask, low_amp_mask, sampling_rate=100) return wavefactor_nan, wavefactor @timing_decorator() -def calc_peakfactor(signal_data, movement_mask, low_amp_mask, sampling_rate=100): +def calc_peakfactor_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100): assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" diff --git a/utils/operation_tools.py b/utils/operation_tools.py index ffd1b9f..5739097 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -5,6 +5,8 @@ import numpy as np import pandas as pd from matplotlib import pyplot as plt import yaml +from numpy.ma.core import append + from utils.event_map import E2N plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 @@ -212,13 +214,15 @@ def generate_event_mask(signal_second: int, event_df): def event_mask_2_list(mask, event_true=True): if event_true: - event_2_normal = 1 - normal_2_event = -1 - else: event_2_normal = -1 normal_2_event = 1 - mask_start = np.where(np.diff(mask, append=0) == normal_2_event)[0] - mask_end = np.where(np.diff(mask, append=0) == normal_2_event)[0] + 1 + _append = 0 + else: + event_2_normal = 1 + normal_2_event = -1 + _append = 1 + mask_start = np.where(np.diff(mask, prepend=_append, append=_append) == normal_2_event)[0] + mask_end = np.where(np.diff(mask, prepend=_append, append=_append) == event_2_normal)[0] + 1 event_list =[[start, end] for start, end in zip(mask_start, mask_end)] return event_list From f258838a86f5d2b77e5fc7d511ce8b33e51a7633 Mon Sep 17 00:00:00 2001 From: marques Date: Mon, 10 Nov 2025 14:37:42 +0800 Subject: [PATCH 18/28] =?UTF-8?q?=E5=91=BC=E5=90=B8=E4=BD=93=E5=8A=A8?= =?UTF-8?q?=E6=A3=80=E6=B5=8B=E5=9F=BA=E6=9C=AC=E7=A8=B3=E5=AE=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- HYS_process.py | 11 +++++---- dataset_config/HYS_config.yaml | 6 +++++ signal_method/rule_base_event.py | 41 ++++++++++++++++++++++---------- 3 files changed, 40 insertions(+), 18 deletions(-) diff --git a/HYS_process.py b/HYS_process.py index 0c3627c..1eecdf5 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -28,7 +28,7 @@ import numpy as np import signal_method import os from matplotlib import pyplot as plt -os.environ['DISPLAY'] = "localhost:14.0" +os.environ['DISPLAY'] = "localhost:10.0" def process_one_signal(samp_id): signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) @@ -144,14 +144,15 @@ def process_one_signal(samp_id): print("resp_movement_mask revise is skipped") - # 分析Resp的幅值突变区间 - if resp_movement_mask is not None: + resp_amp_change_conf = conf.get("resp_amp_change", None) + if resp_amp_change_conf is not None and resp_movement_mask is not None: resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v3( signal_data=resp_data, movement_mask=resp_movement_mask, movement_list=resp_movement_position_list, - sampling_rate=resp_fs) + sampling_rate=resp_fs, + **resp_amp_change_conf) print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}") else: resp_amp_change_mask = None @@ -237,4 +238,4 @@ if __name__ == '__main__': all_samp_disable_df = utils.read_disable_excel(disable_df_path) - process_one_signal(select_ids[7]) + process_one_signal(select_ids[0]) diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index 97d650e..101c1b0 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -48,6 +48,12 @@ resp_movement_revise: merge_gap_sec: 10 min_duration_sec: 1 +resp_amp_change: + mav_calc_window_sec: 5 + threshold_amplitude: 0.1 + threshold_energy: 0.4 + + bcg: downsample_fs: 100 diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index 161f378..4a408ab 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -194,10 +194,15 @@ def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up # 往左右两边取compare_size个点的mav,取平均值 for start, end in movement_list: - left_values = collect_values(arr=mav, index=start - 1, step=-1, limit=compare_size, mask=movement_mask) - right_values = collect_values(arr=mav, index=end + 5, step=1, limit=compare_size, mask=movement_mask) + left_points = start - 5 + right_points = end + 5 + + left_values = collect_values(arr=mav, index=left_points, step=-1, limit=compare_size, mask=movement_mask) + right_values = collect_values(arr=mav, index=right_points, step=1, limit=compare_size, mask=movement_mask) + left_value_metrics = np.median(left_values) if len(left_values) > 0 else 0 right_value_metrics = np.median(right_values) if len(right_values) > 0 else 0 + if left_value_metrics == 0: value_metrics = right_value_metrics elif right_value_metrics == 0: @@ -207,7 +212,7 @@ def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up # 逐秒遍历mav,判断是否需要修正 # print(f"Revising movement from index {start} to {end}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") - for i in range(start, end + 5): + for i in range(left_points, right_points): if i < 0 or i >= len(mav): continue # print(f"Index {i}, mav: {mav[i]:.2f}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") @@ -514,9 +519,13 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat return position_changes, position_change_times -def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_list, sampling_rate=100): +def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_list, sampling_rate, mav_calc_window_sec, + threshold_amplitude, threshold_energy): """ + :param threshold_energy: + :param threshold_amplitude: + :param mav_calc_window_sec: :param movement_list: :param signal_data: :param movement_mask: mask的采样率为1Hz @@ -524,11 +533,6 @@ def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_lis :param window_size_sec: :return: """ - mav_calc_window_sec = 1 # 计算mav的窗口大小,单位秒 - # 判断是否存在显著变化 (可根据实际情况调整阈值) - threshold_amplitude = 0.1 # 幅值变化阈值 - threshold_energy = 0.1 # 能量变化阈值 - # 获取有效片段起止位置 valid_list = utils.event_mask_2_list(movement_mask, event_true=False) @@ -549,16 +553,26 @@ def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_lis return [], [] def clac_mav(data_segment): - mav = np.nanmean( + # 确定data_segment长度为mav_calc_window_sec的整数倍 + if len(data_segment) % (mav_calc_window_sec * sampling_rate) != 0: + data_segment = data_segment[:-(len(data_segment) % (mav_calc_window_sec * sampling_rate))] + + mav = np.nanstd( np.nanmax(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), - axis=0)) - np.nanmean( + axis=0) - np.nanmin(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) return mav def clac_energy(data_segment): - energy = np.nansum(np.abs(data_segment ** 2)) + energy = np.nansum(np.abs(data_segment ** 2)) // (len(data_segment) // sampling_rate) return energy + def calc_mav_by_quantiles(data_segment): + # 先计算所有的mav值 + mav_values = np.nanmax(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=0) - np.nanmin( + data_segment.reshape(-1, mav_calc_window_sec * sampling_rate)) + # 计算分位数 + position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int) position_change_list = [] @@ -583,7 +597,8 @@ def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_lis # 避免过短的片段 if movement_end - movement_start <= sampling_rate: # 小于1秒的片段不考虑 - print(f"Skipping movement segment {i + 1} due to insufficient length. movement start: {movement_start}, movement end: {movement_end}") + print( + f"Skipping movement segment {i + 1} due to insufficient length. movement start: {movement_start}, movement end: {movement_end}") continue # 计算前后片段的幅值和能量 From 85f2408f13a0b8100637c5832293405c13ddf449 Mon Sep 17 00:00:00 2001 From: marques Date: Mon, 10 Nov 2025 14:41:25 +0800 Subject: [PATCH 19/28] =?UTF-8?q?=E5=91=BC=E5=90=B8=E4=BD=93=E5=8A=A8?= =?UTF-8?q?=E6=A3=80=E6=B5=8B=E5=9F=BA=E6=9C=AC=E7=A8=B3=E5=AE=9A=E8=B0=83?= =?UTF-8?q?=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- signal_method/rule_base_event.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index 4a408ab..107d55b 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -195,7 +195,7 @@ def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up # 往左右两边取compare_size个点的mav,取平均值 for start, end in movement_list: left_points = start - 5 - right_points = end + 5 + right_points = end + 10 left_values = collect_values(arr=mav, index=left_points, step=-1, limit=compare_size, mask=movement_mask) right_values = collect_values(arr=mav, index=right_points, step=1, limit=compare_size, mask=movement_mask) From 40fdda649791f3c945f1fe46325746d6551f762e Mon Sep 17 00:00:00 2001 From: marques Date: Mon, 10 Nov 2025 18:38:26 +0800 Subject: [PATCH 20/28] =?UTF-8?q?=E4=B8=AD=E5=A4=A7=E4=BA=94=E9=99=A2?= =?UTF-8?q?=E4=B8=BA=E5=8D=A0=E4=BD=8D=EF=BC=8C=E5=91=BC=E7=A0=94=E6=89=80?= =?UTF-8?q?=E5=B7=B2=E5=8F=AF=E4=BB=A5=E6=AD=A3=E5=B8=B8=E5=AF=BC=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- HYS_process.py | 56 ++++++- ZD5Y_process.py | 277 +++++++++++++++++++++++++++++++ dataset_config/HYS_config.yaml | 11 +- dataset_config/ZD5Y_config.yaml | 88 ++++++++++ draw_tools/draw_statics.py | 10 +- signal_method/rule_base_event.py | 131 ++++++++++----- utils/__init__.py | 1 + utils/operation_tools.py | 7 +- 8 files changed, 525 insertions(+), 56 deletions(-) create mode 100644 ZD5Y_process.py create mode 100644 dataset_config/ZD5Y_config.yaml diff --git a/HYS_process.py b/HYS_process.py index 1eecdf5..50f3c6c 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -21,7 +21,7 @@ todo: 使用mask 屏蔽无用区间 """ from pathlib import Path - +import shutil import draw_tools import utils import numpy as np @@ -30,7 +30,7 @@ import os from matplotlib import pyplot as plt os.environ['DISPLAY'] = "localhost:10.0" -def process_one_signal(samp_id): +def process_one_signal(samp_id, show=False): signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) if not signal_path: raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}") @@ -43,6 +43,10 @@ def process_one_signal(samp_id): label_path = list(label_path)[0] print(f"Processing Label_corrected file: {label_path}") + # 保存处理后的数据和标签 + save_samp_path = save_path / f"{samp_id}" + save_samp_path.mkdir(parents=True, exist_ok=True) + signal_data_raw = utils.read_signal_txt(signal_path) signal_length = len(signal_data_raw) print(f"signal_length: {signal_length}") @@ -137,7 +141,8 @@ def process_one_signal(samp_id): movement_mask=resp_movement_mask, movement_list=resp_movement_position_list, sampling_rate=resp_fs, - **resp_movement_revise_conf + **resp_movement_revise_conf, + verbose=False ) print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") else: @@ -152,7 +157,8 @@ def process_one_signal(samp_id): movement_mask=resp_movement_mask, movement_list=resp_movement_position_list, sampling_rate=resp_fs, - **resp_amp_change_conf) + **resp_amp_change_conf, + verbose=True) print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}") else: resp_amp_change_mask = None @@ -202,7 +208,7 @@ def process_one_signal(samp_id): signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs, target_fs=100) signal_fs = 100 - draw_tools.draw_signal_with_mask(samp_id=samp_id, + draw_tools.draw_signal_with_mask(samp_id=samp_id, signal_data=signal_data, signal_fs=signal_fs, resp_data=resp_data, @@ -216,7 +222,35 @@ def process_one_signal(samp_id): resp_sa_mask=event_mask, bcg_low_amp_mask=bcg_low_amp_mask, bcg_movement_mask=bcg_movement_mask, - bcg_change_mask=bcg_amp_change_mask) + bcg_change_mask=bcg_amp_change_mask, + show=show, + save_path=save_samp_path / f"{samp_id}_Signal_Plots.png") + + + + + # 复制事件文件 到保存路径 + sa_label_save_name = f"{samp_id}" + label_path.name + shutil.copyfile(label_path, save_samp_path / sa_label_save_name) + + # 新建一个dataframe,分别是秒数、SA标签,SA质量标签,禁用标签,Resp低幅值标签,Resp体动标签,Resp幅值突变标签,Bcg低幅值标签,Bcg体动标签,Bcg幅值突变标签 + save_dict = { + "Second": np.arange(signal_second), + "SA_Label": event_mask, + "SA_Score": score_mask, + "Disable_Label": manual_disable_mask, + "Resp_LowAmp_Label": resp_low_amp_mask if resp_low_amp_mask is not None else np.zeros(signal_second, dtype=int), + "Resp_Movement_Label": resp_movement_mask if resp_movement_mask is not None else np.zeros(signal_second, dtype=int), + "Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second, dtype=int), + "Bcg_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int), + "Bcg_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second, dtype=int), + "Bcg_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second, dtype=int) + } + + mask_label_save_name = f"{samp_id}_Processed_Labels.csv" + utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict) + + @@ -229,13 +263,21 @@ if __name__ == '__main__': conf = utils.load_dataset_conf(yaml_path) select_ids = conf["select_ids"] root_path = Path(conf["root_path"]) + save_path = Path(conf["save_path"]) print(f"select_ids: {select_ids}") print(f"root_path: {root_path}") + print(f"save_path: {save_path}") org_signal_root_path = root_path / "OrgBCG_Aligned" label_root_path = root_path / "Label" all_samp_disable_df = utils.read_disable_excel(disable_df_path) - process_one_signal(select_ids[0]) + process_one_signal(select_ids[9], show=True) + + # for samp_id in select_ids: + # print(f"Processing sample ID: {samp_id}") + # process_one_signal(samp_id, show=False) + # print(f"Finished processing sample ID: {samp_id}\n\n") + diff --git a/ZD5Y_process.py b/ZD5Y_process.py new file mode 100644 index 0000000..23dc4c8 --- /dev/null +++ b/ZD5Y_process.py @@ -0,0 +1,277 @@ +""" +本脚本完成对呼研所数据的处理,包含以下功能: +1. 数据读取与预处理 + 从传入路径中,进行数据和标签的读取,并进行初步的预处理 + 预处理包括为数据进行滤波、去噪等操作 +2. 数据清洗与异常值处理 +3. 输出清晰后的统计信息 +4. 数据保存 + 将处理后的数据保存到指定路径,便于后续使用 + 主要是保存切分后的数据位置和标签 +5. 可视化 + 提供数据处理前后的可视化对比,帮助理解数据变化 + 绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等 + +todo: 使用mask 屏蔽无用区间 + + +# 低幅值区间规则标定与剔除 +# 高幅值连续体动规则标定与剔除 +# 手动标定不可用区间提剔除 +""" + +from pathlib import Path +import shutil +import draw_tools +import utils +import numpy as np +import signal_method +import os +from matplotlib import pyplot as plt +os.environ['DISPLAY'] = "localhost:10.0" + +def process_one_signal(samp_id, show=False): + signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) + if not signal_path: + raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}") + signal_path = signal_path[0] + print(f"Processing OrgBCG_Sync signal file: {signal_path}") + + label_path = (label_root_path / f"{samp_id}").glob("SA Label_corrected.csv") + if not label_path: + raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}") + label_path = list(label_path)[0] + print(f"Processing Label_corrected file: {label_path}") + + signal_data_raw = utils.read_signal_txt(signal_path) + signal_length = len(signal_data_raw) + print(f"signal_length: {signal_length}") + signal_fs = int(signal_path.stem.split("_")[-1]) + print(f"signal_fs: {signal_fs}") + signal_second = signal_length // signal_fs + print(f"signal_second: {signal_second}") + + # 根据采样率进行截断 + signal_data_raw = signal_data_raw[:signal_second * signal_fs] + + # 滤波 + # 50Hz陷波滤波器 + # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) + print("Applying 50Hz notch filter...") + signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) + + resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs) + resp_fs = conf["resp"]["downsample_fs_1"] + resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs) + resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20) + resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"], + low_cut=conf["resp_filter"]["low_cut"], + high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"], + sample_rate=resp_fs) + print("Begin plotting signal data...") + + + # fig = plt.figure(figsize=(12, 8)) + # # 绘制三个图raw_data、resp_data_1、resp_data_2 + # ax0 = fig.add_subplot(3, 1, 1) + # ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue') + # ax0.set_title('Raw Signal Data') + # ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) + # ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange') + # ax1.set_title('Resp Data after Average Filtering') + # ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) + # ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green') + # ax2.set_title('Resp Data after Butterworth Filtering') + # plt.tight_layout() + # plt.show() + + bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"], + low_cut=conf["bcg_filter"]["low_cut"], + high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"], + sample_rate=signal_fs) + + # 降采样 + old_resp_fs = resp_fs + resp_fs = conf["resp"]["downsample_fs_2"] + resp_data = utils.downsample_signal_fast(original_signal=resp_data_2, original_fs=old_resp_fs, target_fs=resp_fs) + bcg_fs = conf["bcg"]["downsample_fs"] + bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs) + + label_data = utils.read_label_csv(path=label_path) + event_mask, score_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data) + + manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[ + all_samp_disable_df["id"] == samp_id]) + print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}") + + # 分析Resp的低幅值区间 + resp_low_amp_conf = conf.get("resp_low_amp", None) + if resp_low_amp_conf is not None: + resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal( + signal_data=resp_data, + sampling_rate=resp_fs, + **resp_low_amp_conf + ) + print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}, count_low_amp_positions: {len(resp_low_amp_position_list)}") + else: + resp_low_amp_mask, resp_low_amp_position_list = None, None + print("resp_low_amp_mask is None") + + # 分析Resp的高幅值伪迹区间 + resp_movement_conf = conf.get("resp_movement", None) + if resp_movement_conf is not None: + raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement( + signal_data=resp_data, + sampling_rate=resp_fs, + **resp_movement_conf + ) + print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") + else: + resp_movement_mask, resp_movement_position_list = None, None + print("resp_movement_mask is None") + + resp_movement_revise_conf = conf.get("resp_movement_revise", None) + if resp_movement_mask is not None and resp_movement_revise_conf is not None: + resp_movement_mask, resp_movement_position_list = signal_method.movement_revise( + signal_data=resp_data, + movement_mask=resp_movement_mask, + movement_list=resp_movement_position_list, + sampling_rate=resp_fs, + **resp_movement_revise_conf, + verbose=False + ) + print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") + else: + print("resp_movement_mask revise is skipped") + + + # 分析Resp的幅值突变区间 + resp_amp_change_conf = conf.get("resp_amp_change", None) + if resp_amp_change_conf is not None and resp_movement_mask is not None: + resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v3( + signal_data=resp_data, + movement_mask=resp_movement_mask, + movement_list=resp_movement_position_list, + sampling_rate=resp_fs, + **resp_amp_change_conf) + print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}") + else: + resp_amp_change_mask = None + print("amp_change_mask is None") + + + + # 分析Bcg的低幅值区间 + bcg_low_amp_conf = conf.get("bcg_low_amp", None) + if bcg_low_amp_conf is not None: + bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal( + signal_data=bcg_data, + sampling_rate=bcg_fs, + **bcg_low_amp_conf + ) + print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}, count_low_amp_positions: {len(bcg_low_amp_position_list)}") + else: + bcg_low_amp_mask, bcg_low_amp_position_list = None, None + print("bcg_low_amp_mask is None") + # 分析Bcg的高幅值伪迹区间 + bcg_movement_conf = conf.get("bcg_movement", None) + if bcg_movement_conf is not None: + raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement( + signal_data=bcg_data, + sampling_rate=bcg_fs, + **bcg_movement_conf + ) + print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}, count_movement_positions: {len(bcg_movement_position_list)}") + else: + bcg_movement_mask = None + print("bcg_movement_mask is None") + # 分析Bcg的幅值突变区间 + if bcg_movement_mask is not None: + bcg_amp_change_mask, bcg_amp_change_list = signal_method.position_based_sleep_recognition_v2( + signal_data=bcg_data, + movement_mask=bcg_movement_mask, + sampling_rate=bcg_fs) + print(f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}, count_amp_change_positions: {len(bcg_amp_change_list)}") + else: + bcg_amp_change_mask = None + print("bcg_amp_change_mask is None") + + + # 如果signal_data采样率过,进行降采样 + if signal_fs == 1000: + signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100) + signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs, target_fs=100) + signal_fs = 100 + if show: + draw_tools.draw_signal_with_mask(samp_id=samp_id, + signal_data=signal_data, + signal_fs=signal_fs, + resp_data=resp_data, + resp_fs=resp_fs, + bcg_data=bcg_data, + bcg_fs=bcg_fs, + signal_disable_mask=manual_disable_mask, + resp_low_amp_mask=resp_low_amp_mask, + resp_movement_mask=resp_movement_mask, + resp_change_mask=resp_amp_change_mask, + resp_sa_mask=event_mask, + bcg_low_amp_mask=bcg_low_amp_mask, + bcg_movement_mask=bcg_movement_mask, + bcg_change_mask=bcg_amp_change_mask) + + + # 保存处理后的数据和标签 + save_samp_path = save_path / f"{samp_id}" + save_samp_path.mkdir(parents=True, exist_ok=True) + + # 复制事件文件 到保存路径 + sa_label_save_name = f"{samp_id}" + label_path.name + shutil.copyfile(label_path, save_samp_path / sa_label_save_name) + + # 新建一个dataframe,分别是秒数、SA标签,SA质量标签,禁用标签,Resp低幅值标签,Resp体动标签,Resp幅值突变标签,Bcg低幅值标签,Bcg体动标签,Bcg幅值突变标签 + save_dict = { + "Second": np.arange(signal_second), + "SA_Label": event_mask, + "SA_Score": score_mask, + "Disable_Label": manual_disable_mask, + "Resp_LowAmp_Label": resp_low_amp_mask if resp_low_amp_mask is not None else np.zeros(signal_second, dtype=int), + "Resp_Movement_Label": resp_movement_mask if resp_movement_mask is not None else np.zeros(signal_second, dtype=int), + "Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second, dtype=int), + "Bcg_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int), + "Bcg_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second, dtype=int), + "Bcg_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second, dtype=int) + } + + mask_label_save_name = f"{samp_id}_Processed_Labels.csv" + utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict) + + + + + + + +if __name__ == '__main__': + yaml_path = Path("./dataset_config/ZD5Y_config.yaml") + disable_df_path = Path("./排除区间.xlsx") + + conf = utils.load_dataset_conf(yaml_path) + select_ids = conf["select_ids"] + root_path = Path(conf["root_path"]) + save_path = Path(conf["save_path"]) + + print(f"select_ids: {select_ids}") + print(f"root_path: {root_path}") + print(f"save_path: {save_path}") + + org_signal_root_path = root_path / "OrgBCG_Aligned" + label_root_path = root_path / "Label" + + all_samp_disable_df = utils.read_disable_excel(disable_df_path) + + process_one_signal(select_ids[1], show=True) + + # for samp_id in select_ids: + # print(f"Processing sample ID: {samp_id}") + # process_one_signal(samp_id, show=False) + # print(f"Finished processing sample ID: {samp_id}\n\n") \ No newline at end of file diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index 101c1b0..e6d02dc 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -11,6 +11,7 @@ select_ids: - 960 root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS +save_path: /mnt/disk_code/marques/dataprepare/output/HYS resp: downsample_fs_1: 100 @@ -32,25 +33,25 @@ resp_low_amp: resp_movement: window_size_sec: 20 stride_sec: 1 - std_median_multiplier: 5 + std_median_multiplier: 4 compare_intervals_sec: - 60 - 120 # - 180 - interval_multiplier: 3.5 + interval_multiplier: 3 merge_gap_sec: 30 min_duration_sec: 1 resp_movement_revise: up_interval_multiplier: 3 - down_interval_multiplier: 1.5 + down_interval_multiplier: 2 compare_intervals_sec: 30 merge_gap_sec: 10 min_duration_sec: 1 resp_amp_change: - mav_calc_window_sec: 5 - threshold_amplitude: 0.1 + mav_calc_window_sec: 1 + threshold_amplitude: 0.25 threshold_energy: 0.4 diff --git a/dataset_config/ZD5Y_config.yaml b/dataset_config/ZD5Y_config.yaml new file mode 100644 index 0000000..ff479fd --- /dev/null +++ b/dataset_config/ZD5Y_config.yaml @@ -0,0 +1,88 @@ +select_ids: + - 3103 + - 3105 + - 3106 + - 3107 + - 3108 + - 3110 + - 3203 + - 3204 + - 3205 + - 3207 + - 3208 + - 3209 + - 3212 + - 3301 + - 3303 + - 3307 + - 3403 + - 3504 + +root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/ZD5Y +save_path: /mnt/disk_code/marques/dataprepare/output/ZD5Y + +resp: + downsample_fs_1: 100 + downsample_fs_2: 10 + +resp_filter: + filter_type: bandpass + low_cut: 0.01 + high_cut: 0.7 + order: 3 + +resp_low_amp: + window_size_sec: 30 + stride_sec: + amplitude_threshold: 3 + merge_gap_sec: 60 + min_duration_sec: 60 + +resp_movement: + window_size_sec: 20 + stride_sec: 1 + std_median_multiplier: 5 + compare_intervals_sec: + - 60 + - 120 +# - 180 + interval_multiplier: 3.5 + merge_gap_sec: 30 + min_duration_sec: 1 + +resp_movement_revise: + up_interval_multiplier: 3 + down_interval_multiplier: 2 + compare_intervals_sec: 30 + merge_gap_sec: 10 + min_duration_sec: 1 + +resp_amp_change: + mav_calc_window_sec: 1 + threshold_amplitude: 0.25 + threshold_energy: 0.4 + + +bcg: + downsample_fs: 100 + +bcg_filter: + filter_type: bandpass + low_cut: 1 + high_cut: 10 + order: 10 + +bcg_low_amp: + window_size_sec: 1 + stride_sec: + amplitude_threshold: 8 + merge_gap_sec: 20 + min_duration_sec: 3 + + +bcg_movement: + window_size_sec: 2 + stride_sec: + merge_gap_sec: 20 + min_duration_sec: 4 + diff --git a/draw_tools/draw_statics.py b/draw_tools/draw_statics.py index f44cde0..f94679d 100644 --- a/draw_tools/draw_statics.py +++ b/draw_tools/draw_statics.py @@ -178,7 +178,7 @@ def draw_signal_metrics(bcg_origin_signal, resp_origin_signal, bcg_no_movement_s def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, resp_fs, bcg_fs, signal_disable_mask, resp_low_amp_mask, resp_movement_mask, resp_change_mask, - resp_sa_mask, bcg_low_amp_mask, bcg_movement_mask, bcg_change_mask + resp_sa_mask, bcg_low_amp_mask, bcg_movement_mask, bcg_change_mask, show=False, save_path=None ): # 第一行绘制去工频噪声的原始信号,右侧为不可用区间标记,左侧为信号幅值纵坐标 # 第二行绘制呼吸分量,右侧低幅值、高幅值、幅值变换标记、SA标签,左侧为呼吸幅值纵坐标 @@ -292,10 +292,12 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, ax0_twin.callbacks.connect('ylim_changed', on_lims_change) ax1_twin.callbacks.connect('ylim_changed', on_lims_change) ax2_twin.callbacks.connect('ylim_changed', on_lims_change) - - plt.tight_layout() - plt.show() + + if save_path is not None: + plt.savefig(save_path, dpi=300) + if show: + plt.show() diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index 107d55b..d13e6cb 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -169,7 +169,7 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up_interval_multiplier: float, - down_interval_multiplier: float, compare_intervals_sec, merge_gap_sec, min_duration_sec): + down_interval_multiplier: float, compare_intervals_sec, merge_gap_sec, min_duration_sec, verbose=False): """ 基于标准差对已有体动掩码进行修正。 用于大尺度的体动检测后的位置精细修正 @@ -189,13 +189,13 @@ def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up compare_size = int(compare_intervals_sec // (stride_size / sampling_rate)) _, mav = calc_mav_by_slide_windows(signal_data, movement_mask=None, low_amp_mask=None, sampling_rate=sampling_rate, - window_second=2, step_second=1, - inner_window_second=1) + window_second=4, step_second=1, + inner_window_second=4) # 往左右两边取compare_size个点的mav,取平均值 for start, end in movement_list: - left_points = start - 5 - right_points = end + 10 + left_points = start - 20 + right_points = end + 20 left_values = collect_values(arr=mav, index=left_points, step=-1, limit=compare_size, mask=movement_mask) right_values = collect_values(arr=mav, index=right_points, step=1, limit=compare_size, mask=movement_mask) @@ -203,28 +203,58 @@ def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up left_value_metrics = np.median(left_values) if len(left_values) > 0 else 0 right_value_metrics = np.median(right_values) if len(right_values) > 0 else 0 - if left_value_metrics == 0: - value_metrics = right_value_metrics - elif right_value_metrics == 0: - value_metrics = left_value_metrics - else: - value_metrics = np.mean([left_value_metrics, right_value_metrics]) + # if left_value_metrics == 0: + # value_metrics = right_value_metrics + # elif right_value_metrics == 0: + # value_metrics = left_value_metrics + # else: + # value_metrics = np.mean([left_value_metrics, right_value_metrics]) + + if left_value_metrics == 0: + left_value_metrics = right_value_metrics + elif right_value_metrics == 0: + right_value_metrics = left_value_metrics + + if verbose: + print(f"Revising movement from index {start} to {end}, left_metric: {left_value_metrics:.2f}, right_metric: {right_value_metrics:.2f}") - # 逐秒遍历mav,判断是否需要修正 - # print(f"Revising movement from index {start} to {end}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") for i in range(left_points, right_points): if i < 0 or i >= len(mav): continue - # print(f"Index {i}, mav: {mav[i]:.2f}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") - if mav[i] > (value_metrics * up_interval_multiplier): + if i < start: + value_metrics = left_value_metrics + elif i > end: + value_metrics = right_value_metrics + else: + value_metrics = (left_value_metrics + right_value_metrics) / 2 + + if mav[i] > (value_metrics * up_interval_multiplier) and movement_mask[i] == 0: movement_mask[i] = 1 - # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * up_interval_multiplier:.2f}") - elif mav[i] < (value_metrics * down_interval_multiplier): + if verbose: + print(f"Normal revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_metrics * up_interval_multiplier:.2f}") + elif mav[i] < (value_metrics * down_interval_multiplier) and movement_mask[i] == 1: movement_mask[i] = 0 - # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * down_interval_multiplier:.2f}") - # else: - # print(f"No revision at index {i}, mav: {mav[i]:.2f}, up_threshold: {value_mean * up_interval_multiplier:.2f}, down_threshold: {value_mean * down_interval_multiplier:.2f}") - # + if verbose: + print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_metrics * down_interval_multiplier:.2f}") + else: + if verbose: + print(f"No revision at index {i}, mav: {mav[i]:.2f}, up_threshold: {value_metrics * up_interval_multiplier:.2f}, down_threshold: {value_metrics * down_interval_multiplier:.2f}") + # + # 逐秒遍历mav,判断是否需要修正 + # print(f"Revising movement from index {start} to {end}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") + # for i in range(left_points, right_points): + # if i < 0 or i >= len(mav): + # continue + # # print(f"Index {i}, mav: {mav[i]:.2f}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") + # if mav[i] > (value_metrics * up_interval_multiplier): + # movement_mask[i] = 1 + # # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * up_interval_multiplier:.2f}") + # elif mav[i] < (value_metrics * down_interval_multiplier): + # movement_mask[i] = 0 + # # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * down_interval_multiplier:.2f}") + # # else: + # # print(f"No revision at index {i}, mav: {mav[i]:.2f}, up_threshold: {value_mean * up_interval_multiplier:.2f}, down_threshold: {value_mean * down_interval_multiplier:.2f}") + # # # 如果需要合并间隔小的体动状态 if merge_gap_sec > 0: movement_mask = merge_short_gaps(movement_mask, time_points, merge_gap_sec) @@ -520,7 +550,7 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_list, sampling_rate, mav_calc_window_sec, - threshold_amplitude, threshold_energy): + threshold_amplitude, threshold_energy, verbose=False): """ :param threshold_energy: @@ -569,9 +599,18 @@ def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_lis def calc_mav_by_quantiles(data_segment): # 先计算所有的mav值 + if len(data_segment) % (mav_calc_window_sec * sampling_rate) != 0: + data_segment = data_segment[:-(len(data_segment) % (mav_calc_window_sec * sampling_rate))] + mav_values = np.nanmax(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=0) - np.nanmin( data_segment.reshape(-1, mav_calc_window_sec * sampling_rate)) # 计算分位数 + q20 = np.nanpercentile(mav_values, 20) + q80 = np.nanpercentile(mav_values, 80) + + mav_values = mav_values[(mav_values >= q20) & (mav_values <= q80)] + mav = np.nanmean(mav_values) + return mav position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int) position_change_list = [] @@ -579,14 +618,17 @@ def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_lis pre_valid_start = valid_list[0][0] * sampling_rate pre_valid_end = valid_list[0][1] * sampling_rate - print(f"Total movement segments to analyze: {len(movement_list)}") - print(f"Total valid segments available: {len(valid_list)}") + if verbose: + print(f"Total movement segments to analyze: {len(movement_list)}") + print(f"Total valid segments available: {len(valid_list)}") for i in range(len(movement_list)): - print(f"Analyzing movement segment {i + 1}/{len(movement_list)}") + if verbose: + print(f"Analyzing movement segment {i + 1}/{len(movement_list)}") if i + 1 >= len(valid_list): - print("No more valid segments to compare. Ending analysis.") + if verbose: + print("No more valid segments to compare. Ending analysis.") break next_valid_start = valid_list[i + 1][0] * sampling_rate @@ -597,25 +639,33 @@ def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_lis # 避免过短的片段 if movement_end - movement_start <= sampling_rate: # 小于1秒的片段不考虑 - print( - f"Skipping movement segment {i + 1} due to insufficient length. movement start: {movement_start}, movement end: {movement_end}") + if verbose: + print( + f"Skipping movement segment {i + 1} due to insufficient length. movement start: {movement_start}, movement end: {movement_end}") continue # 计算前后片段的幅值和能量 - left_mav = clac_mav(signal_data_no_movement[pre_valid_start:pre_valid_end]) - right_mav = clac_mav(signal_data_no_movement[next_valid_start:next_valid_end]) - left_energy = clac_energy(signal_data_no_movement[pre_valid_start:pre_valid_end]) - right_energy = clac_energy(signal_data_no_movement[next_valid_start:next_valid_end]) + # left_mav = clac_mav(signal_data_no_movement[pre_valid_start:pre_valid_end]) + # right_mav = clac_mav(signal_data_no_movement[next_valid_start:next_valid_end]) + # left_energy = clac_energy(signal_data_no_movement[pre_valid_start:pre_valid_end]) + # right_energy = clac_energy(signal_data_no_movement[next_valid_start:next_valid_end]) + + left_mav = calc_mav_by_quantiles(signal_data_no_movement[pre_valid_start:pre_valid_end]) + right_mav = calc_mav_by_quantiles(signal_data_no_movement[next_valid_start:next_valid_end]) + # 计算幅值指标的变化率 amplitude_change = abs(right_mav - left_mav) / max(left_mav, 1e-6) - # 计算能量指标的变化率 - energy_change = abs(right_energy - left_energy) / max(left_energy, 1e-6) + # # 计算能量指标的变化率 + # energy_change = abs(right_energy - left_energy) / max(left_energy, 1e-6) - significant_change = (amplitude_change > threshold_amplitude) and (energy_change > threshold_energy) + # significant_change = (amplitude_change > threshold_amplitude) and (energy_change > threshold_energy) + significant_change = (amplitude_change > threshold_amplitude) if significant_change: - print( - f"Significant position change detected between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}") + # print( + # f"Significant position change detected between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}") + if verbose: + print(f"Significant position change detected between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}") # 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示 position_changes[movement_start:movement_end] = 1 position_change_list.append(movement_list[i]) @@ -624,8 +674,11 @@ def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_lis pre_valid_end = next_valid_end else: - print( - f"No significant position change between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}") + # print( + # f"No significant position change between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}") + if verbose: + print(f"No significant position change between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}") + # 仅更新前片段 pre_valid_start = pre_valid_start pre_valid_end = next_valid_end diff --git a/utils/__init__.py b/utils/__init__.py index c89b90c..68e7772 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -2,5 +2,6 @@ from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list from .operation_tools import merge_short_gaps, remove_short_durations from .operation_tools import collect_values +from .operation_tools import save_process_label from .event_map import E2N from .signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel \ No newline at end of file diff --git a/utils/operation_tools.py b/utils/operation_tools.py index 5739097..23205be 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -237,4 +237,9 @@ def collect_values(arr: np.ndarray, index: int, step: int, limit: int, mask=None values.append(arr[index]) count += 1 index += step - return values \ No newline at end of file + return values + + +def save_process_label(save_path: Path, save_dict: dict): + save_df = pd.DataFrame(save_dict) + save_df.to_csv(save_path, index=False) From 60e245b1e39e5f0146b62b681854e9c61f9ca1b4 Mon Sep 17 00:00:00 2001 From: marques Date: Tue, 11 Nov 2025 10:39:23 +0800 Subject: [PATCH 21/28] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E5=B9=85=E5=80=BC?= =?UTF-8?q?=E6=94=B9=E5=8F=98=E8=AE=A1=E7=AE=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- HYS_process.py | 93 ++++++++++++++++---------------- dataset_config/HYS_config.yaml | 2 +- signal_method/rule_base_event.py | 10 ++-- utils/operation_tools.py | 2 +- 4 files changed, 55 insertions(+), 52 deletions(-) diff --git a/HYS_process.py b/HYS_process.py index 50f3c6c..06156eb 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -12,7 +12,6 @@ 提供数据处理前后的可视化对比,帮助理解数据变化 绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等 -todo: 使用mask 屏蔽无用区间 # 低幅值区间规则标定与剔除 @@ -28,8 +27,10 @@ import numpy as np import signal_method import os from matplotlib import pyplot as plt + os.environ['DISPLAY'] = "localhost:10.0" + def process_one_signal(samp_id, show=False): signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) if not signal_path: @@ -69,12 +70,11 @@ def process_one_signal(samp_id, show=False): resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs) resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20) resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"], - low_cut=conf["resp_filter"]["low_cut"], - high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"], - sample_rate=resp_fs) + low_cut=conf["resp_filter"]["low_cut"], + high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"], + sample_rate=resp_fs) print("Begin plotting signal data...") - # fig = plt.figure(figsize=(12, 8)) # # 绘制三个图raw_data、resp_data_1、resp_data_2 # ax0 = fig.add_subplot(3, 1, 1) @@ -116,7 +116,8 @@ def process_one_signal(samp_id, show=False): sampling_rate=resp_fs, **resp_low_amp_conf ) - print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}, count_low_amp_positions: {len(resp_low_amp_position_list)}") + print( + f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}, count_low_amp_positions: {len(resp_low_amp_position_list)}") else: resp_low_amp_mask, resp_low_amp_position_list = None, None print("resp_low_amp_mask is None") @@ -129,7 +130,8 @@ def process_one_signal(samp_id, show=False): sampling_rate=resp_fs, **resp_movement_conf ) - print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") + print( + f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") else: resp_movement_mask, resp_movement_position_list = None, None print("resp_movement_mask is None") @@ -144,11 +146,11 @@ def process_one_signal(samp_id, show=False): **resp_movement_revise_conf, verbose=False ) - print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") + print( + f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") else: print("resp_movement_mask revise is skipped") - # 分析Resp的幅值突变区间 resp_amp_change_conf = conf.get("resp_amp_change", None) if resp_amp_change_conf is not None and resp_movement_mask is not None: @@ -159,13 +161,12 @@ def process_one_signal(samp_id, show=False): sampling_rate=resp_fs, **resp_amp_change_conf, verbose=True) - print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}") + print( + f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}") else: resp_amp_change_mask = None print("amp_change_mask is None") - - # 分析Bcg的低幅值区间 bcg_low_amp_conf = conf.get("bcg_low_amp", None) if bcg_low_amp_conf is not None: @@ -174,10 +175,12 @@ def process_one_signal(samp_id, show=False): sampling_rate=bcg_fs, **bcg_low_amp_conf ) - print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}, count_low_amp_positions: {len(bcg_low_amp_position_list)}") + print( + f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}, count_low_amp_positions: {len(bcg_low_amp_position_list)}") else: bcg_low_amp_mask, bcg_low_amp_position_list = None, None print("bcg_low_amp_mask is None") + # 分析Bcg的高幅值伪迹区间 bcg_movement_conf = conf.get("bcg_movement", None) if bcg_movement_conf is not None: @@ -186,48 +189,48 @@ def process_one_signal(samp_id, show=False): sampling_rate=bcg_fs, **bcg_movement_conf ) - print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}, count_movement_positions: {len(bcg_movement_position_list)}") + print( + f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}, count_movement_positions: {len(bcg_movement_position_list)}") else: bcg_movement_mask = None print("bcg_movement_mask is None") + # 分析Bcg的幅值突变区间 if bcg_movement_mask is not None: bcg_amp_change_mask, bcg_amp_change_list = signal_method.position_based_sleep_recognition_v2( signal_data=bcg_data, movement_mask=bcg_movement_mask, sampling_rate=bcg_fs) - print(f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}, count_amp_change_positions: {len(bcg_amp_change_list)}") + print( + f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}, count_amp_change_positions: {len(bcg_amp_change_list)}") else: bcg_amp_change_mask = None print("bcg_amp_change_mask is None") - # 如果signal_data采样率过,进行降采样 if signal_fs == 1000: signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100) - signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs, target_fs=100) + signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs, + target_fs=100) signal_fs = 100 draw_tools.draw_signal_with_mask(samp_id=samp_id, - signal_data=signal_data, - signal_fs=signal_fs, - resp_data=resp_data, - resp_fs=resp_fs, - bcg_data=bcg_data, - bcg_fs=bcg_fs, - signal_disable_mask=manual_disable_mask, - resp_low_amp_mask=resp_low_amp_mask, - resp_movement_mask=resp_movement_mask, - resp_change_mask=resp_amp_change_mask, - resp_sa_mask=event_mask, - bcg_low_amp_mask=bcg_low_amp_mask, - bcg_movement_mask=bcg_movement_mask, - bcg_change_mask=bcg_amp_change_mask, - show=show, - save_path=save_samp_path / f"{samp_id}_Signal_Plots.png") - - - + signal_data=signal_data, + signal_fs=signal_fs, + resp_data=resp_data, + resp_fs=resp_fs, + bcg_data=bcg_data, + bcg_fs=bcg_fs, + signal_disable_mask=manual_disable_mask, + resp_low_amp_mask=resp_low_amp_mask, + resp_movement_mask=resp_movement_mask, + resp_change_mask=resp_amp_change_mask, + resp_sa_mask=event_mask, + bcg_low_amp_mask=bcg_low_amp_mask, + bcg_movement_mask=bcg_movement_mask, + bcg_change_mask=bcg_amp_change_mask, + show=show, + save_path=save_samp_path / f"{samp_id}_Signal_Plots.png") # 复制事件文件 到保存路径 sa_label_save_name = f"{samp_id}" + label_path.name @@ -240,22 +243,21 @@ def process_one_signal(samp_id, show=False): "SA_Score": score_mask, "Disable_Label": manual_disable_mask, "Resp_LowAmp_Label": resp_low_amp_mask if resp_low_amp_mask is not None else np.zeros(signal_second, dtype=int), - "Resp_Movement_Label": resp_movement_mask if resp_movement_mask is not None else np.zeros(signal_second, dtype=int), - "Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second, dtype=int), + "Resp_Movement_Label": resp_movement_mask if resp_movement_mask is not None else np.zeros(signal_second, + dtype=int), + "Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second, + dtype=int), "Bcg_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int), - "Bcg_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second, dtype=int), - "Bcg_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second, dtype=int) + "Bcg_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second, + dtype=int), + "Bcg_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second, + dtype=int) } mask_label_save_name = f"{samp_id}_Processed_Labels.csv" utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict) - - - - - if __name__ == '__main__': yaml_path = Path("./dataset_config/HYS_config.yaml") disable_df_path = Path("./排除区间.xlsx") @@ -280,4 +282,3 @@ if __name__ == '__main__': # print(f"Processing sample ID: {samp_id}") # process_one_signal(samp_id, show=False) # print(f"Finished processing sample ID: {samp_id}\n\n") - diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index e6d02dc..0f5be51 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -50,7 +50,7 @@ resp_movement_revise: min_duration_sec: 1 resp_amp_change: - mav_calc_window_sec: 1 + mav_calc_window_sec: 4 threshold_amplitude: 0.25 threshold_energy: 0.4 diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index d13e6cb..ee6b6ba 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -602,8 +602,8 @@ def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_lis if len(data_segment) % (mav_calc_window_sec * sampling_rate) != 0: data_segment = data_segment[:-(len(data_segment) % (mav_calc_window_sec * sampling_rate))] - mav_values = np.nanmax(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=0) - np.nanmin( - data_segment.reshape(-1, mav_calc_window_sec * sampling_rate)) + mav_values = np.nanmax(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=1) - np.nanmin( + data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=1) # 计算分位数 q20 = np.nanpercentile(mav_values, 20) q80 = np.nanpercentile(mav_values, 80) @@ -638,10 +638,12 @@ def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_lis movement_end = movement_list[i][1] # 避免过短的片段 - if movement_end - movement_start <= sampling_rate: # 小于1秒的片段不考虑 + if movement_end - movement_start <= 1: # 小于1秒的片段不考虑 if verbose: print( f"Skipping movement segment {i + 1} due to insufficient length. movement start: {movement_start}, movement end: {movement_end}") + pre_valid_start = pre_valid_start + pre_valid_end = next_valid_end continue # 计算前后片段的幅值和能量 @@ -665,7 +667,7 @@ def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_lis # print( # f"Significant position change detected between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}") if verbose: - print(f"Significant position change detected between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}") + print(f"Significant position change detected between segments {movement_start}s and {movement_end}:s left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right:{next_valid_start}to{next_valid_end} right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}") # 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示 position_changes[movement_start:movement_end] = 1 position_change_list.append(movement_list[i]) diff --git a/utils/operation_tools.py b/utils/operation_tools.py index 23205be..095dd5c 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -222,7 +222,7 @@ def event_mask_2_list(mask, event_true=True): normal_2_event = -1 _append = 1 mask_start = np.where(np.diff(mask, prepend=_append, append=_append) == normal_2_event)[0] - mask_end = np.where(np.diff(mask, prepend=_append, append=_append) == event_2_normal)[0] + 1 + mask_end = np.where(np.diff(mask, prepend=_append, append=_append) == event_2_normal)[0] event_list =[[start, end] for start, end in zip(mask_start, mask_end)] return event_list From 1a0761c6c8667fd32f3a842be8e4a7172f73bb51 Mon Sep 17 00:00:00 2001 From: marques Date: Tue, 11 Nov 2025 14:21:25 +0800 Subject: [PATCH 22/28] =?UTF-8?q?=E5=87=86=E5=A4=87=E6=9E=84=E5=BB=BA?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E9=9B=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- SHHS_process.py => dataset_builder/HYS_dataset.py | 0 dataset_builder/__init__.py | 0 HYS_process.py => event_mask_process/HYS_process.py | 10 +++++----- event_mask_process/SHHS_process.py | 0 ZD5Y_process.py => event_mask_process/ZD5Y_process.py | 4 ++-- event_mask_process/__init__.py | 0 6 files changed, 7 insertions(+), 7 deletions(-) rename SHHS_process.py => dataset_builder/HYS_dataset.py (100%) create mode 100644 dataset_builder/__init__.py rename HYS_process.py => event_mask_process/HYS_process.py (98%) create mode 100644 event_mask_process/SHHS_process.py rename ZD5Y_process.py => event_mask_process/ZD5Y_process.py (99%) create mode 100644 event_mask_process/__init__.py diff --git a/SHHS_process.py b/dataset_builder/HYS_dataset.py similarity index 100% rename from SHHS_process.py rename to dataset_builder/HYS_dataset.py diff --git a/dataset_builder/__init__.py b/dataset_builder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/HYS_process.py b/event_mask_process/HYS_process.py similarity index 98% rename from HYS_process.py rename to event_mask_process/HYS_process.py index 06156eb..1054dba 100644 --- a/HYS_process.py +++ b/event_mask_process/HYS_process.py @@ -160,7 +160,7 @@ def process_one_signal(samp_id, show=False): movement_list=resp_movement_position_list, sampling_rate=resp_fs, **resp_amp_change_conf, - verbose=True) + verbose=False) print( f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}") else: @@ -259,8 +259,8 @@ def process_one_signal(samp_id, show=False): if __name__ == '__main__': - yaml_path = Path("./dataset_config/HYS_config.yaml") - disable_df_path = Path("./排除区间.xlsx") + yaml_path = Path("../dataset_config/HYS_config.yaml") + disable_df_path = Path("../排除区间.xlsx") conf = utils.load_dataset_conf(yaml_path) select_ids = conf["select_ids"] @@ -276,8 +276,8 @@ if __name__ == '__main__': all_samp_disable_df = utils.read_disable_excel(disable_df_path) - process_one_signal(select_ids[9], show=True) - + process_one_signal(select_ids[6], show=True) + # # for samp_id in select_ids: # print(f"Processing sample ID: {samp_id}") # process_one_signal(samp_id, show=False) diff --git a/event_mask_process/SHHS_process.py b/event_mask_process/SHHS_process.py new file mode 100644 index 0000000..e69de29 diff --git a/ZD5Y_process.py b/event_mask_process/ZD5Y_process.py similarity index 99% rename from ZD5Y_process.py rename to event_mask_process/ZD5Y_process.py index 23dc4c8..8701d65 100644 --- a/ZD5Y_process.py +++ b/event_mask_process/ZD5Y_process.py @@ -252,8 +252,8 @@ def process_one_signal(samp_id, show=False): if __name__ == '__main__': - yaml_path = Path("./dataset_config/ZD5Y_config.yaml") - disable_df_path = Path("./排除区间.xlsx") + yaml_path = Path("../dataset_config/ZD5Y_config.yaml") + disable_df_path = Path("../排除区间.xlsx") conf = utils.load_dataset_conf(yaml_path) select_ids = conf["select_ids"] diff --git a/event_mask_process/__init__.py b/event_mask_process/__init__.py new file mode 100644 index 0000000..e69de29 From ed4205f5b8a12c40490b87244571d794721192e2 Mon Sep 17 00:00:00 2001 From: marques Date: Fri, 14 Nov 2025 18:39:50 +0800 Subject: [PATCH 23/28] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=A4=84=E7=90=86=E6=A8=A1=E5=9D=97=EF=BC=8C=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E4=BF=A1=E5=8F=B7=E6=A0=87=E5=87=86=E5=8C=96=E5=92=8C=E7=BB=98?= =?UTF-8?q?=E5=9B=BE=E5=8A=9F=E8=83=BD=EF=BC=8C=E9=87=8D=E6=9E=84=E9=83=A8?= =?UTF-8?q?=E5=88=86=E5=87=BD=E6=95=B0=E4=BB=A5=E6=8F=90=E9=AB=98=E5=8F=AF?= =?UTF-8?q?=E8=AF=BB=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 + dataset_builder/HYS_dataset.py | 166 ++++++++++++ dataset_config/HYS_config.yaml | 27 +- draw_tools/__init__.py | 3 +- draw_tools/draw_label.py | 230 ++++++++++++++++ event_mask_process/HYS_process.py | 114 +++----- event_mask_process/SHHS_process.py | 0 event_mask_process/ZD5Y_process.py | 277 -------------------- signal_method/__init__.py | 4 +- signal_method/normalize_method.py | 36 +++ signal_method/signal_process.py | 62 +++++ utils/HYS_FileReader.py | 134 +++++++++- utils/__init__.py | 9 +- utils/event_map.py | 37 ++- utils/{signal_process.py => filter_func.py} | 0 utils/operation_tools.py | 28 +- utils/split_method.py | 27 ++ 17 files changed, 774 insertions(+), 382 deletions(-) create mode 100644 draw_tools/draw_label.py delete mode 100644 event_mask_process/SHHS_process.py delete mode 100644 event_mask_process/ZD5Y_process.py create mode 100644 signal_method/normalize_method.py create mode 100644 signal_method/signal_process.py rename utils/{signal_process.py => filter_func.py} (100%) create mode 100644 utils/split_method.py diff --git a/.gitignore b/.gitignore index 2429834..1119c69 100644 --- a/.gitignore +++ b/.gitignore @@ -253,3 +253,5 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +output/* +!output/ diff --git a/dataset_builder/HYS_dataset.py b/dataset_builder/HYS_dataset.py index e69de29..50cca9a 100644 --- a/dataset_builder/HYS_dataset.py +++ b/dataset_builder/HYS_dataset.py @@ -0,0 +1,166 @@ +import sys +from pathlib import Path + +import os + +import numpy as np + +os.environ['DISPLAY'] = "localhost:10.0" + +sys.path.append(str(Path(__file__).resolve().parent.parent)) +project_root_path = Path(__file__).resolve().parent.parent + +import utils +import signal_method +import draw_tools +import shutil + + +def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False): + signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) + if not signal_path: + raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}") + signal_path = signal_path[0] + print(f"Processing OrgBCG_Sync signal file: {signal_path}") + + mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv") + print(f"mask_excel_path: {mask_excel_path}") + + event_mask, event_list = utils.read_mask_execl(mask_excel_path) + + bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float) + + bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs) + normalized_resp_signal = signal_method.normalize_resp_signal(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"]) + + + # 如果signal_data采样率过,进行降采样 + if signal_fs == 1000: + bcg_signal_notch = utils.downsample_signal_fast(original_signal=bcg_signal_notch, original_fs=signal_fs, target_fs=100) + bcg_signal_raw = utils.downsample_signal_fast(original_signal=bcg_signal_raw, original_fs=signal_fs, + target_fs=100) + signal_fs = 100 + + if bcg_fs == 1000: + bcg_signal = utils.downsample_signal_fast(original_signal=bcg_signal, original_fs=bcg_fs, target_fs=100) + bcg_fs = 100 + + # draw_tools.draw_signal_with_mask(samp_id=samp_id, + # signal_data=resp_signal, + # signal_fs=resp_fs, + # resp_data=normalized_resp_signal, + # resp_fs=resp_fs, + # bcg_data=bcg_signal, + # bcg_fs=bcg_fs, + # signal_disable_mask=event_mask["Disable_Label"], + # resp_low_amp_mask=event_mask["Resp_LowAmp_Label"], + # resp_movement_mask=event_mask["Resp_Movement_Label"], + # resp_change_mask=event_mask["Resp_AmpChange_Label"], + # resp_sa_mask=event_mask["SA_Label"], + # bcg_low_amp_mask=event_mask["BCG_LowAmp_Label"], + # bcg_movement_mask=event_mask["BCG_Movement_Label"], + # bcg_change_mask=event_mask["BCG_AmpChange_Label"], + # show=show, + # save_path=None) + + segment_list = utils.resp_split(dataset_config, event_mask, event_list) + print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}") + + + # 复制mask到processed_Labels文件夹 + save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv" + shutil.copyfile(mask_excel_path, save_mask_excel_path) + + # 复制SA Label_corrected.csv到processed_Labels文件夹 + sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_corrected.csv") + if sa_label_corrected_path.exists(): + save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv" + shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path) + else: + print(f"Warning: {sa_label_corrected_path} does not exist.") + + # 保存处理后的信号和截取的片段列表 + save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz" + save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz" + + bcg_data = { + "bcg_signal_notch": { + "name": "BCG_Signal_Notch", + "data": bcg_signal_notch, + "fs": signal_fs, + "length": len(bcg_signal_notch), + "second": len(bcg_signal_notch) // signal_fs + }, + "bcg_signal":{ + "name": "BCG_Signal_Raw", + "data": bcg_signal, + "fs": bcg_fs, + "length": len(bcg_signal), + "second": len(bcg_signal) // bcg_fs + }, + "resp_signal": { + "name": "Resp_Signal", + "data": normalized_resp_signal, + "fs": resp_fs, + "length": len(normalized_resp_signal), + "second": len(normalized_resp_signal) // resp_fs + } + } + + np.savez_compressed(save_signal_path, **bcg_data) + np.savez_compressed(save_segment_path, + segment_list=segment_list) + print(f"Saved processed signals to: {save_signal_path}") + print(f"Saved segment list to: {save_segment_path}") + + if draw_segment: + psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8]) + psg_data["HR"] = { + "name": "HR", + "data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"]), + "fs": psg_data["ECG_Sync"]["fs"], + "length": psg_data["ECG_Sync"]["length"], + "second": psg_data["ECG_Sync"]["second"] + } + + + psg_label = utils.read_psg_label(sa_label_corrected_path) + psg_event_mask, _ = utils.generate_event_mask(event_df=psg_label, signal_second=psg_data["ECG_Sync"]["second"], use_correct=False) + draw_tools.draw_psg_bcg_label(psg_data=psg_data, + psg_label=psg_event_mask, + bcg_data=bcg_data, + event_mask=event_mask, + segment_list=segment_list) + + +if __name__ == '__main__': + yaml_path = project_root_path / "dataset_config/HYS_config.yaml" + + conf = utils.load_dataset_conf(yaml_path) + select_ids = conf["select_ids"] + root_path = Path(conf["root_path"]) + mask_path = Path(conf["mask_save_path"]) + save_path = Path(conf["dataset_config"]["dataset_save_path"]) + dataset_config = conf["dataset_config"] + + save_processed_signal_path = save_path / "Signals" + save_processed_signal_path.mkdir(parents=True, exist_ok=True) + + save_segment_list_path = save_path / "Segments_List" + save_segment_list_path.mkdir(parents=True, exist_ok=True) + + save_processed_label_path = save_path / "Labels" + save_processed_label_path.mkdir(parents=True, exist_ok=True) + + print(f"select_ids: {select_ids}") + print(f"root_path: {root_path}") + print(f"save_path: {save_path}") + + org_signal_root_path = root_path / "OrgBCG_Aligned" + psg_signal_root_path = root_path / "PSG_Aligned" + + build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True) + # + # for samp_id in select_ids: + # print(f"Processing sample ID: {samp_id}") + # build_HYS_dataset_segment(samp_id, show=False) \ No newline at end of file diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index 0f5be51..a15e03d 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -11,7 +11,7 @@ select_ids: - 960 root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS -save_path: /mnt/disk_code/marques/dataprepare/output/HYS +mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS resp: downsample_fs_1: 100 @@ -43,11 +43,11 @@ resp_movement: min_duration_sec: 1 resp_movement_revise: - up_interval_multiplier: 3 - down_interval_multiplier: 2 - compare_intervals_sec: 30 - merge_gap_sec: 10 - min_duration_sec: 1 + up_interval_multiplier: 3 + down_interval_multiplier: 2 + compare_intervals_sec: 30 + merge_gap_sec: 10 + min_duration_sec: 1 resp_amp_change: mav_calc_window_sec: 4 @@ -56,7 +56,7 @@ resp_amp_change: bcg: - downsample_fs: 100 + downsample_fs: 100 bcg_filter: filter_type: bandpass @@ -73,8 +73,13 @@ bcg_low_amp: bcg_movement: - window_size_sec: 2 - stride_sec: - merge_gap_sec: 20 - min_duration_sec: 4 + window_size_sec: 2 + stride_sec: + merge_gap_sec: 20 + min_duration_sec: 4 + +dataset_config: + window_sec: 180 + stride_sec: 60 + dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset diff --git a/draw_tools/__init__.py b/draw_tools/__init__.py index 281cc34..3386b90 100644 --- a/draw_tools/__init__.py +++ b/draw_tools/__init__.py @@ -1 +1,2 @@ -from .draw_statics import draw_signal_with_mask \ No newline at end of file +from .draw_statics import draw_signal_with_mask +from .draw_label import draw_psg_bcg_label, draw_resp_label \ No newline at end of file diff --git a/draw_tools/draw_label.py b/draw_tools/draw_label.py new file mode 100644 index 0000000..605de63 --- /dev/null +++ b/draw_tools/draw_label.py @@ -0,0 +1,230 @@ +from matplotlib.axes import Axes +from matplotlib.gridspec import GridSpec +from matplotlib.colors import PowerNorm +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +import seaborn as sns +import numpy as np + +import utils + +# 添加with_prediction参数 + +psg_chn_name2ax = { + "SpO2": 0, + "Flow T": 1, + "Flow P": 2, + "Effort Tho": 3, + "Effort Abd": 4, + "HR": 5, + "resp": 6, + "bcg": 7, + "Stage": 8 +} + +resp_chn_name2ax = { + "resp": 0, + "bcg": 1, +} + + +def create_psg_bcg_figure(): + fig = plt.figure(figsize=(12, 8), dpi=100) + gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1]) + fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) + axes = [] + for i in range(9): + ax = fig.add_subplot(gs[i]) + axes.append(ax) + + axes[0].grid(True) + # axes[0].xaxis.set_major_formatter(Params.FORMATTER) + axes[0].set_ylim((85, 100)) + axes[0].tick_params(axis='x', colors="white") + + axes[1].grid(True) + # axes[1].xaxis.set_major_formatter(Params.FORMATTER) + axes[1].tick_params(axis='x', colors="white") + + axes[2].grid(True) + # axes[2].xaxis.set_major_formatter(Params.FORMATTER) + axes[2].tick_params(axis='x', colors="white") + + axes[3].grid(True) + # axes[3].xaxis.set_major_formatter(Params.FORMATTER) + axes[3].tick_params(axis='x', colors="white") + + axes[4].grid(True) + # axes[4].xaxis.set_major_formatter(Params.FORMATTER) + axes[4].tick_params(axis='x', colors="white") + + axes[5].grid(True) + axes[5].tick_params(axis='x', colors="white") + + axes[6].grid(True) + # axes[5].xaxis.set_major_formatter(Params.FORMATTER) + axes[6].tick_params(axis='x', colors="white") + + axes[7].grid(True) + # axes[6].xaxis.set_major_formatter(Params.FORMATTER) + axes[7].tick_params(axis='x', colors="white") + + axes[8].grid(True) + # axes[7].xaxis.set_major_formatter(Params.FORMATTER) + + return fig, axes + + +def create_resp_figure(): + fig = plt.figure(figsize=(12, 6), dpi=100) + gs = GridSpec(2, 1, height_ratios=[3, 2]) + fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) + axes = [] + for i in range(2): + ax = fig.add_subplot(gs[i]) + axes.append(ax) + + axes[0].grid(True) + # axes[0].xaxis.set_major_formatter(Params.FORMATTER) + axes[0].tick_params(axis='x', colors="white") + + axes[1].grid(True) + # axes[1].xaxis.set_major_formatter(Params.FORMATTER) + axes[1].tick_params(axis='x', colors="white") + + return fig, axes + + +def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None, + event_codes: list[int] = None, multi_labels=None): + signal_fs = signal_data["fs"] + chn_signal = signal_data["data"] + time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs) + ax.plot(time_axis, chn_signal[segment_start * signal_fs:segment_end * signal_fs], color='black', + label=signal_data["name"]) + if event_mask is not None: + if multi_labels is None and event_codes is not None: + for event_code in event_codes: + mask = event_mask[segment_start:segment_end].repeat(signal_fs) == event_code + y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * mask).astype(float) + np.place(y, y == 0, np.nan) + ax.plot(time_axis, y, color=utils.ColorCycle[event_code]) + elif multi_labels == "resp" and event_codes is not None: + ax.set_ylim(-6, 6) + # 建立第二个y轴坐标 + ax2 = ax.twinx() + ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, + color='blue', alpha=0.8, label='Low Amplitude Mask') + ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, + color='orange', alpha=0.8, label='Movement Mask') + ax2.plot(time_axis, event_mask["Resp_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3, + color='green', alpha=0.8, label='Amplitude Change Mask') + for event_code in event_codes: + sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code + score_mask = event_mask["SA_Score_Alpha"][segment_start:segment_end].repeat(signal_fs) + y = (sa_mask * score_mask).astype(float) + np.place(y, y == 0, np.nan) + ax2.plot(time_axis, y, color=utils.ColorCycle[event_code]) + ax2.set_ylim(-4, 5) + elif multi_labels == "bcg" and event_codes is not None: + # 建立第二个y轴坐标 + ax2 = ax.twinx() + ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, + color='blue', alpha=0.8, label='Low Amplitude Mask') + ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, + color='orange', alpha=0.8, label='Movement Mask') + ax2.plot(time_axis, event_mask["BCG_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3, + color='green', alpha=0.8, label='Amplitude Change Mask') + + ax2.set_ylim(-4, 4) + + ax.set_ylabel("Amplitude") + ax.legend(loc=1) + + +def plt_stage_on_ax(ax, stage_data, segment_start, segment_end): + stage_signal = stage_data["data"] + stage_fs = stage_data["fs"] + time_axis = np.linspace(segment_start / stage_fs, segment_end / stage_fs, segment_end - segment_start) + ax.plot(time_axis, stage_signal[segment_start:segment_end], color='black', label=stage_data["name"]) + ax.set_ylim(0, 6) + ax.set_yticks([1, 2, 3, 4, 5]) + ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"]) + ax.set_ylabel("Stage") + ax.legend(loc=1) + + +def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end): + spo2_signal = spo2_data["data"] + spo2_fs = spo2_data["fs"] + time_axis = np.linspace(segment_start / spo2_fs, segment_end / spo2_fs, segment_end - segment_start) + ax.plot(time_axis, spo2_signal[segment_start:segment_end], color='black', label=spo2_data["name"]) + + if spo2_signal[segment_start:segment_end].min() < 85: + ax.set_ylim((spo2_signal[segment_start:segment_end].min() - 5, 100)) + else: + ax.set_ylim((85, 100)) + ax.set_ylabel("SpO2 (%)") + ax.legend(loc=1) + + +def score_mask2alpha(score_mask): + alpha_mask = np.zeros_like(score_mask, dtype=float) + alpha_mask[score_mask <= 0] = 0 + alpha_mask[score_mask == 1] = 0.9 + alpha_mask[score_mask == 2] = 0.6 + alpha_mask[score_mask == 3] = 0.1 + return alpha_mask + + +def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list): + for mask in event_mask.keys(): + if mask.startswith("Resp_") or mask.endswith("BCG_"): + event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0) + + event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"]) + event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0) + + fig, axes = create_psg_bcg_figure() + for segment_start, segment_end in segment_list: + print(f"Drawing segment: {segment_start} to {segment_end} seconds") + for ax in axes: + ax.cla() + + plt_spo2_on_ax(axes[psg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end) + plt_stage_on_ax(axes[psg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end) + plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) + plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end, + event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end, + event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4]) + plt.show() + print(f"Finished drawing segment: {segment_start} to {segment_end} seconds") + + +def draw_resp_label(resp_data, resp_label, segment_list): + for mask in resp_label.keys(): + if mask.startswith("Resp_"): + resp_label[mask] = utils.none_to_nan_mask(resp_label[mask], 0) + + resp_label["Resp_Score_Alpha"] = score_mask2alpha(resp_label["Resp_Score"]) + resp_label["Resp_Label_Alpha"] = utils.none_to_nan_mask(resp_label["Resp_Label_Alpha"], 0) + + fig, axes = create_resp_figure() + for segment_start, segment_end in segment_list: + for ax in axes: + ax.cla() + + plt_signal_label_on_ax(axes[resp_chn_name2ax["resp"]], resp_data["resp_signal"], segment_start, segment_end, + resp_label, multi_labels="resp", event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[resp_chn_name2ax["bcg"]], resp_data["bcg_signal"], segment_start, segment_end, + resp_label, multi_labels="bcg", event_codes=[1, 2, 3, 4]) + plt.show() diff --git a/event_mask_process/HYS_process.py b/event_mask_process/HYS_process.py index 1054dba..dc33461 100644 --- a/event_mask_process/HYS_process.py +++ b/event_mask_process/HYS_process.py @@ -18,15 +18,19 @@ # 高幅值连续体动规则标定与剔除 # 手动标定不可用区间提剔除 """ - +import sys from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) +project_root_path = Path(__file__).resolve().parent.parent + import shutil import draw_tools import utils import numpy as np import signal_method import os -from matplotlib import pyplot as plt + os.environ['DISPLAY'] = "localhost:10.0" @@ -48,56 +52,14 @@ def process_one_signal(samp_id, show=False): save_samp_path = save_path / f"{samp_id}" save_samp_path.mkdir(parents=True, exist_ok=True) - signal_data_raw = utils.read_signal_txt(signal_path) - signal_length = len(signal_data_raw) - print(f"signal_length: {signal_length}") - signal_fs = int(signal_path.stem.split("_")[-1]) - print(f"signal_fs: {signal_fs}") - signal_second = signal_length // signal_fs - print(f"signal_second: {signal_second}") + signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, verbose=True) - # 根据采样率进行截断 - signal_data_raw = signal_data_raw[:signal_second * signal_fs] - - # 滤波 - # 50Hz陷波滤波器 - # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) - print("Applying 50Hz notch filter...") - signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) - - resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs) - resp_fs = conf["resp"]["downsample_fs_1"] - resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs) - resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20) - resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"], - low_cut=conf["resp_filter"]["low_cut"], - high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"], - sample_rate=resp_fs) - print("Begin plotting signal data...") - - # fig = plt.figure(figsize=(12, 8)) - # # 绘制三个图raw_data、resp_data_1、resp_data_2 - # ax0 = fig.add_subplot(3, 1, 1) - # ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue') - # ax0.set_title('Raw Signal Data') - # ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) - # ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange') - # ax1.set_title('Resp Data after Average Filtering') - # ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) - # ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green') - # ax2.set_title('Resp Data after Butterworth Filtering') - # plt.tight_layout() - # plt.show() - - bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"], - low_cut=conf["bcg_filter"]["low_cut"], - high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"], - sample_rate=signal_fs) + signal_data, resp_data, resp_fs, bcg_data, bcg_fs = signal_method.signal_filter_split(conf=conf, signal_data_raw=signal_data_raw, signal_fs=signal_fs) # 降采样 old_resp_fs = resp_fs resp_fs = conf["resp"]["downsample_fs_2"] - resp_data = utils.downsample_signal_fast(original_signal=resp_data_2, original_fs=old_resp_fs, target_fs=resp_fs) + resp_data = utils.downsample_signal_fast(original_signal=resp_data, original_fs=old_resp_fs, target_fs=resp_fs) bcg_fs = conf["bcg"]["downsample_fs"] bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs) @@ -214,26 +176,26 @@ def process_one_signal(samp_id, show=False): target_fs=100) signal_fs = 100 - draw_tools.draw_signal_with_mask(samp_id=samp_id, - signal_data=signal_data, - signal_fs=signal_fs, - resp_data=resp_data, - resp_fs=resp_fs, - bcg_data=bcg_data, - bcg_fs=bcg_fs, - signal_disable_mask=manual_disable_mask, - resp_low_amp_mask=resp_low_amp_mask, - resp_movement_mask=resp_movement_mask, - resp_change_mask=resp_amp_change_mask, - resp_sa_mask=event_mask, - bcg_low_amp_mask=bcg_low_amp_mask, - bcg_movement_mask=bcg_movement_mask, - bcg_change_mask=bcg_amp_change_mask, - show=show, - save_path=save_samp_path / f"{samp_id}_Signal_Plots.png") + draw_tools.draw_signal_with_mask(samp_id=samp_id, + signal_data=signal_data, + signal_fs=signal_fs, + resp_data=resp_data, + resp_fs=resp_fs, + bcg_data=bcg_data, + bcg_fs=bcg_fs, + signal_disable_mask=manual_disable_mask, + resp_low_amp_mask=resp_low_amp_mask, + resp_movement_mask=resp_movement_mask, + resp_change_mask=resp_amp_change_mask, + resp_sa_mask=event_mask, + bcg_low_amp_mask=bcg_low_amp_mask, + bcg_movement_mask=bcg_movement_mask, + bcg_change_mask=bcg_amp_change_mask, + show=show, + save_path=save_samp_path / f"{samp_id}_Signal_Plots.png") # 复制事件文件 到保存路径 - sa_label_save_name = f"{samp_id}" + label_path.name + sa_label_save_name = f"{samp_id}_" + label_path.name shutil.copyfile(label_path, save_samp_path / sa_label_save_name) # 新建一个dataframe,分别是秒数、SA标签,SA质量标签,禁用标签,Resp低幅值标签,Resp体动标签,Resp幅值突变标签,Bcg低幅值标签,Bcg体动标签,Bcg幅值突变标签 @@ -247,10 +209,10 @@ def process_one_signal(samp_id, show=False): dtype=int), "Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second, dtype=int), - "Bcg_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int), - "Bcg_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second, + "BCG_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int), + "BCG_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second, dtype=int), - "Bcg_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second, + "BCG_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second, dtype=int) } @@ -259,13 +221,13 @@ def process_one_signal(samp_id, show=False): if __name__ == '__main__': - yaml_path = Path("../dataset_config/HYS_config.yaml") - disable_df_path = Path("../排除区间.xlsx") + yaml_path = project_root_path / "dataset_config/HYS_config.yaml" + disable_df_path = project_root_path / "排除区间.xlsx" conf = utils.load_dataset_conf(yaml_path) select_ids = conf["select_ids"] root_path = Path(conf["root_path"]) - save_path = Path(conf["save_path"]) + save_path = Path(conf["mask_save_path"]) print(f"select_ids: {select_ids}") print(f"root_path: {root_path}") @@ -276,9 +238,9 @@ if __name__ == '__main__': all_samp_disable_df = utils.read_disable_excel(disable_df_path) - process_one_signal(select_ids[6], show=True) + # process_one_signal(select_ids[6], show=True) # - # for samp_id in select_ids: - # print(f"Processing sample ID: {samp_id}") - # process_one_signal(samp_id, show=False) - # print(f"Finished processing sample ID: {samp_id}\n\n") + for samp_id in select_ids: + print(f"Processing sample ID: {samp_id}") + process_one_signal(samp_id, show=False) + print(f"Finished processing sample ID: {samp_id}\n\n") diff --git a/event_mask_process/SHHS_process.py b/event_mask_process/SHHS_process.py deleted file mode 100644 index e69de29..0000000 diff --git a/event_mask_process/ZD5Y_process.py b/event_mask_process/ZD5Y_process.py deleted file mode 100644 index 8701d65..0000000 --- a/event_mask_process/ZD5Y_process.py +++ /dev/null @@ -1,277 +0,0 @@ -""" -本脚本完成对呼研所数据的处理,包含以下功能: -1. 数据读取与预处理 - 从传入路径中,进行数据和标签的读取,并进行初步的预处理 - 预处理包括为数据进行滤波、去噪等操作 -2. 数据清洗与异常值处理 -3. 输出清晰后的统计信息 -4. 数据保存 - 将处理后的数据保存到指定路径,便于后续使用 - 主要是保存切分后的数据位置和标签 -5. 可视化 - 提供数据处理前后的可视化对比,帮助理解数据变化 - 绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等 - -todo: 使用mask 屏蔽无用区间 - - -# 低幅值区间规则标定与剔除 -# 高幅值连续体动规则标定与剔除 -# 手动标定不可用区间提剔除 -""" - -from pathlib import Path -import shutil -import draw_tools -import utils -import numpy as np -import signal_method -import os -from matplotlib import pyplot as plt -os.environ['DISPLAY'] = "localhost:10.0" - -def process_one_signal(samp_id, show=False): - signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) - if not signal_path: - raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}") - signal_path = signal_path[0] - print(f"Processing OrgBCG_Sync signal file: {signal_path}") - - label_path = (label_root_path / f"{samp_id}").glob("SA Label_corrected.csv") - if not label_path: - raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}") - label_path = list(label_path)[0] - print(f"Processing Label_corrected file: {label_path}") - - signal_data_raw = utils.read_signal_txt(signal_path) - signal_length = len(signal_data_raw) - print(f"signal_length: {signal_length}") - signal_fs = int(signal_path.stem.split("_")[-1]) - print(f"signal_fs: {signal_fs}") - signal_second = signal_length // signal_fs - print(f"signal_second: {signal_second}") - - # 根据采样率进行截断 - signal_data_raw = signal_data_raw[:signal_second * signal_fs] - - # 滤波 - # 50Hz陷波滤波器 - # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) - print("Applying 50Hz notch filter...") - signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) - - resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs) - resp_fs = conf["resp"]["downsample_fs_1"] - resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs) - resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20) - resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"], - low_cut=conf["resp_filter"]["low_cut"], - high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"], - sample_rate=resp_fs) - print("Begin plotting signal data...") - - - # fig = plt.figure(figsize=(12, 8)) - # # 绘制三个图raw_data、resp_data_1、resp_data_2 - # ax0 = fig.add_subplot(3, 1, 1) - # ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue') - # ax0.set_title('Raw Signal Data') - # ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) - # ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange') - # ax1.set_title('Resp Data after Average Filtering') - # ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) - # ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green') - # ax2.set_title('Resp Data after Butterworth Filtering') - # plt.tight_layout() - # plt.show() - - bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"], - low_cut=conf["bcg_filter"]["low_cut"], - high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"], - sample_rate=signal_fs) - - # 降采样 - old_resp_fs = resp_fs - resp_fs = conf["resp"]["downsample_fs_2"] - resp_data = utils.downsample_signal_fast(original_signal=resp_data_2, original_fs=old_resp_fs, target_fs=resp_fs) - bcg_fs = conf["bcg"]["downsample_fs"] - bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs) - - label_data = utils.read_label_csv(path=label_path) - event_mask, score_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data) - - manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[ - all_samp_disable_df["id"] == samp_id]) - print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}") - - # 分析Resp的低幅值区间 - resp_low_amp_conf = conf.get("resp_low_amp", None) - if resp_low_amp_conf is not None: - resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal( - signal_data=resp_data, - sampling_rate=resp_fs, - **resp_low_amp_conf - ) - print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}, count_low_amp_positions: {len(resp_low_amp_position_list)}") - else: - resp_low_amp_mask, resp_low_amp_position_list = None, None - print("resp_low_amp_mask is None") - - # 分析Resp的高幅值伪迹区间 - resp_movement_conf = conf.get("resp_movement", None) - if resp_movement_conf is not None: - raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement( - signal_data=resp_data, - sampling_rate=resp_fs, - **resp_movement_conf - ) - print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") - else: - resp_movement_mask, resp_movement_position_list = None, None - print("resp_movement_mask is None") - - resp_movement_revise_conf = conf.get("resp_movement_revise", None) - if resp_movement_mask is not None and resp_movement_revise_conf is not None: - resp_movement_mask, resp_movement_position_list = signal_method.movement_revise( - signal_data=resp_data, - movement_mask=resp_movement_mask, - movement_list=resp_movement_position_list, - sampling_rate=resp_fs, - **resp_movement_revise_conf, - verbose=False - ) - print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") - else: - print("resp_movement_mask revise is skipped") - - - # 分析Resp的幅值突变区间 - resp_amp_change_conf = conf.get("resp_amp_change", None) - if resp_amp_change_conf is not None and resp_movement_mask is not None: - resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v3( - signal_data=resp_data, - movement_mask=resp_movement_mask, - movement_list=resp_movement_position_list, - sampling_rate=resp_fs, - **resp_amp_change_conf) - print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}") - else: - resp_amp_change_mask = None - print("amp_change_mask is None") - - - - # 分析Bcg的低幅值区间 - bcg_low_amp_conf = conf.get("bcg_low_amp", None) - if bcg_low_amp_conf is not None: - bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal( - signal_data=bcg_data, - sampling_rate=bcg_fs, - **bcg_low_amp_conf - ) - print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}, count_low_amp_positions: {len(bcg_low_amp_position_list)}") - else: - bcg_low_amp_mask, bcg_low_amp_position_list = None, None - print("bcg_low_amp_mask is None") - # 分析Bcg的高幅值伪迹区间 - bcg_movement_conf = conf.get("bcg_movement", None) - if bcg_movement_conf is not None: - raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement( - signal_data=bcg_data, - sampling_rate=bcg_fs, - **bcg_movement_conf - ) - print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}, count_movement_positions: {len(bcg_movement_position_list)}") - else: - bcg_movement_mask = None - print("bcg_movement_mask is None") - # 分析Bcg的幅值突变区间 - if bcg_movement_mask is not None: - bcg_amp_change_mask, bcg_amp_change_list = signal_method.position_based_sleep_recognition_v2( - signal_data=bcg_data, - movement_mask=bcg_movement_mask, - sampling_rate=bcg_fs) - print(f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}, count_amp_change_positions: {len(bcg_amp_change_list)}") - else: - bcg_amp_change_mask = None - print("bcg_amp_change_mask is None") - - - # 如果signal_data采样率过,进行降采样 - if signal_fs == 1000: - signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100) - signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs, target_fs=100) - signal_fs = 100 - if show: - draw_tools.draw_signal_with_mask(samp_id=samp_id, - signal_data=signal_data, - signal_fs=signal_fs, - resp_data=resp_data, - resp_fs=resp_fs, - bcg_data=bcg_data, - bcg_fs=bcg_fs, - signal_disable_mask=manual_disable_mask, - resp_low_amp_mask=resp_low_amp_mask, - resp_movement_mask=resp_movement_mask, - resp_change_mask=resp_amp_change_mask, - resp_sa_mask=event_mask, - bcg_low_amp_mask=bcg_low_amp_mask, - bcg_movement_mask=bcg_movement_mask, - bcg_change_mask=bcg_amp_change_mask) - - - # 保存处理后的数据和标签 - save_samp_path = save_path / f"{samp_id}" - save_samp_path.mkdir(parents=True, exist_ok=True) - - # 复制事件文件 到保存路径 - sa_label_save_name = f"{samp_id}" + label_path.name - shutil.copyfile(label_path, save_samp_path / sa_label_save_name) - - # 新建一个dataframe,分别是秒数、SA标签,SA质量标签,禁用标签,Resp低幅值标签,Resp体动标签,Resp幅值突变标签,Bcg低幅值标签,Bcg体动标签,Bcg幅值突变标签 - save_dict = { - "Second": np.arange(signal_second), - "SA_Label": event_mask, - "SA_Score": score_mask, - "Disable_Label": manual_disable_mask, - "Resp_LowAmp_Label": resp_low_amp_mask if resp_low_amp_mask is not None else np.zeros(signal_second, dtype=int), - "Resp_Movement_Label": resp_movement_mask if resp_movement_mask is not None else np.zeros(signal_second, dtype=int), - "Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second, dtype=int), - "Bcg_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int), - "Bcg_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second, dtype=int), - "Bcg_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second, dtype=int) - } - - mask_label_save_name = f"{samp_id}_Processed_Labels.csv" - utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict) - - - - - - - -if __name__ == '__main__': - yaml_path = Path("../dataset_config/ZD5Y_config.yaml") - disable_df_path = Path("../排除区间.xlsx") - - conf = utils.load_dataset_conf(yaml_path) - select_ids = conf["select_ids"] - root_path = Path(conf["root_path"]) - save_path = Path(conf["save_path"]) - - print(f"select_ids: {select_ids}") - print(f"root_path: {root_path}") - print(f"save_path: {save_path}") - - org_signal_root_path = root_path / "OrgBCG_Aligned" - label_root_path = root_path / "Label" - - all_samp_disable_df = utils.read_disable_excel(disable_df_path) - - process_one_signal(select_ids[1], show=True) - - # for samp_id in select_ids: - # print(f"Processing sample ID: {samp_id}") - # process_one_signal(samp_id, show=False) - # print(f"Finished processing sample ID: {samp_id}\n\n") \ No newline at end of file diff --git a/signal_method/__init__.py b/signal_method/__init__.py index eaea6ea..7ce8cdb 100644 --- a/signal_method/__init__.py +++ b/signal_method/__init__.py @@ -1,4 +1,6 @@ from .rule_base_event import detect_low_amplitude_signal, detect_movement from .rule_base_event import position_based_sleep_recognition_v2, position_based_sleep_recognition_v3 from .rule_base_event import movement_revise -from .time_metrics import calc_mav_by_slide_windows \ No newline at end of file +from .time_metrics import calc_mav_by_slide_windows +from .signal_process import signal_filter_split, rpeak2hr +from .normalize_method import normalize_resp_signal \ No newline at end of file diff --git a/signal_method/normalize_method.py b/signal_method/normalize_method.py new file mode 100644 index 0000000..8ed89ce --- /dev/null +++ b/signal_method/normalize_method.py @@ -0,0 +1,36 @@ +import utils +import pandas as pd +import numpy as np +from scipy import signal + +def normalize_resp_signal(resp_signal: np.ndarray, resp_fs, movement_mask, enable_list): + # 根据呼吸信号的幅值改变区间,对每段进行Z-Score标准化 + normalized_resp_signal = np.zeros_like(resp_signal) + # 全部填成nan + normalized_resp_signal[:] = np.nan + + resp_signal_no_movement = resp_signal.copy() + + + resp_signal_no_movement[np.array(movement_mask == 1).repeat(resp_fs)] = np.nan + + + for i in range(len(enable_list)): + enable_start = enable_list[i][0] * resp_fs + enable_end = enable_list[i][1] * resp_fs + segment = resp_signal_no_movement[enable_start:enable_end] + + # print(f"Normalizing segment {i+1}/{len(enable_list)}: start={enable_start}, end={enable_end}, length={len(segment)}") + + segment_mean = np.nanmean(segment) + segment_std = np.nanstd(segment) + if segment_std == 0: + raise ValueError(f"segment_std is zero! segment_start: {enable_start}, segment_end: {enable_end}") + + # 同下一个enable区间的体动一起进行标准化 + if i <= len(enable_list) - 2: + enable_end = enable_list[i + 1][0] * resp_fs + raw_segment = resp_signal[enable_start:enable_end] + normalized_resp_signal[enable_start:enable_end] = (raw_segment - segment_mean) / segment_std + + return normalized_resp_signal diff --git a/signal_method/signal_process.py b/signal_method/signal_process.py new file mode 100644 index 0000000..eaaea59 --- /dev/null +++ b/signal_method/signal_process.py @@ -0,0 +1,62 @@ +import numpy as np + +import utils + +def signal_filter_split(conf, signal_data_raw, signal_fs): + # 滤波 + # 50Hz陷波滤波器 + # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) + print("Applying 50Hz notch filter...") + signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) + + resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs) + resp_fs = conf["resp"]["downsample_fs_1"] + resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs) + resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20) + resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"], + low_cut=conf["resp_filter"]["low_cut"], + high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"], + sample_rate=resp_fs) + print("Begin plotting signal data...") + + # fig = plt.figure(figsize=(12, 8)) + # # 绘制三个图raw_data、resp_data_1、resp_data_2 + # ax0 = fig.add_subplot(3, 1, 1) + # ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue') + # ax0.set_title('Raw Signal Data') + # ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) + # ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange') + # ax1.set_title('Resp Data after Average Filtering') + # ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) + # ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green') + # ax2.set_title('Resp Data after Butterworth Filtering') + # plt.tight_layout() + # plt.show() + + bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"], + low_cut=conf["bcg_filter"]["low_cut"], + high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"], + sample_rate=signal_fs) + + + return signal_data, resp_data_2, resp_fs, bcg_data, signal_fs + + + +def rpeak2hr(rpeak_indices, signal_length): + hr_signal = np.zeros(signal_length) + for i in range(1, len(rpeak_indices)): + rri = rpeak_indices[i] - rpeak_indices[i - 1] + if rri == 0: + continue + hr = 60 * 1000 / rri # 心率,单位:bpm + if hr > 120: + hr = 120 + elif hr < 30: + hr = 30 + hr_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = hr + # 填充最后一个R峰之后的心率值 + if len(rpeak_indices) > 1: + hr_signal[rpeak_indices[-1]:] = hr_signal[rpeak_indices[-2]] + return hr_signal + diff --git a/utils/HYS_FileReader.py b/utils/HYS_FileReader.py index dd65bab..6f7d95a 100644 --- a/utils/HYS_FileReader.py +++ b/utils/HYS_FileReader.py @@ -1,9 +1,11 @@ from pathlib import Path from typing import Union +import utils +from .event_map import N2Chn import numpy as np import pandas as pd - +from .operation_tools import event_mask_2_list # 尝试导入 Polars try: import polars as pl @@ -13,15 +15,17 @@ except ImportError: HAS_POLARS = False -def read_signal_txt(path: Union[str, Path]) -> np.ndarray: +def read_signal_txt(path: Union[str, Path], dtype, verbose=True, is_peak=False): """ Read a txt file and return the first column as a numpy array. Args: - path (str | Path): Path to the txt file. - + :param path: + :param verbose: + :param dtype: Returns: np.ndarray: The first column of the txt file as a numpy array. + """ path = Path(path) if not path.exists(): @@ -29,10 +33,30 @@ def read_signal_txt(path: Union[str, Path]) -> np.ndarray: if HAS_POLARS: df = pl.read_csv(path, has_header=False, infer_schema_length=0) - return df[:, 0].to_numpy().astype(float) + signal_data_raw = df[:, 0].to_numpy().astype(dtype) else: - df = pd.read_csv(path, header=None, dtype=float) - return df.iloc[:, 0].to_numpy() + df = pd.read_csv(path, header=None, dtype=dtype) + signal_data_raw = df.iloc[:, 0].to_numpy() + + signal_original_length = len(signal_data_raw) + signal_fs = int(path.stem.split("_")[-1]) + if is_peak: + signal_second = None + signal_length = None + else: + signal_second = signal_original_length // signal_fs + # 根据采样率进行截断 + signal_data_raw = signal_data_raw[:signal_second * signal_fs] + signal_length = len(signal_data_raw) + + if verbose: + print(f"Signal file read from {path}") + print(f"signal_fs: {signal_fs}") + print(f"signal_original_length: {signal_original_length}") + print(f"signal_after_cut_off_length: {signal_length}") + print(f"signal_second: {signal_second}") + + return signal_data_raw, signal_length, signal_fs, signal_second def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame: @@ -172,3 +196,99 @@ def read_disable_excel(path: Union[str, Path]) -> pd.DataFrame: df["start"] = df["start"].astype(int) df["end"] = df["end"].astype(int) return df + + +def read_mask_execl(path: Union[str, Path]): + """ + Read an Excel file and return the mask as a numpy array. + Args: + path (str | Path): Path to the Excel file. + Returns: + np.ndarray: The mask as a numpy array. + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + df = pd.read_csv(path) + event_mask = df.to_dict(orient="list") + for key in event_mask: + event_mask[key] = np.array(event_mask[key]) + + event_list = {"RespAmpChangeSegment": event_mask_2_list(1 - event_mask["Resp_AmpChange_Label"]), + "BCGAmpChangeSegment": event_mask_2_list(1 - event_mask["BCG_AmpChange_Label"]), + "EnableSegment": event_mask_2_list(1 - event_mask["Disable_Label"]),} + + + return event_mask, event_list + + + +def read_psg_channel(path_str: Union[str, Path], channel_number: list[int]): + """ + 读取PSG文件中特定通道的数据。 + + 参数: + path_str (Union[str, Path]): 存放PSG文件的文件夹路径。 + channel_name (str): 需要读取的通道名称。 + 返回: + np.ndarray: 指定通道的数据数组。 + """ + path = Path(path_str) + if not path.exists(): + raise FileNotFoundError(f"PSG Dir not found: {path}") + + if not path.is_dir(): + raise NotADirectoryError(f"PSG Dir not found: {path}") + channel_data = {} + # 遍历检查通道对应的文件是否存在 + for ch_id in channel_number: + ch_name = N2Chn[ch_id] + ch_path = list(path.glob(f"{ch_name}*.txt")) + + if not any(ch_path): + raise FileNotFoundError(f"PSG Channel file not found: {ch_path}") + + if len(ch_path) > 1: + print(f"Warning!!! PSG Channel file more than one: {ch_path}") + + if ch_id == 8: + # sleep stage 特例 读取为整数 + ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=str, verbose=True) + # 转换为整数数组 + for stage_str, stage_number in utils.Stage2N.items(): + np.place(ch_signal, ch_signal == stage_str, stage_number) + ch_signal = ch_signal.astype(int) + elif ch_id == 1: + # Rpeak 特例 读取为整数 + ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=int, verbose=True, is_peak=True) + else: + ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=float, verbose=True) + channel_data[ch_name] = { + "name": ch_name, + "path": ch_path[0], + "data": ch_signal, + "length": ch_length, + "fs": ch_fs, + "second": ch_second + } + + return channel_data + + +def read_psg_label(path: Union[str, Path], verbose=True): + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + # 直接用pandas读取 包含中文 故指定编码 + df = pd.read_csv(path, encoding="gbk") + if verbose: + print(f"Label file read from {path}, number of rows: {len(df)}") + + # 丢掉Event type为空的行 + df = df.dropna(subset=["Event type"], how='all').reset_index(drop=True) + + return df + + diff --git a/utils/__init__.py b/utils/__init__.py index 68e7772..362297e 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,7 +1,10 @@ -from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel +from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel, read_psg_label from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list from .operation_tools import merge_short_gaps, remove_short_durations from .operation_tools import collect_values from .operation_tools import save_process_label -from .event_map import E2N -from .signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel \ No newline at end of file +from .operation_tools import none_to_nan_mask +from .split_method import resp_split +from .HYS_FileReader import read_mask_execl, read_psg_channel +from .event_map import E2N, N2Chn, Stage2N, ColorCycle +from .filter_func import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel \ No newline at end of file diff --git a/utils/event_map.py b/utils/event_map.py index c85a027..20c6c58 100644 --- a/utils/event_map.py +++ b/utils/event_map.py @@ -4,4 +4,39 @@ E2N = { "Central apnea": 2, "Obstructive apnea": 3, "Mixed apnea": 4 -} \ No newline at end of file +} + +N2Chn = { + 1: "Rpeak", + 2: "ECG_Sync", + 3: "Effort Tho", + 4: "Effort Abd", + 5: "Flow P", + 6: "Flow T", + 7: "SpO2", + 8: "5_class" +} + +Stage2N = { + "W": 5, + "N1": 3, + "N2": 2, + "N3": 1, + "R": 4, +} + +# 设定事件和其对应颜色 +# event_code color event +# 0 黑色 背景 +# 1 粉色 低通气 +# 2 蓝色 中枢性 +# 3 红色 阻塞型 +# 4 灰色 混合型 +# 5 绿色 血氧饱和度下降 +# 6 橙色 大体动 +# 7 橙色 小体动 +# 8 橙色 深呼吸 +# 9 橙色 脉冲体动 +# 10 橙色 无效片段 +ColorCycle = ["black", "pink", "blue", "red", "silver", "green", "orange", "orange", "orange", "orange", + "orange"] \ No newline at end of file diff --git a/utils/signal_process.py b/utils/filter_func.py similarity index 100% rename from utils/signal_process.py rename to utils/filter_func.py diff --git a/utils/operation_tools.py b/utils/operation_tools.py index 095dd5c..866029d 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -198,16 +198,25 @@ def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray: return disable_mask -def generate_event_mask(signal_second: int, event_df): +def generate_event_mask(signal_second: int, event_df, use_correct=True): event_mask = np.zeros(signal_second, dtype=int) score_mask = np.zeros(signal_second, dtype=int) + if use_correct: + start_name = "correct_Start" + end_name = "correct_End" + event_type_name = "correct_EventsType" + else: + start_name = "Start" + end_name = "End" + event_type_name = "Event type" + # 剔除start = -1 的行 - event_df = event_df[event_df["correct_Start"] >= 0] + event_df = event_df[event_df[start_name] >= 0] for _, row in event_df.iterrows(): - start = row["correct_Start"] - end = row["correct_End"] + 1 - event_mask[start:end] = E2N[row["correct_EventsType"]] + start = row[start_name] + end = row[end_name] + 1 + event_mask[start:end] = E2N[row[event_type_name]] score_mask[start:end] = row["score"] return event_mask, score_mask @@ -243,3 +252,12 @@ def collect_values(arr: np.ndarray, index: int, step: int, limit: int, mask=None def save_process_label(save_path: Path, save_dict: dict): save_df = pd.DataFrame(save_dict) save_df.to_csv(save_path, index=False) + +def none_to_nan_mask(mask, ref): + """将None转换为与ref形状相同的nan掩码""" + if mask is None: + return np.full_like(ref, np.nan) + else: + # 将mask中的0替换为nan,其他的保持 + mask = np.where(mask == 0, np.nan, mask) + return mask \ No newline at end of file diff --git a/utils/split_method.py b/utils/split_method.py new file mode 100644 index 0000000..e9c151b --- /dev/null +++ b/utils/split_method.py @@ -0,0 +1,27 @@ + + + +def resp_split(dataset_config, event_mask, event_list): + # 提取体动区间和呼吸低幅值区间 + enable_list = event_list["EnableSegment"] + + # 读取数据集配置 + window_sec = dataset_config["window_sec"] + stride_sec = dataset_config["stride_sec"] + + segment_list = [] + + # 遍历每个enable区间, 如果最后一个窗口不足stride的1/2,则舍弃,否则以enable_end为结尾截取一个窗口 + for enable_start, enable_end in enable_list: + current_start = enable_start + while current_start + window_sec <= enable_end: + segment_list.append((current_start, current_start + window_sec)) + current_start += stride_sec + # 检查最后一个窗口是否需要添加 + if (enable_end - current_start >= stride_sec / 2) and (enable_end - current_start >= window_sec): + segment_list.append((enable_end - window_sec, enable_end)) + + return segment_list + + + From 19d476d489f7624b485dff5db7b74f3fc9145a69 Mon Sep 17 00:00:00 2001 From: marques Date: Mon, 17 Nov 2025 08:05:42 +0800 Subject: [PATCH 24/28] =?UTF-8?q?=E4=BC=98=E5=8C=96=E7=BB=98=E5=9B=BE?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=EF=BC=8C=E5=A2=9E=E5=8A=A0=E5=8F=8Cy?= =?UTF-8?q?=E8=BD=B4=E6=94=AF=E6=8C=81=E5=B9=B6=E8=B0=83=E6=95=B4=E5=9B=BE?= =?UTF-8?q?=E5=83=8F=E4=BF=9D=E5=AD=98=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dataset_builder/HYS_dataset.py | 17 ++++--- dataset_config/HYS_config.yaml | 1 + draw_tools/draw_label.py | 82 +++++++++++++++++++--------------- 3 files changed, 59 insertions(+), 41 deletions(-) diff --git a/dataset_builder/HYS_dataset.py b/dataset_builder/HYS_dataset.py index 50cca9a..4442936 100644 --- a/dataset_builder/HYS_dataset.py +++ b/dataset_builder/HYS_dataset.py @@ -130,7 +130,8 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False): psg_label=psg_event_mask, bcg_data=bcg_data, event_mask=event_mask, - segment_list=segment_list) + segment_list=segment_list, + save_path=visual_path / f"{samp_id}") if __name__ == '__main__': @@ -141,8 +142,11 @@ if __name__ == '__main__': root_path = Path(conf["root_path"]) mask_path = Path(conf["mask_save_path"]) save_path = Path(conf["dataset_config"]["dataset_save_path"]) + visual_path = Path(conf["dataset_config"]["dataset_visual_path"]) dataset_config = conf["dataset_config"] + visual_path.mkdir(parents=True, exist_ok=True) + save_processed_signal_path = save_path / "Signals" save_processed_signal_path.mkdir(parents=True, exist_ok=True) @@ -155,12 +159,13 @@ if __name__ == '__main__': print(f"select_ids: {select_ids}") print(f"root_path: {root_path}") print(f"save_path: {save_path}") + print(f"visual_path: {visual_path}") org_signal_root_path = root_path / "OrgBCG_Aligned" psg_signal_root_path = root_path / "PSG_Aligned" - build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True) - # - # for samp_id in select_ids: - # print(f"Processing sample ID: {samp_id}") - # build_HYS_dataset_segment(samp_id, show=False) \ No newline at end of file + # build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True) + + for samp_id in select_ids: + print(f"Processing sample ID: {samp_id}") + build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) \ No newline at end of file diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index a15e03d..80b8165 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -83,3 +83,4 @@ dataset_config: window_sec: 180 stride_sec: 60 dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset + dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset/visualization diff --git a/draw_tools/draw_label.py b/draw_tools/draw_label.py index 605de63..134202e 100644 --- a/draw_tools/draw_label.py +++ b/draw_tools/draw_label.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt import matplotlib.patches as mpatches import seaborn as sns import numpy as np - +from tqdm.rich import tqdm import utils # 添加with_prediction参数 @@ -19,7 +19,9 @@ psg_chn_name2ax = { "HR": 5, "resp": 6, "bcg": 7, - "Stage": 8 + "Stage": 8, + "resp_twinx": 9, + "bcg_twinx": 10, } resp_chn_name2ax = { @@ -29,7 +31,7 @@ resp_chn_name2ax = { def create_psg_bcg_figure(): - fig = plt.figure(figsize=(12, 8), dpi=100) + fig = plt.figure(figsize=(12, 8), dpi=200) gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1]) fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) axes = [] @@ -37,39 +39,41 @@ def create_psg_bcg_figure(): ax = fig.add_subplot(gs[i]) axes.append(ax) - axes[0].grid(True) + axes[psg_chn_name2ax["SpO2"]].grid(True) # axes[0].xaxis.set_major_formatter(Params.FORMATTER) - axes[0].set_ylim((85, 100)) - axes[0].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["SpO2"]].set_ylim((85, 100)) + axes[psg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white") - axes[1].grid(True) + axes[psg_chn_name2ax["Flow T"]].grid(True) # axes[1].xaxis.set_major_formatter(Params.FORMATTER) - axes[1].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white") - axes[2].grid(True) + axes[psg_chn_name2ax["Flow P"]].grid(True) # axes[2].xaxis.set_major_formatter(Params.FORMATTER) - axes[2].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white") - axes[3].grid(True) + axes[psg_chn_name2ax["Effort Tho"]].grid(True) # axes[3].xaxis.set_major_formatter(Params.FORMATTER) - axes[3].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white") - axes[4].grid(True) + axes[psg_chn_name2ax["Effort Abd"]].grid(True) # axes[4].xaxis.set_major_formatter(Params.FORMATTER) - axes[4].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white") - axes[5].grid(True) - axes[5].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["HR"]].grid(True) + axes[psg_chn_name2ax["HR"]].tick_params(axis='x', colors="white") - axes[6].grid(True) + axes[psg_chn_name2ax["resp"]].grid(True) + axes[psg_chn_name2ax["resp"]].tick_params(axis='x', colors="white") + axes.append(axes[psg_chn_name2ax["resp"]].twinx()) + + axes[psg_chn_name2ax["bcg"]].grid(True) # axes[5].xaxis.set_major_formatter(Params.FORMATTER) - axes[6].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white") + axes.append(axes[psg_chn_name2ax["bcg"]].twinx()) - axes[7].grid(True) - # axes[6].xaxis.set_major_formatter(Params.FORMATTER) - axes[7].tick_params(axis='x', colors="white") - axes[8].grid(True) + axes[psg_chn_name2ax["Stage"]].grid(True) # axes[7].xaxis.set_major_formatter(Params.FORMATTER) return fig, axes @@ -96,7 +100,7 @@ def create_resp_figure(): def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None, - event_codes: list[int] = None, multi_labels=None): + event_codes: list[int] = None, multi_labels=None, ax2: Axes = None): signal_fs = signal_data["fs"] chn_signal = signal_data["data"] time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs) @@ -112,7 +116,7 @@ def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, ev elif multi_labels == "resp" and event_codes is not None: ax.set_ylim(-6, 6) # 建立第二个y轴坐标 - ax2 = ax.twinx() + ax2.cla() ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, color='blue', alpha=0.8, label='Low Amplitude Mask') ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, @@ -122,13 +126,15 @@ def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, ev for event_code in event_codes: sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code score_mask = event_mask["SA_Score_Alpha"][segment_start:segment_end].repeat(signal_fs) - y = (sa_mask * score_mask).astype(float) + # y = (sa_mask * score_mask).astype(float) + y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * sa_mask).astype(float) np.place(y, y == 0, np.nan) - ax2.plot(time_axis, y, color=utils.ColorCycle[event_code]) + ax.plot(time_axis, y, color=utils.ColorCycle[event_code]) + ax2.plot(time_axis, score_mask, color="orange") ax2.set_ylim(-4, 5) elif multi_labels == "bcg" and event_codes is not None: # 建立第二个y轴坐标 - ax2 = ax.twinx() + ax2.cla() ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, color='blue', alpha=0.8, label='Low Amplitude Mask') ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, @@ -177,17 +183,19 @@ def score_mask2alpha(score_mask): return alpha_mask -def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list): +def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None): + if save_path is not None: + save_path.mkdir(parents=True, exist_ok=True) + for mask in event_mask.keys(): - if mask.startswith("Resp_") or mask.endswith("BCG_"): + if mask.startswith("Resp_") or mask.startswith("BCG_"): event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0) event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"]) event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0) fig, axes = create_psg_bcg_figure() - for segment_start, segment_end in segment_list: - print(f"Drawing segment: {segment_start} to {segment_end} seconds") + for segment_start, segment_end in tqdm(segment_list): for ax in axes: ax.cla() @@ -203,13 +211,17 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list): psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end, - event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4]) + event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["resp_twinx"]]) plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end, - event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4]) - plt.show() - print(f"Finished drawing segment: {segment_start} to {segment_end} seconds") + event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["bcg_twinx"]]) + + if save_path is not None: + fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png") + tqdm.write(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}") + # print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}") + def draw_resp_label(resp_data, resp_label, segment_list): for mask in resp_label.keys(): if mask.startswith("Resp_"): From d829f3e43d6ebc7d6e6aaa3ed215dedee23bb63a Mon Sep 17 00:00:00 2001 From: marques Date: Mon, 17 Nov 2025 09:40:43 +0800 Subject: [PATCH 25/28] =?UTF-8?q?=E4=BF=AE=E6=AD=A3SA=5FScore=E7=9A=84?= =?UTF-8?q?=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E4=B8=8D=E5=BF=85=E8=A6=81=E7=9A=84Alpha=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=EF=BC=8C=E4=BC=98=E5=8C=96=E6=95=B0=E6=8D=AE=E6=8E=A9=E7=A0=81?= =?UTF-8?q?=E7=94=9F=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- draw_tools/draw_label.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/draw_tools/draw_label.py b/draw_tools/draw_label.py index 134202e..6b0dcb9 100644 --- a/draw_tools/draw_label.py +++ b/draw_tools/draw_label.py @@ -125,7 +125,7 @@ def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, ev color='green', alpha=0.8, label='Amplitude Change Mask') for event_code in event_codes: sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code - score_mask = event_mask["SA_Score_Alpha"][segment_start:segment_end].repeat(signal_fs) + score_mask = event_mask["SA_Score"][segment_start:segment_end].repeat(signal_fs) # y = (sa_mask * score_mask).astype(float) y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * sa_mask).astype(float) np.place(y, y == 0, np.nan) @@ -191,8 +191,10 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, if mask.startswith("Resp_") or mask.startswith("BCG_"): event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0) - event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"]) - event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0) + event_mask["SA_Score"] = utils.none_to_nan_mask(event_mask["SA_Score"], 0) + + # event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"]) + # event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0) fig, axes = create_psg_bcg_figure() for segment_start, segment_end in tqdm(segment_list): @@ -227,8 +229,8 @@ def draw_resp_label(resp_data, resp_label, segment_list): if mask.startswith("Resp_"): resp_label[mask] = utils.none_to_nan_mask(resp_label[mask], 0) - resp_label["Resp_Score_Alpha"] = score_mask2alpha(resp_label["Resp_Score"]) - resp_label["Resp_Label_Alpha"] = utils.none_to_nan_mask(resp_label["Resp_Label_Alpha"], 0) + # resp_label["Resp_Score_Alpha"] = score_mask2alpha(resp_label["Resp_Score"]) + # resp_label["Resp_Label_Alpha"] = utils.none_to_nan_mask(resp_label["Resp_Label_Alpha"], 0) fig, axes = create_resp_figure() for segment_start, segment_end in segment_list: From d09ffecf7031f08d03bd7bab987336b55188046e Mon Sep 17 00:00:00 2001 From: marques Date: Tue, 30 Dec 2025 16:54:45 +0800 Subject: [PATCH 26/28] =?UTF-8?q?=E4=BF=AE=E6=AD=A3SA=5FScore=E7=9A=84?= =?UTF-8?q?=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E4=B8=8D=E5=BF=85=E8=A6=81=E7=9A=84Alpha=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=EF=BC=8C=E4=BC=98=E5=8C=96=E6=95=B0=E6=8D=AE=E6=8E=A9=E7=A0=81?= =?UTF-8?q?=E7=94=9F=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 5 +- dataset_builder/HYS_dataset.py | 126 ++++++++++++++++++++++++------ dataset_config/HYS_config.yaml | 4 +- draw_tools/draw_label.py | 20 +++-- event_mask_process/HYS_process.py | 2 +- signal_method/signal_process.py | 8 +- utils/HYS_FileReader.py | 12 +-- utils/split_method.py | 41 ++++++++-- 8 files changed, 170 insertions(+), 48 deletions(-) diff --git a/README.md b/README.md index a83c153..24e92a3 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,5 @@ # DataPrepare - +## 操作步骤 +1. 信号预处理 +2. 数据集构建 +3. 数据可视化(可选) diff --git a/dataset_builder/HYS_dataset.py b/dataset_builder/HYS_dataset.py index 4442936..20024f5 100644 --- a/dataset_builder/HYS_dataset.py +++ b/dataset_builder/HYS_dataset.py @@ -1,8 +1,8 @@ +import multiprocessing import sys from pathlib import Path import os - import numpy as np os.environ['DISPLAY'] = "localhost:10.0" @@ -16,21 +16,23 @@ import draw_tools import shutil -def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False): +def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None): signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) if not signal_path: raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}") signal_path = signal_path[0] - print(f"Processing OrgBCG_Sync signal file: {signal_path}") + if verbose: + print(f"Processing OrgBCG_Sync signal file: {signal_path}") mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv") - print(f"mask_excel_path: {mask_excel_path}") + if verbose: + print(f"mask_excel_path: {mask_excel_path}") event_mask, event_list = utils.read_mask_execl(mask_excel_path) - bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float) + bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=verbose) - bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs) + bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs, verbose=verbose) normalized_resp_signal = signal_method.normalize_resp_signal(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"]) @@ -63,8 +65,9 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False): # show=show, # save_path=None) - segment_list = utils.resp_split(dataset_config, event_mask, event_list) - print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}") + segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose) + if verbose: + print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}") # 复制mask到processed_Labels文件夹 @@ -77,7 +80,8 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False): save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv" shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path) else: - print(f"Warning: {sa_label_corrected_path} does not exist.") + if verbose: + print(f"Warning: {sa_label_corrected_path} does not exist.") # 保存处理后的信号和截取的片段列表 save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz" @@ -109,12 +113,14 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False): np.savez_compressed(save_signal_path, **bcg_data) np.savez_compressed(save_segment_path, - segment_list=segment_list) - print(f"Saved processed signals to: {save_signal_path}") - print(f"Saved segment list to: {save_segment_path}") + segment_list=segment_list, + disable_segment_list=disable_segment_list) + if verbose: + print(f"Saved processed signals to: {save_signal_path}") + print(f"Saved segment list to: {save_segment_path}") if draw_segment: - psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8]) + psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose) psg_data["HR"] = { "name": "HR", "data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"]), @@ -124,14 +130,86 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False): } - psg_label = utils.read_psg_label(sa_label_corrected_path) + psg_label = utils.read_psg_label(sa_label_corrected_path, verbose=verbose) psg_event_mask, _ = utils.generate_event_mask(event_df=psg_label, signal_second=psg_data["ECG_Sync"]["second"], use_correct=False) + total_len = len(segment_list) + len(disable_segment_list) + if verbose: + print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}") + draw_tools.draw_psg_bcg_label(psg_data=psg_data, psg_label=psg_event_mask, bcg_data=bcg_data, event_mask=event_mask, segment_list=segment_list, - save_path=visual_path / f"{samp_id}") + save_path=visual_path / f"{samp_id}" / "enable", + verbose=verbose, + multi_p=multi_p, + multi_task_id=multi_task_id + ) + + draw_tools.draw_psg_bcg_label( + psg_data=psg_data, + psg_label=psg_event_mask, + bcg_data=bcg_data, + event_mask=event_mask, + segment_list=disable_segment_list, + save_path=visual_path / f"{samp_id}" / "disable", + verbose=verbose, + multi_p=multi_p, + multi_task_id=multi_task_id + ) + + + +def multiprocess_entry(_progress, task_id, _id): + build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id) + + +def multiprocess_with_tqdm(args_list, n_processes): + from concurrent.futures import ProcessPoolExecutor + from rich import progress + + with progress.Progress( + "[progress.description]{task.description}", + progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + progress.MofNCompleteColumn(), + progress.TimeRemainingColumn(), + progress.TimeElapsedColumn(), + refresh_per_second=1, # bit slower updates + transient=False + ) as progress: + futures = [] + with multiprocessing.Manager() as manager: + _progress = manager.dict() + overall_progress_task = progress.add_task("[green]All jobs progress:") + with ProcessPoolExecutor(max_workers=n_processes) as executor: + for i_args in range(len(args_list)): + args = args_list[i_args] + task_id = progress.add_task(f"task {i_args}", visible=True) + futures.append(executor.submit(multiprocess_entry, _progress, task_id, args_list[i_args])) + # monitor the progress: + while (n_finished := sum([future.done() for future in futures])) < len( + futures + ): + progress.update( + overall_progress_task, completed=n_finished, total=len(futures) + ) + for task_id, update_data in _progress.items(): + desc = update_data.get("desc", "") + # update the progress bar for this task: + progress.update( + task_id, + completed=update_data.get("progress", 0), + total=update_data.get("total", 0), + description=desc + ) + + # raise any errors: + for future in futures: + future.result() + + if __name__ == '__main__': @@ -156,16 +234,18 @@ if __name__ == '__main__': save_processed_label_path = save_path / "Labels" save_processed_label_path.mkdir(parents=True, exist_ok=True) - print(f"select_ids: {select_ids}") - print(f"root_path: {root_path}") - print(f"save_path: {save_path}") - print(f"visual_path: {visual_path}") + # print(f"select_ids: {select_ids}") + # print(f"root_path: {root_path}") + # print(f"save_path: {save_path}") + # print(f"visual_path: {visual_path}") org_signal_root_path = root_path / "OrgBCG_Aligned" psg_signal_root_path = root_path / "PSG_Aligned" - # build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True) + build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True) - for samp_id in select_ids: - print(f"Processing sample ID: {samp_id}") - build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) \ No newline at end of file + # for samp_id in select_ids: + # print(f"Processing sample ID: {samp_id}") + # build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) + + # multiprocess_with_tqdm(args_list=select_ids, n_processes=16) \ No newline at end of file diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index 80b8165..43544ea 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -19,8 +19,8 @@ resp: resp_filter: filter_type: bandpass - low_cut: 0.01 - high_cut: 0.7 + low_cut: 0.05 + high_cut: 0.6 order: 3 resp_low_amp: diff --git a/draw_tools/draw_label.py b/draw_tools/draw_label.py index 6b0dcb9..b93cc6c 100644 --- a/draw_tools/draw_label.py +++ b/draw_tools/draw_label.py @@ -72,7 +72,6 @@ def create_psg_bcg_figure(): axes[psg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white") axes.append(axes[psg_chn_name2ax["bcg"]].twinx()) - axes[psg_chn_name2ax["Stage"]].grid(True) # axes[7].xaxis.set_major_formatter(Params.FORMATTER) @@ -183,7 +182,8 @@ def score_mask2alpha(score_mask): return alpha_mask -def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None): +def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None, verbose=True, + multi_p=None, multi_task_id=None): if save_path is not None: save_path.mkdir(parents=True, exist_ok=True) @@ -191,13 +191,13 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, if mask.startswith("Resp_") or mask.startswith("BCG_"): event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0) - event_mask["SA_Score"] = utils.none_to_nan_mask(event_mask["SA_Score"], 0) + event_mask["SA_Score"] = utils.none_to_nan_mask(event_mask["SA_Score"], 0) # event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"]) # event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0) fig, axes = create_psg_bcg_figure() - for segment_start, segment_end in tqdm(segment_list): + for i, (segment_start, segment_end) in enumerate(segment_list): for ax in axes: ax.cla() @@ -213,17 +213,21 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end, - event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["resp_twinx"]]) + event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], + ax2=axes[psg_chn_name2ax["resp_twinx"]]) plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end, - event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["bcg_twinx"]]) - + event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], + ax2=axes[psg_chn_name2ax["bcg_twinx"]]) if save_path is not None: fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png") - tqdm.write(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}") # print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}") + if multi_p is not None: + multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"} + + def draw_resp_label(resp_data, resp_label, segment_list): for mask in resp_label.keys(): if mask.startswith("Resp_"): diff --git a/event_mask_process/HYS_process.py b/event_mask_process/HYS_process.py index dc33461..45fa76d 100644 --- a/event_mask_process/HYS_process.py +++ b/event_mask_process/HYS_process.py @@ -52,7 +52,7 @@ def process_one_signal(samp_id, show=False): save_samp_path = save_path / f"{samp_id}" save_samp_path.mkdir(parents=True, exist_ok=True) - signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, verbose=True) + signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=True) signal_data, resp_data, resp_fs, bcg_data, bcg_fs = signal_method.signal_filter_split(conf=conf, signal_data_raw=signal_data_raw, signal_fs=signal_fs) diff --git a/signal_method/signal_process.py b/signal_method/signal_process.py index eaaea59..c0c6699 100644 --- a/signal_method/signal_process.py +++ b/signal_method/signal_process.py @@ -2,11 +2,12 @@ import numpy as np import utils -def signal_filter_split(conf, signal_data_raw, signal_fs): +def signal_filter_split(conf, signal_data_raw, signal_fs, verbose=True): # 滤波 # 50Hz陷波滤波器 # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) - print("Applying 50Hz notch filter...") + if verbose: + print("Applying 50Hz notch filter...") signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs) @@ -17,7 +18,8 @@ def signal_filter_split(conf, signal_data_raw, signal_fs): low_cut=conf["resp_filter"]["low_cut"], high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"], sample_rate=resp_fs) - print("Begin plotting signal data...") + if verbose: + print("Begin plotting signal data...") # fig = plt.figure(figsize=(12, 8)) # # 绘制三个图raw_data、resp_data_1、resp_data_2 diff --git a/utils/HYS_FileReader.py b/utils/HYS_FileReader.py index 6f7d95a..f41d16a 100644 --- a/utils/HYS_FileReader.py +++ b/utils/HYS_FileReader.py @@ -20,6 +20,7 @@ def read_signal_txt(path: Union[str, Path], dtype, verbose=True, is_peak=False): Read a txt file and return the first column as a numpy array. Args: + :param is_peak: :param path: :param verbose: :param dtype: @@ -217,14 +218,15 @@ def read_mask_execl(path: Union[str, Path]): event_list = {"RespAmpChangeSegment": event_mask_2_list(1 - event_mask["Resp_AmpChange_Label"]), "BCGAmpChangeSegment": event_mask_2_list(1 - event_mask["BCG_AmpChange_Label"]), - "EnableSegment": event_mask_2_list(1 - event_mask["Disable_Label"]),} + "EnableSegment": event_mask_2_list(1 - event_mask["Disable_Label"]), + "DisableSegment": event_mask_2_list(event_mask["Disable_Label"])} return event_mask, event_list -def read_psg_channel(path_str: Union[str, Path], channel_number: list[int]): +def read_psg_channel(path_str: Union[str, Path], channel_number: list[int], verbose=True): """ 读取PSG文件中特定通道的数据。 @@ -254,16 +256,16 @@ def read_psg_channel(path_str: Union[str, Path], channel_number: list[int]): if ch_id == 8: # sleep stage 特例 读取为整数 - ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=str, verbose=True) + ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=str, verbose=verbose) # 转换为整数数组 for stage_str, stage_number in utils.Stage2N.items(): np.place(ch_signal, ch_signal == stage_str, stage_number) ch_signal = ch_signal.astype(int) elif ch_id == 1: # Rpeak 特例 读取为整数 - ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=int, verbose=True, is_peak=True) + ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=int, verbose=verbose, is_peak=True) else: - ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=float, verbose=True) + ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=float, verbose=verbose) channel_data[ch_name] = { "name": ch_name, "path": ch_path[0], diff --git a/utils/split_method.py b/utils/split_method.py index e9c151b..f113013 100644 --- a/utils/split_method.py +++ b/utils/split_method.py @@ -1,27 +1,58 @@ +def check_split(event_mask, current_start, window_sec, verbose=False): + # 检查当前窗口是否包含在禁用区间或低幅值区间内 + resp_movement_mask = event_mask["Resp_Movement_Label"][current_start : current_start + window_sec] + resp_low_amp_mask = event_mask["Resp_LowAmp_Label"][current_start : current_start + window_sec] + + # 体动与低幅值进行与计算 + low_move_mask = resp_movement_mask | resp_low_amp_mask + if low_move_mask.sum() > 2/3 * window_sec: + if verbose: + print(f"{current_start}-{current_start + window_sec} rejected due to movement/low amplitude mask more than 2/3") + return False + return True + - -def resp_split(dataset_config, event_mask, event_list): +def resp_split(dataset_config, event_mask, event_list, verbose=False): # 提取体动区间和呼吸低幅值区间 enable_list = event_list["EnableSegment"] + disable_list = event_list["DisableSegment"] # 读取数据集配置 window_sec = dataset_config["window_sec"] stride_sec = dataset_config["stride_sec"] segment_list = [] + disable_segment_list = [] # 遍历每个enable区间, 如果最后一个窗口不足stride的1/2,则舍弃,否则以enable_end为结尾截取一个窗口 for enable_start, enable_end in enable_list: current_start = enable_start while current_start + window_sec <= enable_end: - segment_list.append((current_start, current_start + window_sec)) + if check_split(event_mask, current_start, window_sec, verbose): + segment_list.append((current_start, current_start + window_sec)) + else: + disable_segment_list.append((current_start, current_start + window_sec)) current_start += stride_sec # 检查最后一个窗口是否需要添加 if (enable_end - current_start >= stride_sec / 2) and (enable_end - current_start >= window_sec): - segment_list.append((enable_end - window_sec, enable_end)) + if check_split(event_mask, enable_end - window_sec, window_sec, verbose): + segment_list.append((enable_end - window_sec, enable_end)) + else: + disable_segment_list.append((enable_end - window_sec, enable_end)) - return segment_list + # 遍历每个disable区间, 如果最后一个窗口不足stride的1/2,则舍弃,否则以disable_end为结尾截取一个窗口 + for disable_start, disable_end in disable_list: + current_start = disable_start + while current_start + window_sec <= disable_end: + disable_segment_list.append((current_start, current_start + window_sec)) + current_start += stride_sec + # 检查最后一个窗口是否需要添加 + if (disable_end - current_start >= stride_sec / 2) and (disable_end - current_start >= window_sec): + disable_segment_list.append((disable_end - window_sec, disable_end)) + + + return segment_list, disable_segment_list From 097c9cbf0ba9ffcf29b5887e42f2b82251a818d0 Mon Sep 17 00:00:00 2001 From: marques Date: Mon, 19 Jan 2026 14:27:26 +0800 Subject: [PATCH 27/28] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=A4=84=E7=90=86=E6=A8=A1=E5=9D=97=EF=BC=8C=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?PSG=E4=BF=A1=E5=8F=B7=E7=BB=98=E5=9B=BE=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=8C=E9=87=8D=E6=9E=84=E9=83=A8=E5=88=86=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E4=BB=A5=E6=8F=90=E9=AB=98=E5=8F=AF=E8=AF=BB=E6=80=A7=E5=92=8C?= =?UTF-8?q?=E5=8F=AF=E7=BB=B4=E6=8A=A4=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dataset_builder/HYS_PSG_dataset.py | 395 ++++++++++++++++++++++++++ dataset_builder/HYS_dataset.py | 53 ++-- dataset_config/HYS_PSG_config.yaml | 129 +++++++++ draw_tools/__init__.py | 4 +- draw_tools/draw_label.py | 173 ++++++++--- draw_tools/draw_statics.py | 67 +++++ event_mask_process/HYS_PSG_process.py | 183 ++++++++++++ signal_method/__init__.py | 4 +- signal_method/normalize_method.py | 18 +- signal_method/signal_process.py | 47 ++- utils/HYS_FileReader.py | 58 ++++ utils/__init__.py | 6 +- utils/filter_func.py | 47 +++ utils/operation_tools.py | 125 +++++++- utils/split_method.py | 2 - 15 files changed, 1228 insertions(+), 83 deletions(-) create mode 100644 dataset_builder/HYS_PSG_dataset.py create mode 100644 dataset_config/HYS_PSG_config.yaml create mode 100644 event_mask_process/HYS_PSG_process.py diff --git a/dataset_builder/HYS_PSG_dataset.py b/dataset_builder/HYS_PSG_dataset.py new file mode 100644 index 0000000..e061e25 --- /dev/null +++ b/dataset_builder/HYS_PSG_dataset.py @@ -0,0 +1,395 @@ +import multiprocessing +import sys +from pathlib import Path + +import os +import numpy as np + +from utils import N2Chn + +os.environ['DISPLAY'] = "localhost:10.0" + +sys.path.append(str(Path(__file__).resolve().parent.parent)) +project_root_path = Path(__file__).resolve().parent.parent + +import utils +import signal_method +import draw_tools +import shutil +import gc + +def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None): + psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose) + + total_seconds = min( + psg_data[i]["second"] for i in N2Chn.values() if i != "Rpeak" + ) + for i in N2Chn.values(): + if i == "Rpeak": + continue + length = int(total_seconds * psg_data[i]["fs"]) + psg_data[i]["data"] = psg_data[i]["data"][:length] + psg_data[i]["length"] = length + psg_data[i]["second"] = total_seconds + + psg_data["HR"] = { + "name": "HR", + "data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"], + psg_data["Rpeak"]["fs"]), + "fs": psg_data["ECG_Sync"]["fs"], + "length": psg_data["ECG_Sync"]["length"], + "second": psg_data["ECG_Sync"]["second"] + } + # 预处理与滤波 + tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Tho"]["data"], effort_fs=psg_data["Effort Tho"]["fs"]) + abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Abd"]["data"], effort_fs=psg_data["Effort Abd"]["fs"]) + flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow P"]["data"], effort_fs=psg_data["Flow P"]["fs"]) + flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow T"]["data"], effort_fs=psg_data["Flow T"]["fs"]) + + rri, rri_fs = signal_method.rpeak2rri_interpolation(rpeak_indices=psg_data["Rpeak"]["data"], ecg_fs=psg_data["ECG_Sync"]["fs"], rri_fs=100) + + + mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv") + if verbose: + print(f"mask_excel_path: {mask_excel_path}") + + event_mask, event_list = utils.read_mask_execl(mask_excel_path) + + enable_list = [[0, psg_data["Effort Tho"]["second"]]] + normalized_tho_signal = signal_method.normalize_resp_signal_by_segment(tho_data_filt, tho_fs, np.zeros(psg_data["Effort Tho"]["second"]), enable_list) + normalized_abd_signal = signal_method.normalize_resp_signal_by_segment(abd_data_filt, abd_fs, np.zeros(psg_data["Effort Abd"]["second"]), enable_list) + normalized_flowp_signal = signal_method.normalize_resp_signal_by_segment(flowp_data_filt, flowp_fs, np.zeros(psg_data["Flow P"]["second"]), enable_list) + normalized_flowt_signal = signal_method.normalize_resp_signal_by_segment(flowt_data_filt, flowt_fs, np.zeros(psg_data["Flow T"]["second"]), enable_list) + + + # 都调整至100Hz采样率 + target_fs = 100 + normalized_tho_signal = utils.adjust_sample_rate(normalized_tho_signal, tho_fs, target_fs) + normalized_abd_signal = utils.adjust_sample_rate(normalized_abd_signal, abd_fs, target_fs) + normalized_flowp_signal = utils.adjust_sample_rate(normalized_flowp_signal, flowp_fs, target_fs) + normalized_flowt_signal = utils.adjust_sample_rate(normalized_flowt_signal, flowt_fs, target_fs) + spo2_data_filt = utils.adjust_sample_rate(psg_data["SpO2"]["data"], psg_data["SpO2"]["fs"], target_fs) + normalized_effort_signal = (normalized_tho_signal + normalized_abd_signal) / 2 + rri = utils.adjust_sample_rate(rri, rri_fs, target_fs) + + # 调整至相同长度 + min_length = min(len(normalized_tho_signal), len(normalized_abd_signal), len(normalized_flowp_signal), len(normalized_flowt_signal), len(spo2_data_filt), len(normalized_effort_signal) + ,len(rri)) + min_length = min_length - min_length % target_fs # 保证是整数秒 + normalized_tho_signal = normalized_tho_signal[:min_length] + normalized_abd_signal = normalized_abd_signal[:min_length] + normalized_flowp_signal = normalized_flowp_signal[:min_length] + normalized_flowt_signal = normalized_flowt_signal[:min_length] + spo2_data_filt = spo2_data_filt[:min_length] + normalized_effort_signal = normalized_effort_signal[:min_length] + rri = rri[:min_length] + + tho_second = min_length / target_fs + for i in event_mask.keys(): + event_mask[i] = event_mask[i][:int(tho_second)] + + spo2_data_filt_fill, spo2_disable_mask = utils.fill_spo2_anomaly(spo2_data=spo2_data_filt, + spo2_fs=target_fs, + max_fill_duration=30, + min_gap_duration=10,) + + + draw_tools.draw_psg_signal( + samp_id=samp_id, + tho_signal=normalized_tho_signal, + abd_signal=normalized_abd_signal, + flowp_signal=normalized_flowp_signal, + flowt_signal=normalized_flowt_signal, + spo2_signal=spo2_data_filt, + effort_signal=normalized_effort_signal, + rri_signal = rri, + fs=target_fs, + event_mask=event_mask["SA_Label"], + save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots.png", + show=show + ) + + draw_tools.draw_psg_signal( + samp_id=samp_id, + tho_signal=normalized_tho_signal, + abd_signal=normalized_abd_signal, + flowp_signal=normalized_flowp_signal, + flowt_signal=normalized_flowt_signal, + spo2_signal=spo2_data_filt_fill, + effort_signal=normalized_effort_signal, + rri_signal = rri, + fs=target_fs, + event_mask=event_mask["SA_Label"], + save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots_fill.png", + show=show + ) + + spo2_disable_mask = spo2_disable_mask[::target_fs] + min_len = min(len(event_mask["Disable_Label"]), len(spo2_disable_mask)) + + if len(event_mask["Disable_Label"]) != len(spo2_disable_mask): + print(f"Warning: Data length mismatch! Truncating to {min_len}.") + event_mask["Disable_Label"] = event_mask["Disable_Label"][:min_len] & spo2_disable_mask[:min_len] + + event_list = { + "EnableSegment": utils.event_mask_2_list(1 - event_mask["Disable_Label"]), + "DisableSegment": utils.event_mask_2_list(event_mask["Disable_Label"])} + + spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=95) + + segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose) + if verbose: + print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}") + + + # 复制mask到processed_Labels文件夹 + save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv" + shutil.copyfile(mask_excel_path, save_mask_excel_path) + + # 复制SA Label_corrected.csv到processed_Labels文件夹 + sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_Sync.csv") + if sa_label_corrected_path.exists(): + save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv" + shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path) + else: + if verbose: + print(f"Warning: {sa_label_corrected_path} does not exist.") + + # 保存处理后的信号和截取的片段列表 + save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz" + save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz" + + # psg_data更新为处理后的信号 + # 用下划线替换键里面的空格 + psg_data = { + "Effort Tho": { + "name": "Effort_Tho", + "data": normalized_tho_signal, + "fs": target_fs, + "length": len(normalized_tho_signal), + "second": len(normalized_tho_signal) / target_fs + }, + "Effort Abd": { + "name": "Effort_Abd", + "data": normalized_abd_signal, + "fs": target_fs, + "length": len(normalized_abd_signal), + "second": len(normalized_abd_signal) / target_fs + }, + "Effort": { + "name": "Effort", + "data": normalized_effort_signal, + "fs": target_fs, + "length": len(normalized_effort_signal), + "second": len(normalized_effort_signal) / target_fs + }, + "Flow P": { + "name": "Flow_P", + "data": normalized_flowp_signal, + "fs": target_fs, + "length": len(normalized_flowp_signal), + "second": len(normalized_flowp_signal) / target_fs + }, + "Flow T": { + "name": "Flow_T", + "data": normalized_flowt_signal, + "fs": target_fs, + "length": len(normalized_flowt_signal), + "second": len(normalized_flowt_signal) / target_fs + }, + "SpO2": { + "name": "SpO2", + "data": spo2_data_filt_fill, + "fs": target_fs, + "length": len(spo2_data_filt_fill), + "second": len(spo2_data_filt_fill) / target_fs + }, + "HR": { + "name": "HR", + "data": psg_data["HR"]["data"], + "fs": psg_data["HR"]["fs"], + "length": psg_data["HR"]["length"], + "second": psg_data["HR"]["second"] + }, + "RRI": { + "name": "RRI", + "data": rri, + "fs": target_fs, + "length": len(rri), + "second": len(rri) / target_fs + }, + "5_class": { + "name": "Stage", + "data": psg_data["5_class"]["data"], + "fs": psg_data["5_class"]["fs"], + "length": psg_data["5_class"]["length"], + "second": psg_data["5_class"]["second"] + } + } + + np.savez_compressed(save_signal_path, **psg_data) + np.savez_compressed(save_segment_path, + segment_list=segment_list, + disable_segment_list=disable_segment_list) + if verbose: + print(f"Saved processed signals to: {save_signal_path}") + print(f"Saved segment list to: {save_segment_path}") + + if draw_segment: + total_len = len(segment_list) + len(disable_segment_list) + if verbose: + print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}") + + + draw_tools.draw_psg_label( + psg_data=psg_data, + psg_label=event_mask["SA_Label"], + segment_list=segment_list, + save_path=visual_path / f"{samp_id}" / "enable", + verbose=verbose, + multi_p=multi_p, + multi_task_id=multi_task_id + ) + + draw_tools.draw_psg_label( + psg_data=psg_data, + psg_label=event_mask["SA_Label"], + segment_list=disable_segment_list, + save_path=visual_path / f"{samp_id}" / "disable", + verbose=verbose, + multi_p=multi_p, + multi_task_id=multi_task_id + ) + + # 显式删除大型对象 + try: + del psg_data + del normalized_tho_signal, normalized_abd_signal + del normalized_flowp_signal, normalized_flowt_signal + del normalized_effort_signal + del spo2_data_filt, spo2_data_filt_fill + del rri + del event_mask, event_list + del segment_list, disable_segment_list + except: + pass + + # 强制垃圾回收 + gc.collect() + + + +def multiprocess_entry(_progress, task_id, _id): + build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id) + + +def multiprocess_with_tqdm(args_list, n_processes): + from concurrent.futures import ProcessPoolExecutor + from rich import progress + + with progress.Progress( + "[progress.description]{task.description}", + progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + progress.MofNCompleteColumn(), + progress.TimeRemainingColumn(), + progress.TimeElapsedColumn(), + refresh_per_second=1, # bit slower updates + transient=False + ) as progress: + futures = [] + with multiprocessing.Manager() as manager: + _progress = manager.dict() + overall_progress_task = progress.add_task("[green]All jobs progress:") + with ProcessPoolExecutor(max_workers=n_processes, mp_context=multiprocessing.get_context("spawn")) as executor: + for i_args in range(len(args_list)): + args = args_list[i_args] + task_id = progress.add_task(f"task {i_args}", visible=True) + futures.append(executor.submit(multiprocess_entry, _progress, task_id, args_list[i_args])) + # monitor the progress: + while (n_finished := sum([future.done() for future in futures])) < len( + futures + ): + progress.update( + overall_progress_task, completed=n_finished, total=len(futures) + ) + for task_id, update_data in _progress.items(): + desc = update_data.get("desc", "") + # update the progress bar for this task: + progress.update( + task_id, + completed=update_data.get("progress", 0), + total=update_data.get("total", 0), + description=desc + ) + + # raise any errors: + for future in futures: + future.result() + + +def multiprocess_with_pool(args_list, n_processes): + """使用Pool,每个worker处理固定数量任务后重启""" + from multiprocessing import Pool + + # maxtasksperchild 设置每个worker处理多少任务后重启(释放内存) + with Pool(processes=n_processes, maxtasksperchild=2) as pool: + results = [] + for samp_id in args_list: + result = pool.apply_async( + build_HYS_dataset_segment, + args=(samp_id, False, True, False, None, None) + ) + results.append(result) + + # 等待所有任务完成 + for i, result in enumerate(results): + try: + result.get() + print(f"Completed {i + 1}/{len(args_list)}") + except Exception as e: + print(f"Error processing task {i}: {e}") + + pool.close() + pool.join() + + +if __name__ == '__main__': + yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml" + + conf = utils.load_dataset_conf(yaml_path) + select_ids = conf["select_ids"] + root_path = Path(conf["root_path"]) + mask_path = Path(conf["mask_save_path"]) + save_path = Path(conf["dataset_config"]["dataset_save_path"]) + visual_path = Path(conf["dataset_config"]["dataset_visual_path"]) + dataset_config = conf["dataset_config"] + + visual_path.mkdir(parents=True, exist_ok=True) + + save_processed_signal_path = save_path / "Signals" + save_processed_signal_path.mkdir(parents=True, exist_ok=True) + + save_segment_list_path = save_path / "Segments_List" + save_segment_list_path.mkdir(parents=True, exist_ok=True) + + save_processed_label_path = save_path / "Labels" + save_processed_label_path.mkdir(parents=True, exist_ok=True) + + # print(f"select_ids: {select_ids}") + # print(f"root_path: {root_path}") + # print(f"save_path: {save_path}") + # print(f"visual_path: {visual_path}") + + org_signal_root_path = root_path / "OrgBCG_Aligned" + psg_signal_root_path = root_path / "PSG_Aligned" + print(select_ids) + + # build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True) + + # for samp_id in select_ids: + # print(f"Processing sample ID: {samp_id}") + # build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) + + # multiprocess_with_tqdm(args_list=select_ids, n_processes=8) + multiprocess_with_pool(args_list=select_ids, n_processes=8) \ No newline at end of file diff --git a/dataset_builder/HYS_dataset.py b/dataset_builder/HYS_dataset.py index 20024f5..f944539 100644 --- a/dataset_builder/HYS_dataset.py +++ b/dataset_builder/HYS_dataset.py @@ -33,7 +33,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=verbose) bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs, verbose=verbose) - normalized_resp_signal = signal_method.normalize_resp_signal(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"]) + normalized_resp_signal = signal_method.normalize_resp_signal_by_segment(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"]) # 如果signal_data采样率过,进行降采样 @@ -123,7 +123,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose) psg_data["HR"] = { "name": "HR", - "data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"]), + "data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"], psg_data["Rpeak"]["fs"]), "fs": psg_data["ECG_Sync"]["fs"], "length": psg_data["ECG_Sync"]["length"], "second": psg_data["ECG_Sync"]["second"] @@ -136,28 +136,28 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T if verbose: print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}") - draw_tools.draw_psg_bcg_label(psg_data=psg_data, - psg_label=psg_event_mask, - bcg_data=bcg_data, - event_mask=event_mask, - segment_list=segment_list, - save_path=visual_path / f"{samp_id}" / "enable", - verbose=verbose, - multi_p=multi_p, - multi_task_id=multi_task_id - ) - - draw_tools.draw_psg_bcg_label( - psg_data=psg_data, - psg_label=psg_event_mask, - bcg_data=bcg_data, - event_mask=event_mask, - segment_list=disable_segment_list, - save_path=visual_path / f"{samp_id}" / "disable", - verbose=verbose, - multi_p=multi_p, - multi_task_id=multi_task_id - ) + # draw_tools.draw_psg_bcg_label(psg_data=psg_data, + # psg_label=psg_event_mask, + # bcg_data=bcg_data, + # event_mask=event_mask, + # segment_list=segment_list, + # save_path=visual_path / f"{samp_id}" / "enable", + # verbose=verbose, + # multi_p=multi_p, + # multi_task_id=multi_task_id + # ) + # + # draw_tools.draw_psg_bcg_label( + # psg_data=psg_data, + # psg_label=psg_event_mask, + # bcg_data=bcg_data, + # event_mask=event_mask, + # segment_list=disable_segment_list, + # save_path=visual_path / f"{samp_id}" / "disable", + # verbose=verbose, + # multi_p=multi_p, + # multi_task_id=multi_task_id + # ) @@ -241,11 +241,12 @@ if __name__ == '__main__': org_signal_root_path = root_path / "OrgBCG_Aligned" psg_signal_root_path = root_path / "PSG_Aligned" + print(select_ids) - build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True) + # build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True) # for samp_id in select_ids: # print(f"Processing sample ID: {samp_id}") # build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) - # multiprocess_with_tqdm(args_list=select_ids, n_processes=16) \ No newline at end of file + multiprocess_with_tqdm(args_list=select_ids, n_processes=16) \ No newline at end of file diff --git a/dataset_config/HYS_PSG_config.yaml b/dataset_config/HYS_PSG_config.yaml new file mode 100644 index 0000000..63233df --- /dev/null +++ b/dataset_config/HYS_PSG_config.yaml @@ -0,0 +1,129 @@ +select_ids: + - 54 + - 88 + - 220 + - 221 + - 229 + - 282 + - 286 + - 541 + - 579 + - 582 + - 670 + - 671 + - 683 + - 684 + - 735 + - 933 + - 935 + - 950 + - 952 + - 960 + - 962 + - 967 + - 1302 + +root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS +mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG + +dataset_config: + window_sec: 180 + stride_sec: 60 + dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_PSG_dataset + dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_PSG_dataset/visualization + + +effort: + downsample_fs: 10 + +effort_filter: + filter_type: bandpass + low_cut: 0.05 + high_cut: 0.5 + order: 3 + +flow: + downsample_fs: 10 + +flow_filter: + filter_type: bandpass + low_cut: 0.05 + high_cut: 0.5 + order: 3 + +#ecg: +# downsample_fs: 100 +# +#ecg_filter: +# filter_type: bandpass +# low_cut: 0.5 +# high_cut: 40 +# order: 5 + + +#resp: +# downsample_fs_1: None +# downsample_fs_2: 10 +# +#resp_filter: +# filter_type: bandpass +# low_cut: 0.05 +# high_cut: 0.5 +# order: 3 +# +#resp_low_amp: +# window_size_sec: 30 +# stride_sec: +# amplitude_threshold: 3 +# merge_gap_sec: 60 +# min_duration_sec: 60 +# +#resp_movement: +# window_size_sec: 20 +# stride_sec: 1 +# std_median_multiplier: 4 +# compare_intervals_sec: +# - 60 +# - 120 +## - 180 +# interval_multiplier: 3 +# merge_gap_sec: 30 +# min_duration_sec: 1 +# +#resp_movement_revise: +# up_interval_multiplier: 3 +# down_interval_multiplier: 2 +# compare_intervals_sec: 30 +# merge_gap_sec: 10 +# min_duration_sec: 1 +# +#resp_amp_change: +# mav_calc_window_sec: 4 +# threshold_amplitude: 0.25 +# threshold_energy: 0.4 +# +# +#bcg: +# downsample_fs: 100 +# +#bcg_filter: +# filter_type: bandpass +# low_cut: 1 +# high_cut: 10 +# order: 10 +# +#bcg_low_amp: +# window_size_sec: 1 +# stride_sec: +# amplitude_threshold: 8 +# merge_gap_sec: 20 +# min_duration_sec: 3 +# +# +#bcg_movement: +# window_size_sec: 2 +# stride_sec: +# merge_gap_sec: 20 +# min_duration_sec: 4 + + diff --git a/draw_tools/__init__.py b/draw_tools/__init__.py index 3386b90..ab3b0ae 100644 --- a/draw_tools/__init__.py +++ b/draw_tools/__init__.py @@ -1,2 +1,2 @@ -from .draw_statics import draw_signal_with_mask -from .draw_label import draw_psg_bcg_label, draw_resp_label \ No newline at end of file +from .draw_statics import draw_signal_with_mask, draw_psg_signal +from .draw_label import draw_psg_bcg_label,draw_psg_label \ No newline at end of file diff --git a/draw_tools/draw_label.py b/draw_tools/draw_label.py index b93cc6c..29daa5d 100644 --- a/draw_tools/draw_label.py +++ b/draw_tools/draw_label.py @@ -7,10 +7,10 @@ import seaborn as sns import numpy as np from tqdm.rich import tqdm import utils - +import gc # 添加with_prediction参数 -psg_chn_name2ax = { +psg_bcg_chn_name2ax = { "SpO2": 0, "Flow T": 1, "Flow P": 2, @@ -24,6 +24,19 @@ psg_chn_name2ax = { "bcg_twinx": 10, } +psg_chn_name2ax = { + "SpO2": 0, + "Flow T": 1, + "Flow P": 2, + "Effort Tho": 3, + "Effort Abd": 4, + "Effort": 5, + "HR": 6, + "RRI": 7, + "Stage": 8, +} + + resp_chn_name2ax = { "resp": 0, "bcg": 1, @@ -39,6 +52,54 @@ def create_psg_bcg_figure(): ax = fig.add_subplot(gs[i]) axes.append(ax) + axes[psg_bcg_chn_name2ax["SpO2"]].grid(True) + # axes[0].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["SpO2"]].set_ylim((85, 100)) + axes[psg_bcg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["Flow T"]].grid(True) + # axes[1].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["Flow P"]].grid(True) + # axes[2].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["Effort Tho"]].grid(True) + # axes[3].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["Effort Abd"]].grid(True) + # axes[4].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["HR"]].grid(True) + axes[psg_bcg_chn_name2ax["HR"]].tick_params(axis='x', colors="white") + + axes[psg_bcg_chn_name2ax["resp"]].grid(True) + axes[psg_bcg_chn_name2ax["resp"]].tick_params(axis='x', colors="white") + axes.append(axes[psg_bcg_chn_name2ax["resp"]].twinx()) + + axes[psg_bcg_chn_name2ax["bcg"]].grid(True) + # axes[5].xaxis.set_major_formatter(Params.FORMATTER) + axes[psg_bcg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white") + axes.append(axes[psg_bcg_chn_name2ax["bcg"]].twinx()) + + axes[psg_bcg_chn_name2ax["Stage"]].grid(True) + # axes[7].xaxis.set_major_formatter(Params.FORMATTER) + + return fig, axes + + +def create_psg_figure(): + fig = plt.figure(figsize=(12, 8), dpi=200) + gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 1, 1, 1]) + fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) + axes = [] + for i in range(9): + ax = fig.add_subplot(gs[i]) + axes.append(ax) + axes[psg_chn_name2ax["SpO2"]].grid(True) # axes[0].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_chn_name2ax["SpO2"]].set_ylim((85, 100)) @@ -60,24 +121,21 @@ def create_psg_bcg_figure(): # axes[4].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["Effort"]].grid(True) + axes[psg_chn_name2ax["Effort"]].tick_params(axis='x', colors="white") + axes[psg_chn_name2ax["HR"]].grid(True) axes[psg_chn_name2ax["HR"]].tick_params(axis='x', colors="white") - axes[psg_chn_name2ax["resp"]].grid(True) - axes[psg_chn_name2ax["resp"]].tick_params(axis='x', colors="white") - axes.append(axes[psg_chn_name2ax["resp"]].twinx()) + axes[psg_chn_name2ax["RRI"]].grid(True) + axes[psg_chn_name2ax["RRI"]].tick_params(axis='x', colors="white") - axes[psg_chn_name2ax["bcg"]].grid(True) - # axes[5].xaxis.set_major_formatter(Params.FORMATTER) - axes[psg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white") - axes.append(axes[psg_chn_name2ax["bcg"]].twinx()) axes[psg_chn_name2ax["Stage"]].grid(True) - # axes[7].xaxis.set_major_formatter(Params.FORMATTER) + return fig, axes - def create_resp_figure(): fig = plt.figure(figsize=(12, 6), dpi=100) gs = GridSpec(2, 1, height_ratios=[3, 2]) @@ -150,8 +208,8 @@ def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, ev def plt_stage_on_ax(ax, stage_data, segment_start, segment_end): stage_signal = stage_data["data"] stage_fs = stage_data["fs"] - time_axis = np.linspace(segment_start / stage_fs, segment_end / stage_fs, segment_end - segment_start) - ax.plot(time_axis, stage_signal[segment_start:segment_end], color='black', label=stage_data["name"]) + time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * stage_fs) + ax.plot(time_axis, stage_signal[segment_start * stage_fs:segment_end * stage_fs], color='black', label=stage_data["name"]) ax.set_ylim(0, 6) ax.set_yticks([1, 2, 3, 4, 5]) ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"]) @@ -162,11 +220,11 @@ def plt_stage_on_ax(ax, stage_data, segment_start, segment_end): def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end): spo2_signal = spo2_data["data"] spo2_fs = spo2_data["fs"] - time_axis = np.linspace(segment_start / spo2_fs, segment_end / spo2_fs, segment_end - segment_start) - ax.plot(time_axis, spo2_signal[segment_start:segment_end], color='black', label=spo2_data["name"]) + time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * spo2_fs) + ax.plot(time_axis, spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs], color='black', label=spo2_data["name"]) - if spo2_signal[segment_start:segment_end].min() < 85: - ax.set_ylim((spo2_signal[segment_start:segment_end].min() - 5, 100)) + if spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() < 85: + ax.set_ylim((spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() - 5, 100)) else: ax.set_ylim((85, 100)) ax.set_ylabel("SpO2 (%)") @@ -197,6 +255,56 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, # event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0) fig, axes = create_psg_bcg_figure() + for i, (segment_start, segment_end) in enumerate(segment_list): + for ax in axes: + ax.cla() + + plt_spo2_on_ax(axes[psg_bcg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end) + plt_stage_on_ax(axes[psg_bcg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end, + event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], + ax2=axes[psg_bcg_chn_name2ax["resp_twinx"]]) + plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end, + event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], + ax2=axes[psg_bcg_chn_name2ax["bcg_twinx"]]) + + + if save_path is not None: + fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png") + # print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}") + + if multi_p is not None: + multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"} + + plt.close(fig) + plt.close('all') + gc.collect() + + +def draw_psg_label(psg_data, psg_label, segment_list, save_path=None, verbose=True, + multi_p=None, multi_task_id=None): + if save_path is not None: + save_path.mkdir(parents=True, exist_ok=True) + + + if multi_p is None: + # 遍历psg_data中所有数据的长度 + for i in range(len(psg_data.keys())): + chn_name = list(psg_data.keys())[i] + print(f"{chn_name} data length: {len(psg_data[chn_name]['data'])}, fs: {psg_data[chn_name]['fs']}") + # psg_label的长度 + print(f"psg_label length: {len(psg_label)}") + + fig, axes = create_psg_figure() for i, (segment_start, segment_end) in enumerate(segment_list): for ax in axes: ax.cla() @@ -211,14 +319,10 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) + plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort"]], psg_data["Effort"], segment_start, segment_end, + psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) - plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end, - event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], - ax2=axes[psg_chn_name2ax["resp_twinx"]]) - plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end, - event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], - ax2=axes[psg_chn_name2ax["bcg_twinx"]]) - + plt_signal_label_on_ax(axes[psg_chn_name2ax["RRI"]], psg_data["RRI"], segment_start, segment_end) if save_path is not None: fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png") @@ -226,23 +330,8 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, if multi_p is not None: multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"} + plt.close(fig) + plt.close('all') + gc.collect() -def draw_resp_label(resp_data, resp_label, segment_list): - for mask in resp_label.keys(): - if mask.startswith("Resp_"): - resp_label[mask] = utils.none_to_nan_mask(resp_label[mask], 0) - - # resp_label["Resp_Score_Alpha"] = score_mask2alpha(resp_label["Resp_Score"]) - # resp_label["Resp_Label_Alpha"] = utils.none_to_nan_mask(resp_label["Resp_Label_Alpha"], 0) - - fig, axes = create_resp_figure() - for segment_start, segment_end in segment_list: - for ax in axes: - ax.cla() - - plt_signal_label_on_ax(axes[resp_chn_name2ax["resp"]], resp_data["resp_signal"], segment_start, segment_end, - resp_label, multi_labels="resp", event_codes=[1, 2, 3, 4]) - plt_signal_label_on_ax(axes[resp_chn_name2ax["bcg"]], resp_data["bcg_signal"], segment_start, segment_end, - resp_label, multi_labels="bcg", event_codes=[1, 2, 3, 4]) - plt.show() diff --git a/draw_tools/draw_statics.py b/draw_tools/draw_statics.py index f94679d..06e39aa 100644 --- a/draw_tools/draw_statics.py +++ b/draw_tools/draw_statics.py @@ -247,6 +247,8 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, ax1_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask', 'SA Mask'], loc='upper right') ax1.set_title(f'Sample {samp_id} - Respiration Component') + + ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) ax2.plot(np.linspace(0, len(bcg_data) // bcg_fs, len(bcg_data)), bcg_data, color='green') ax2.set_ylabel('Amplitude') @@ -300,5 +302,70 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, plt.show() +def draw_psg_signal(samp_id, tho_signal, abd_signal, flowp_signal, flowt_signal, spo2_signal, effort_signal, rri_signal, event_mask, fs, + show=False, save_path=None): + sa_mask = event_mask.repeat(fs) + fig, axs = plt.subplots(7, 1, figsize=(18, 12), sharex=True) + time_axis = np.linspace(0, len(tho_signal) / fs, len(tho_signal)) + axs[0].plot(time_axis, tho_signal, label='THO', color='black') + axs[0].set_title(f'Sample {samp_id} - PSG Signal Data') + axs[0].set_ylabel('THO Amplitude') + axs[0].legend(loc='upper right') + + ax0_twin = axs[0].twinx() + ax0_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask') + ax0_twin.autoscale(enable=False, axis='y', tight=True) + ax0_twin.set_ylim((-4, 5)) + + axs[1].plot(time_axis, abd_signal, label='ABD', color='black') + axs[1].set_ylabel('ABD Amplitude') + axs[1].legend(loc='upper right') + + ax1_twin = axs[1].twinx() + ax1_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask') + ax1_twin.autoscale(enable=False, axis='y', tight=True) + ax1_twin.set_ylim((-4, 5)) + + axs[2].plot(time_axis, effort_signal, label='EFFO', color='black') + axs[2].set_ylabel('EFFO Amplitude') + axs[2].legend(loc='upper right') + + ax2_twin = axs[2].twinx() + ax2_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask') + ax2_twin.autoscale(enable=False, axis='y', tight=True) + ax2_twin.set_ylim((-4, 5)) + + axs[3].plot(time_axis, flowp_signal, label='FLOWP', color='black') + axs[3].set_ylabel('FLOWP Amplitude') + axs[3].legend(loc='upper right') + + ax3_twin = axs[3].twinx() + ax3_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask') + ax3_twin.autoscale(enable=False, axis='y', tight=True) + ax3_twin.set_ylim((-4, 5)) + + axs[4].plot(time_axis, flowt_signal, label='FLOWT', color='black') + axs[4].set_ylabel('FLOWT Amplitude') + axs[4].legend(loc='upper right') + + ax4_twin = axs[4].twinx() + ax4_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask') + ax4_twin.autoscale(enable=False, axis='y', tight=True) + ax4_twin.set_ylim((-4, 5)) + + axs[5].plot(time_axis, rri_signal, label='RRI', color='black') + axs[5].set_ylabel('RRI Amplitude') + axs[5].legend(loc='upper right') + axs[6].plot(time_axis, spo2_signal, label='SPO2', color='black') + axs[6].set_ylabel('SPO2 Amplitude') + axs[6].set_xlabel('Time (s)') + axs[6].legend(loc='upper right') + + + if save_path is not None: + plt.savefig(save_path, dpi=300) + if show: + plt.show() + diff --git a/event_mask_process/HYS_PSG_process.py b/event_mask_process/HYS_PSG_process.py new file mode 100644 index 0000000..bd04f63 --- /dev/null +++ b/event_mask_process/HYS_PSG_process.py @@ -0,0 +1,183 @@ +""" +本脚本完成对呼研所数据的处理,包含以下功能: +1. 数据读取与预处理 + 从传入路径中,进行数据和标签的读取,并进行初步的预处理 + 预处理包括为数据进行滤波、去噪等操作 +2. 数据清洗与异常值处理 +3. 输出清晰后的统计信息 +4. 数据保存 + 将处理后的数据保存到指定路径,便于后续使用 + 主要是保存切分后的数据位置和标签 +5. 可视化 + 提供数据处理前后的可视化对比,帮助理解数据变化 + 绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等 + + + +# 低幅值区间规则标定与剔除 +# 高幅值连续体动规则标定与剔除 +# 手动标定不可用区间提剔除 +""" +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) +project_root_path = Path(__file__).resolve().parent.parent + +import shutil +import draw_tools +import utils +import numpy as np +import signal_method +import os + + +os.environ['DISPLAY'] = "localhost:10.0" + + +def process_one_signal(samp_id, show=False): + pass + + tho_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Tho_Sync_*.txt")) + abd_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Abd_Sync_*.txt")) + flowp_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow P_Sync_*.txt")) + flowt_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow T_Sync_*.txt")) + spo2_signal_path = list((org_signal_root_path / f"{samp_id}").glob("SpO2_Sync_*.txt")) + stage_signal_path = list((org_signal_root_path / f"{samp_id}").glob("5_class_Sync_*.txt")) + + if not tho_signal_path: + raise FileNotFoundError(f"Effort Tho_Sync file not found for sample ID: {samp_id}") + tho_signal_path = tho_signal_path[0] + print(f"Processing Effort Tho_Sync signal file: {tho_signal_path}") + if not abd_signal_path: + raise FileNotFoundError(f"Effort Abd_Sync file not found for sample ID: {samp_id}") + abd_signal_path = abd_signal_path[0] + print(f"Processing Effort Abd_Sync signal file: {abd_signal_path}") + if not flowp_signal_path: + raise FileNotFoundError(f"Flow P_Sync file not found for sample ID: {samp_id}") + flowp_signal_path = flowp_signal_path[0] + print(f"Processing Flow P_Sync signal file: {flowp_signal_path}") + if not flowt_signal_path: + raise FileNotFoundError(f"Flow T_Sync file not found for sample ID: {samp_id}") + flowt_signal_path = flowt_signal_path[0] + print(f"Processing Flow T_Sync signal file: {flowt_signal_path}") + if not spo2_signal_path: + raise FileNotFoundError(f"SpO2_Sync file not found for sample ID: {samp_id}") + spo2_signal_path = spo2_signal_path[0] + print(f"Processing SpO2_Sync signal file: {spo2_signal_path}") + if not stage_signal_path: + raise FileNotFoundError(f"5_class_Sync file not found for sample ID: {samp_id}") + stage_signal_path = stage_signal_path[0] + print(f"Processing 5_class_Sync signal file: {stage_signal_path}") + + + + label_path = (label_root_path / f"{samp_id}").glob("SA Label_Sync.csv") + if not label_path: + raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}") + label_path = list(label_path)[0] + print(f"Processing Label_corrected file: {label_path}") + # + # # 保存处理后的数据和标签 + save_samp_path = save_path / f"{samp_id}" + save_samp_path.mkdir(parents=True, exist_ok=True) + + # # # 读取信号数据 + tho_data_raw, tho_length, tho_fs, tho_second = utils.read_signal_txt(tho_signal_path, dtype=float, verbose=True) + # abd_data_raw, abd_length, abd_fs, abd_second = utils.read_signal_txt(abd_signal_path, dtype=float, verbose=True) + # flowp_data_raw, flowp_length, flowp_fs, flowp_second = utils.read_signal_txt(flowp_signal_path, dtype=float, verbose=True) + # flowt_data_raw, flowt_length, flowt_fs, flowt_second = utils.read_signal_txt(flowt_signal_path, dtype=float, verbose=True) + # spo2_data_raw, spo2_length, spo2_fs, spo2_second = utils.read_signal_txt(spo2_signal_path, dtype=float, verbose=True) + stage_data_raw, stage_length, stage_fs, stage_second = utils.read_signal_txt(stage_signal_path, dtype=str, verbose=True) + + + # + # # 预处理与滤波 + # tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=tho_data_raw, effort_fs=tho_fs) + # abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=abd_data_raw, effort_fs=abd_fs) + # flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=flowp_data_raw, effort_fs=flowp_fs) + # flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=flowt_data_raw, effort_fs=flowt_fs) + + # 降采样 + # old_tho_fs = tho_fs + # tho_fs = conf["effort"]["downsample_fs"] + # tho_data_filt = utils.downsample_signal_fast(original_signal=tho_data_filt, original_fs=old_tho_fs, target_fs=tho_fs) + # old_abd_fs = abd_fs + # abd_fs = conf["effort"]["downsample_fs"] + # abd_data_filt = utils.downsample_signal_fast(original_signal=abd_data_filt, original_fs=old_abd_fs, target_fs=abd_fs) + # old_flowp_fs = flowp_fs + # flowp_fs = conf["effort"]["downsample_fs"] + # flowp_data_filt = utils.downsample_signal_fast(original_signal=flowp_data_filt, original_fs=old_flowp_fs, target_fs=flowp_fs) + # old_flowt_fs = flowt_fs + # flowt_fs = conf["effort"]["downsample_fs"] + # flowt_data_filt = utils.downsample_signal_fast(original_signal=flowt_data_filt, original_fs=old_flowt_fs, target_fs=flowt_fs) + + # spo2不降采样 + # spo2_data_filt = spo2_data_raw + # spo2_fs = spo2_fs + + label_data = utils.read_raw_psg_label(path=label_path) + event_mask, score_mask = utils.generate_event_mask(signal_second=tho_second, event_df=label_data, use_correct=False, with_score=False) + # event_mask > 0 的部分为1,其他为0 + score_mask = np.where(event_mask > 0, 1, 0) + + # 根据睡眠分期生成不可用区间 + wake_mask = utils.get_wake_mask(stage_data_raw) + # 剔除短于60秒的觉醒区间 + wake_mask = utils.remove_short_durations(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), min_duration_sec=60) + # 合并短于120秒的觉醒区间 + wake_mask = utils.merge_short_gaps(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), max_gap_sec=60) + + disable_label = wake_mask + + disable_label = disable_label[:tho_second] + + + # 复制事件文件 到保存路径 + sa_label_save_name = f"{samp_id}_" + label_path.name + shutil.copyfile(label_path, save_samp_path / sa_label_save_name) + # + # 新建一个dataframe,分别是秒数、SA标签, + save_dict = { + "Second": np.arange(tho_second), + "SA_Label": event_mask, + "SA_Score": score_mask, + "Disable_Label": disable_label, + "Resp_LowAmp_Label": np.zeros_like(event_mask), + "Resp_Movement_Label": np.zeros_like(event_mask), + "Resp_AmpChange_Label": np.zeros_like(event_mask), + "BCG_LowAmp_Label": np.zeros_like(event_mask), + "BCG_Movement_Label": np.zeros_like(event_mask), + "BCG_AmpChange_Label": np.zeros_like(event_mask) + } + + mask_label_save_name = f"{samp_id}_Processed_Labels.csv" + utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict) + + +if __name__ == '__main__': + yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml" + # disable_df_path = project_root_path / "排除区间.xlsx" + # + conf = utils.load_dataset_conf(yaml_path) + + root_path = Path(conf["root_path"]) + save_path = Path(conf["mask_save_path"]) + select_ids = conf["select_ids"] + # + print(f"select_ids: {select_ids}") + print(f"root_path: {root_path}") + print(f"save_path: {save_path}") + # + org_signal_root_path = root_path / "PSG_Aligned" + label_root_path = root_path / "PSG_Aligned" + # + # all_samp_disable_df = utils.read_disable_excel(disable_df_path) + # + # process_one_signal(select_ids[0], show=True) + # # + for samp_id in select_ids: + print(f"Processing sample ID: {samp_id}") + process_one_signal(samp_id, show=False) + print(f"Finished processing sample ID: {samp_id}\n\n") + pass \ No newline at end of file diff --git a/signal_method/__init__.py b/signal_method/__init__.py index 7ce8cdb..eb9e1e2 100644 --- a/signal_method/__init__.py +++ b/signal_method/__init__.py @@ -2,5 +2,5 @@ from .rule_base_event import detect_low_amplitude_signal, detect_movement from .rule_base_event import position_based_sleep_recognition_v2, position_based_sleep_recognition_v3 from .rule_base_event import movement_revise from .time_metrics import calc_mav_by_slide_windows -from .signal_process import signal_filter_split, rpeak2hr -from .normalize_method import normalize_resp_signal \ No newline at end of file +from .signal_process import signal_filter_split, rpeak2hr, psg_effort_filter, rpeak2rri_interpolation +from .normalize_method import normalize_resp_signal_by_segment diff --git a/signal_method/normalize_method.py b/signal_method/normalize_method.py index 8ed89ce..095f16e 100644 --- a/signal_method/normalize_method.py +++ b/signal_method/normalize_method.py @@ -3,7 +3,7 @@ import pandas as pd import numpy as np from scipy import signal -def normalize_resp_signal(resp_signal: np.ndarray, resp_fs, movement_mask, enable_list): +def normalize_resp_signal_by_segment(resp_signal: np.ndarray, resp_fs, movement_mask, enable_list): # 根据呼吸信号的幅值改变区间,对每段进行Z-Score标准化 normalized_resp_signal = np.zeros_like(resp_signal) # 全部填成nan @@ -33,4 +33,20 @@ def normalize_resp_signal(resp_signal: np.ndarray, resp_fs, movement_mask, enabl raw_segment = resp_signal[enable_start:enable_end] normalized_resp_signal[enable_start:enable_end] = (raw_segment - segment_mean) / segment_std + + #如果enable区间不从0开始,则将前面的部分也进行标准化 + if enable_list[0][0] > 0: + new_enable_start = 0 + enable_start = enable_list[0][0] * resp_fs + enable_end = enable_list[0][1] * resp_fs + segment = resp_signal_no_movement[enable_start:enable_end] + + segment_mean = np.nanmean(segment) + segment_std = np.nanstd(segment) + if segment_std == 0: + raise ValueError(f"segment_std is zero! segment_start: {enable_start}, segment_end: {enable_end}") + + raw_segment = resp_signal[new_enable_start:enable_start] + normalized_resp_signal[new_enable_start:enable_start] = (raw_segment - segment_mean) / segment_std + return normalized_resp_signal diff --git a/signal_method/signal_process.py b/signal_method/signal_process.py index c0c6699..a303ca4 100644 --- a/signal_method/signal_process.py +++ b/signal_method/signal_process.py @@ -1,4 +1,5 @@ import numpy as np +from scipy.interpolate import interp1d import utils @@ -44,14 +45,24 @@ def signal_filter_split(conf, signal_data_raw, signal_fs, verbose=True): return signal_data, resp_data_2, resp_fs, bcg_data, signal_fs +def psg_effort_filter(conf, effort_data_raw, effort_fs): + # 滤波 + effort_data_1 = utils.bessel(data=effort_data_raw, _type=conf["effort_filter"]["filter_type"], + low_cut=conf["effort_filter"]["low_cut"], + high_cut=conf["effort_filter"]["high_cut"], order=conf["effort_filter"]["order"], + sample_rate=effort_fs) + # 移动平均 + effort_data_2 = utils.average_filter(raw_data=effort_data_1, sample_rate=effort_fs, window_size_sec=20) + return effort_data_raw, effort_data_2, effort_fs -def rpeak2hr(rpeak_indices, signal_length): + +def rpeak2hr(rpeak_indices, signal_length, ecg_fs): hr_signal = np.zeros(signal_length) for i in range(1, len(rpeak_indices)): rri = rpeak_indices[i] - rpeak_indices[i - 1] if rri == 0: continue - hr = 60 * 1000 / rri # 心率,单位:bpm + hr = 60 * ecg_fs / rri # 心率,单位:bpm if hr > 120: hr = 120 elif hr < 30: @@ -62,3 +73,35 @@ def rpeak2hr(rpeak_indices, signal_length): hr_signal[rpeak_indices[-1]:] = hr_signal[rpeak_indices[-2]] return hr_signal +def rpeak2rri_repeat(rpeak_indices, signal_length, ecg_fs): + rri_signal = np.zeros(signal_length) + for i in range(1, len(rpeak_indices)): + rri = rpeak_indices[i] - rpeak_indices[i - 1] + rri_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = rri + # 填充最后一个R峰之后的RRI值 + if len(rpeak_indices) > 1: + rri_signal[rpeak_indices[-1]:] = rri_signal[rpeak_indices[-2]] + + # 遍历异常值 + for i in range(1, len(rpeak_indices)): + rri = rpeak_indices[i] - rpeak_indices[i - 1] + if rri < 0.3 * ecg_fs or rri > 2 * ecg_fs: + rri_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = 0 + + return rri_signal + +def rpeak2rri_interpolation(rpeak_indices, ecg_fs, rri_fs): + r_peak_time = np.asarray(rpeak_indices) / ecg_fs + rri = np.diff(r_peak_time) + t_rri = r_peak_time[1:] + + mask = (rri > 0.3) & (rri < 2.0) + rri_clean = rri[mask] + t_rri_clean = t_rri[mask] + + t_uniform = np.arange(t_rri_clean[0], t_rri_clean[-1], 1/rri_fs) + f = interp1d(t_rri_clean, rri_clean, kind='linear', fill_value="extrapolate") + rri_uniform = f(t_uniform) + + return rri_uniform, rri_fs + diff --git a/utils/HYS_FileReader.py b/utils/HYS_FileReader.py index f41d16a..d2fb03f 100644 --- a/utils/HYS_FileReader.py +++ b/utils/HYS_FileReader.py @@ -178,6 +178,55 @@ def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame: return df +def read_raw_psg_label(path: Union[str, Path], verbose=True) -> pd.DataFrame: + """ + Read a CSV file and return it as a pandas DataFrame. + + Args: + path (str | Path): Path to the CSV file. + verbose (bool): + Returns: + pd.DataFrame: The content of the CSV file as a pandas DataFrame. + :param path: + :param verbose: + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + # 直接用pandas读取 包含中文 故指定编码 + df = pd.read_csv(path, encoding="gbk") + if verbose: + print(f"Label file read from {path}, number of rows: {len(df)}") + + num_psg_events = np.sum(df["Event type"].notna()) + # 统计事件 + num_psg_hyp = np.sum(df["Event type"] == "Hypopnea") + num_psg_csa = np.sum(df["Event type"] == "Central apnea") + num_psg_osa = np.sum(df["Event type"] == "Obstructive apnea") + num_psg_msa = np.sum(df["Event type"] == "Mixed apnea") + + + + if verbose: + print("Event Statistics:") + # 格式化输出 总计/来自PSG/手动/删除/未标注 指定宽度 + print(f"Type {'Total':^8s}") + print( + f"Hyp: {num_psg_hyp:^8d} ") + print( + f"CSA: {num_psg_csa:^8d} ") + print( + f"OSA: {num_psg_osa:^8d} ") + print( + f"MSA: {num_psg_msa:^8d} ") + print( + f"All: {num_psg_events:^8d}") + + df["Start"] = df["Start"].astype(int) + df["End"] = df["End"].astype(int) + return df + def read_disable_excel(path: Union[str, Path]) -> pd.DataFrame: """ Read an Excel file and return it as a pandas DataFrame. @@ -225,6 +274,15 @@ def read_mask_execl(path: Union[str, Path]): return event_mask, event_list +def read_psg_mask_excel(path: Union[str, Path]): + + df = pd.read_csv(path) + event_mask = df.to_dict(orient="list") + for key in event_mask: + event_mask[key] = np.array(event_mask[key]) + + return event_mask + def read_psg_channel(path_str: Union[str, Path], channel_number: list[int], verbose=True): """ diff --git a/utils/__init__.py b/utils/__init__.py index 362297e..c449802 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,10 +1,12 @@ -from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel, read_psg_label +from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel, read_psg_label, read_raw_psg_label, read_psg_mask_excel from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list from .operation_tools import merge_short_gaps, remove_short_durations from .operation_tools import collect_values from .operation_tools import save_process_label from .operation_tools import none_to_nan_mask +from .operation_tools import get_wake_mask +from .operation_tools import fill_spo2_anomaly from .split_method import resp_split from .HYS_FileReader import read_mask_execl, read_psg_channel from .event_map import E2N, N2Chn, Stage2N, ColorCycle -from .filter_func import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel \ No newline at end of file +from .filter_func import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel, adjust_sample_rate diff --git a/utils/filter_func.py b/utils/filter_func.py index e690c33..5a00c55 100644 --- a/utils/filter_func.py +++ b/utils/filter_func.py @@ -20,6 +20,7 @@ def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10, sample_rate=10 raise ValueError("Please choose a type of fliter") +@timing_decorator() def bessel(data, _type, low_cut=0.0, high_cut=0.0, order=4, sample_rate=1000): if _type == "lowpass": # 低通滤波处理 b, a = signal.bessel(order, low_cut / (sample_rate * 0.5), btype='lowpass', analog=False, norm='mag') @@ -89,6 +90,52 @@ def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=1 return downsampled_signal +def upsample_signal(original_signal, original_fs, target_fs): + """ + 信号升采样 + + 参数: + original_signal : array-like, 原始信号数组 + original_fs : float, 原始采样率 (Hz) + target_fs : float, 目标采样率 (Hz) + + 返回: + upsampled_signal : array-like, 升采样后的信号 + """ + if not isinstance(original_signal, np.ndarray): + original_signal = np.array(original_signal) + if target_fs <= original_fs: + raise ValueError("目标采样率必须大于原始采样率") + if target_fs <= 0 or original_fs <= 0: + raise ValueError("采样率必须为正数") + + upsample_factor = target_fs / original_fs + num_output_samples = int(len(original_signal) * upsample_factor) + + upsampled_signal = signal.resample(original_signal, num_output_samples) + + return upsampled_signal + + +def adjust_sample_rate(signal_data, original_fs, target_fs): + """ + 根据信号的原始采样率和目标采样率,自动选择升采样或降采样。 + + 参数: + signal_data : array-like, 原始信号数组 + original_fs : float, 原始采样率 (Hz) + target_fs : float, 目标采样率 (Hz) + + 返回: + adjusted_signal : array-like, 调整采样率后的信号 + """ + if original_fs == target_fs: + return signal_data + elif original_fs > target_fs: + return downsample_signal_fast(signal_data, original_fs, target_fs) + else: + return upsample_signal(signal_data, original_fs, target_fs) + @timing_decorator() def average_filter(raw_data, sample_rate, window_size_sec=20): diff --git a/utils/operation_tools.py b/utils/operation_tools.py index 866029d..d6d7051 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -6,6 +6,7 @@ import pandas as pd from matplotlib import pyplot as plt import yaml from numpy.ma.core import append +from scipy.interpolate import PchipInterpolator from utils.event_map import E2N plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 @@ -198,9 +199,12 @@ def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray: return disable_mask -def generate_event_mask(signal_second: int, event_df, use_correct=True): +def generate_event_mask(signal_second: int, event_df, use_correct=True, with_score=True): event_mask = np.zeros(signal_second, dtype=int) - score_mask = np.zeros(signal_second, dtype=int) + if with_score: + score_mask = np.zeros(signal_second, dtype=int) + else: + score_mask = None if use_correct: start_name = "correct_Start" end_name = "correct_End" @@ -217,7 +221,8 @@ def generate_event_mask(signal_second: int, event_df, use_correct=True): start = row[start_name] end = row[end_name] + 1 event_mask[start:end] = E2N[row[event_type_name]] - score_mask[start:end] = row["score"] + if with_score: + score_mask[start:end] = row["score"] return event_mask, score_mask @@ -260,4 +265,116 @@ def none_to_nan_mask(mask, ref): else: # 将mask中的0替换为nan,其他的保持 mask = np.where(mask == 0, np.nan, mask) - return mask \ No newline at end of file + return mask + +def get_wake_mask(sleep_stage_mask): + # 将N1, N2, N3, REM视为睡眠 0,其他为清醒 1 + # 输入是字符 分别有 'W', 'N1', 'N2', 'N3', 'R' 等 + wake_mask = np.where(np.isin(sleep_stage_mask, ['N1', 'N2', 'N3', 'REM', 'R']), 0, 1) + return wake_mask + +def detect_spo2_anomaly(spo2, fs, diff_thresh=7): + anomaly = np.zeros(len(spo2), dtype=bool) + + # 生理范围 + anomaly |= (spo2 < 50) | (spo2 > 100) + + # 突变 + diff = np.abs(np.diff(spo2, prepend=spo2[0])) + anomaly |= diff > diff_thresh + + # NaN + anomaly |= np.isnan(spo2) + + return anomaly + +def merge_close_anomalies(anomaly, fs, min_gap_duration): + min_gap = int(min_gap_duration * fs) + merged = anomaly.copy() + + i = 0 + n = len(anomaly) + + while i < n: + if not anomaly[i]: + i += 1 + continue + + # 当前异常段 + start = i + while i < n and anomaly[i]: + i += 1 + end = i + + # 向后看 gap + j = end + while j < n and not anomaly[j]: + j += 1 + + if j < n and (j - end) < min_gap: + merged[end:j] = True + + return merged + +def fill_spo2_anomaly( + spo2_data, + spo2_fs, + max_fill_duration, + min_gap_duration, +): + spo2 = spo2_data.astype(float).copy() + n = len(spo2) + + anomaly = detect_spo2_anomaly(spo2, spo2_fs) + anomaly = merge_close_anomalies(anomaly, spo2_fs, min_gap_duration) + + max_len = int(max_fill_duration * spo2_fs) + + valid_mask = ~anomaly + + i = 0 + while i < n: + if not anomaly[i]: + i += 1 + continue + + start = i + while i < n and anomaly[i]: + i += 1 + end = i + + seg_len = end - start + + # 超长异常段 + if seg_len > max_len: + spo2[start:end] = np.nan + valid_mask[start:end] = False + continue + + has_left = start > 0 and valid_mask[start - 1] + has_right = end < n and valid_mask[end] + + # 开头异常:单侧填充 + if not has_left and has_right: + spo2[start:end] = spo2[end] + continue + + # 结尾异常:单侧填充 + if has_left and not has_right: + spo2[start:end] = spo2[start - 1] + continue + + # 两侧都有 → PCHIP + if has_left and has_right: + x = np.array([start - 1, end]) + y = np.array([spo2[start - 1], spo2[end]]) + + interp = PchipInterpolator(x, y) + spo2[start:end] = interp(np.arange(start, end)) + continue + + # 两侧都没有(极端情况) + spo2[start:end] = np.nan + valid_mask[start:end] = False + + return spo2, valid_mask \ No newline at end of file diff --git a/utils/split_method.py b/utils/split_method.py index f113013..e5c796c 100644 --- a/utils/split_method.py +++ b/utils/split_method.py @@ -54,5 +54,3 @@ def resp_split(dataset_config, event_mask, event_list, verbose=False): return segment_list, disable_segment_list - - From 92e26425f0dda737645069814d91d63573dc82e1 Mon Sep 17 00:00:00 2001 From: marques Date: Mon, 26 Jan 2026 14:03:37 +0800 Subject: [PATCH 28/28] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=A4=9A=E8=BF=9B?= =?UTF-8?q?=E7=A8=8B=E5=A4=84=E7=90=86=E5=8A=9F=E8=83=BD=EF=BC=8C=E9=87=8D?= =?UTF-8?q?=E6=9E=84=E6=95=B0=E6=8D=AE=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91?= =?UTF-8?q?=EF=BC=8C=E6=9B=B4=E6=96=B0=E9=85=8D=E7=BD=AE=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E4=BB=A5=E6=94=AF=E6=8C=81=E6=96=B0=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dataset_builder/HYS_PSG_dataset.py | 8 +- dataset_builder/HYS_dataset.py | 28 ++++++- dataset_config/HYS_PSG_config.yaml | 10 +++ dataset_config/SHHS1_config.yaml | 41 ++++++++++ dataset_tools/shhs_annotations_check.py | 100 ++++++++++++++++++++++++ event_mask_process/SHHS1_process.py | 42 ++++++++++ signal_method/shhs_tools.py | 62 +++++++++++++++ signal_method/signal_process.py | 2 +- 8 files changed, 287 insertions(+), 6 deletions(-) create mode 100644 dataset_config/SHHS1_config.yaml create mode 100644 dataset_tools/shhs_annotations_check.py create mode 100644 event_mask_process/SHHS1_process.py create mode 100644 signal_method/shhs_tools.py diff --git a/dataset_builder/HYS_PSG_dataset.py b/dataset_builder/HYS_PSG_dataset.py index e061e25..54e8d05 100644 --- a/dataset_builder/HYS_PSG_dataset.py +++ b/dataset_builder/HYS_PSG_dataset.py @@ -63,7 +63,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T # 都调整至100Hz采样率 - target_fs = 100 + target_fs = conf["target_fs"] normalized_tho_signal = utils.adjust_sample_rate(normalized_tho_signal, tho_fs, target_fs) normalized_abd_signal = utils.adjust_sample_rate(normalized_abd_signal, abd_fs, target_fs) normalized_flowp_signal = utils.adjust_sample_rate(normalized_flowp_signal, flowp_fs, target_fs) @@ -90,8 +90,8 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T spo2_data_filt_fill, spo2_disable_mask = utils.fill_spo2_anomaly(spo2_data=spo2_data_filt, spo2_fs=target_fs, - max_fill_duration=30, - min_gap_duration=10,) + max_fill_duration=conf["spo2_fill__anomaly"]["max_fill_duration"], + min_gap_duration=conf["spo2_fill__anomaly"]["min_gap_duration"]) draw_tools.draw_psg_signal( @@ -135,7 +135,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T "EnableSegment": utils.event_mask_2_list(1 - event_mask["Disable_Label"]), "DisableSegment": utils.event_mask_2_list(event_mask["Disable_Label"])} - spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=95) + spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=conf["spo2_fill__anomaly"]["nan_to_num_value"]) segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose) if verbose: diff --git a/dataset_builder/HYS_dataset.py b/dataset_builder/HYS_dataset.py index f944539..c5edd4c 100644 --- a/dataset_builder/HYS_dataset.py +++ b/dataset_builder/HYS_dataset.py @@ -210,6 +210,31 @@ def multiprocess_with_tqdm(args_list, n_processes): future.result() +def multiprocess_with_pool(args_list, n_processes): + """使用Pool,每个worker处理固定数量任务后重启""" + from multiprocessing import Pool + + # maxtasksperchild 设置每个worker处理多少任务后重启(释放内存) + with Pool(processes=n_processes, maxtasksperchild=2) as pool: + results = [] + for samp_id in args_list: + result = pool.apply_async( + build_HYS_dataset_segment, + args=(samp_id, False, True, False, None, None) + ) + results.append(result) + + # 等待所有任务完成 + for i, result in enumerate(results): + try: + result.get() + print(f"Completed {i + 1}/{len(args_list)}") + except Exception as e: + print(f"Error processing task {i}: {e}") + + pool.close() + pool.join() + if __name__ == '__main__': @@ -249,4 +274,5 @@ if __name__ == '__main__': # print(f"Processing sample ID: {samp_id}") # build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) - multiprocess_with_tqdm(args_list=select_ids, n_processes=16) \ No newline at end of file + # multiprocess_with_tqdm(args_list=select_ids, n_processes=16) + multiprocess_with_pool(args_list=select_ids, n_processes=16) \ No newline at end of file diff --git a/dataset_config/HYS_PSG_config.yaml b/dataset_config/HYS_PSG_config.yaml index 63233df..6496a72 100644 --- a/dataset_config/HYS_PSG_config.yaml +++ b/dataset_config/HYS_PSG_config.yaml @@ -26,6 +26,8 @@ select_ids: root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG +target_fs: 100 + dataset_config: window_sec: 180 stride_sec: 60 @@ -42,6 +44,9 @@ effort_filter: high_cut: 0.5 order: 3 +average_filter: + window_size_sec: 20 + flow: downsample_fs: 10 @@ -51,6 +56,11 @@ flow_filter: high_cut: 0.5 order: 3 +spo2_fill__anomaly: + max_fill_duration: 30 + min_gap_duration: 10 + nan_to_num_value: 95 + #ecg: # downsample_fs: 100 # diff --git a/dataset_config/SHHS1_config.yaml b/dataset_config/SHHS1_config.yaml new file mode 100644 index 0000000..0b57be4 --- /dev/null +++ b/dataset_config/SHHS1_config.yaml @@ -0,0 +1,41 @@ +root_path: /mnt/disk_wd/marques_dataset/shhs/polysomnography/shhs1 +mask_save_path: /mnt/disk_code/marques/dataprepare/output/shhs1 + +effort_target_fs: 10 +ecg_target_fs: 100 + + +dataset_config: + window_sec: 180 + stride_sec: 60 + dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/SHHS1_dataset + dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/SHHS1_dataset/visualization + + +effort: + downsample_fs: 10 + +effort_filter: + filter_type: bandpass + low_cut: 0.05 + high_cut: 0.5 + order: 3 + +average_filter: + window_size_sec: 20 + +flow: + downsample_fs: 10 + +flow_filter: + filter_type: bandpass + low_cut: 0.05 + high_cut: 0.5 + order: 3 + +spo2_fill__anomaly: + max_fill_duration: 30 + min_gap_duration: 10 + nan_to_num_value: 95 + + diff --git a/dataset_tools/shhs_annotations_check.py b/dataset_tools/shhs_annotations_check.py new file mode 100644 index 0000000..40061c6 --- /dev/null +++ b/dataset_tools/shhs_annotations_check.py @@ -0,0 +1,100 @@ + +import argparse +from pathlib import Path +from lxml import etree +from tqdm import tqdm +from collections import Counter + + +def main(): + # 设定目标文件夹路径,你可以修改这里的路径,或者运行脚本时手动输入 + # 默认为当前目录 '.' + # target_dir = "/mnt/disk_wd/marques_dataset/shhs/polysomnography/annotations-events-nsrr/shhs1" + target_dir = "/mnt/disk_wd/marques_dataset/shhs/polysomnography/annotations-events-nsrr/shhs2" + + folder_path = Path(target_dir) + + if not folder_path.exists(): + print(f"错误: 路径 '{folder_path}' 不存在。") + return + + # 1. 获取所有 XML 文件 (扁平结构,不递归子目录) + xml_files = list(folder_path.glob("*.xml")) + total_files = len(xml_files) + + if total_files == 0: + print(f"在 '{folder_path}' 中没有找到 XML 文件。") + return + + print(f"找到 {total_files} 个 XML 文件,准备开始处理...") + + # 用于统计 (EventType, EventConcept) 组合的计数器 + stats_counter = Counter() + + # 2. 遍历文件,使用 tqdm 显示进度条 + for xml_file in tqdm(xml_files, desc="Processing XMLs", unit="file"): + try: + # 使用 lxml 解析 + tree = etree.parse(str(xml_file)) + root = tree.getroot() + + # 3. 定位到 ScoredEvent 节点 + # SHHS XML 结构通常是: PSGAnnotation -> ScoredEvents -> ScoredEvent + # 我们直接查找所有的 ScoredEvent 节点 + events = root.findall(".//ScoredEvent") + + for event in events: + # 提取 EventType + type_node = event.find("EventType") + # 处理节点不存在或文本为空的情况 + e_type = type_node.text.strip() if (type_node is not None and type_node.text) else "N/A" + + # 提取 EventConcept + concept_node = event.find("EventConcept") + e_concept = concept_node.text.strip() if (concept_node is not None and concept_node.text) else "N/A" + + # 4. 组合并计数 + # 组合键为元组 (EventType, EventConcept) + key = (e_type, e_concept) + stats_counter[key] += 1 + + except etree.XMLSyntaxError: + print(f"\n[警告] 文件格式错误,跳过: {xml_file.name}") + except Exception as e: + print(f"\n[错误] 处理文件 {xml_file.name} 时出错: {e}") + + # 5. 打印结果到终端 + if stats_counter: + # --- 动态计算列宽 --- + # 获取所有 EventType 的最大长度,默认长度 9 + max_type_width = max((len(k[0]) for k in stats_counter.keys()), default=9) + max_type_width = max(max_type_width, 9) + + # 获取所有 EventConcept 的最大长度,默认长度 12 + max_conc_width = max((len(k[1]) for k in stats_counter.keys()), default=12) + max_conc_width = max(max_conc_width, 12) + + # 计算表格总宽度 + total_line_width = max_type_width + max_conc_width + 10 + 6 + + print("\n" + "=" * total_line_width) + print(f"{'EventType':<{max_type_width}} | {'EventConcept':<{max_conc_width}} | {'Count':>10}") + print("-" * total_line_width) + + # --- 修改处:按名称排序 --- + # sorted() 默认会对元组 (EventType, EventConcept) 进行字典序排序 + # 即先按 EventType A-Z 排序,再按 EventConcept A-Z 排序 + for (e_type, e_concept), count in sorted(stats_counter.items()): + print(f"{e_type:<{max_type_width}} | {e_concept:<{max_conc_width}} | {count:>10}") + + print("=" * total_line_width) + + else: + print("\n未提取到任何事件数据。") + + print("=" * 90) + print(f"统计完成。共扫描 {total_files} 个文件。") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/event_mask_process/SHHS1_process.py b/event_mask_process/SHHS1_process.py new file mode 100644 index 0000000..a1fdae6 --- /dev/null +++ b/event_mask_process/SHHS1_process.py @@ -0,0 +1,42 @@ +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) +project_root_path = Path(__file__).resolve().parent.parent + +import shutil +import draw_tools +import utils +import numpy as np +import signal_method +import os +import mne +from tqdm import tqdm +import xml.etree.ElementTree as ET +import re + +# 获取分期和事件标签,以及不可用区间 + +def process_one_signal(samp_id, show=False): + + + + + + + + +if __name__ == '__main__': + yaml_path = project_root_path / "dataset_config/SHHS1_config.yaml" + + conf = utils.load_dataset_conf(yaml_path) + + root_path = Path(conf["root_path"]) + save_path = Path(conf["mask_save_path"]) + + print(f"root_path: {root_path}") + print(f"save_path: {save_path}") + + org_signal_root_path = root_path / "OrgBCG_Aligned" + label_root_path = root_path / "Label" + diff --git a/signal_method/shhs_tools.py b/signal_method/shhs_tools.py new file mode 100644 index 0000000..fc4c653 --- /dev/null +++ b/signal_method/shhs_tools.py @@ -0,0 +1,62 @@ +import xml.etree.ElementTree as ET +ANNOTATION_MAP = { + "Wake|0": 0, + "Stage 1 sleep|1": 1, + "Stage 2 sleep|2": 2, + "Stage 3 sleep|3": 3, + "Stage 4 sleep|4": 4, + "REM sleep|5": 5, + "Unscored|9": 9, + "Movement|6": 6 +} + +SA_EVENTS = ['Central apnea', 'Hypopnea', 'Obstructive apnea'] + +def parse_sleep_annotations(annotation_path): + """解析睡眠分期注释""" + try: + tree = ET.parse(annotation_path) + root = tree.getroot() + events = [] + for scored_event in root.findall('.//ScoredEvent'): + event_type = scored_event.find('EventType').text + if event_type != "Stages|Stages": + continue + description = scored_event.find('EventConcept').text + start = float(scored_event.find('Start').text) + duration = float(scored_event.find('Duration').text) + if description not in ANNOTATION_MAP: + continue + events.append({ + 'onset': start, + 'duration': duration, + 'description': description, + 'stage': ANNOTATION_MAP[description] + }) + return events + except Exception as e: + return None + + +def extract_osa_events(annotation_path): + """提取睡眠呼吸暂停事件""" + try: + tree = ET.parse(annotation_path) + root = tree.getroot() + events = [] + for scored_event in root.findall('.//ScoredEvent'): + event_concept = scored_event.find('EventConcept').text + event_type = event_concept.split('|')[0].strip() + if event_type in SA_EVENTS: + start = float(scored_event.find('Start').text) + duration = float(scored_event.find('Duration').text) + if duration >= 10: + events.append({ + 'start': start, + 'duration': duration, + 'end': start + duration, + 'type': event_type + }) + return events + except Exception as e: + return [] \ No newline at end of file diff --git a/signal_method/signal_process.py b/signal_method/signal_process.py index a303ca4..168de8f 100644 --- a/signal_method/signal_process.py +++ b/signal_method/signal_process.py @@ -52,7 +52,7 @@ def psg_effort_filter(conf, effort_data_raw, effort_fs): high_cut=conf["effort_filter"]["high_cut"], order=conf["effort_filter"]["order"], sample_rate=effort_fs) # 移动平均 - effort_data_2 = utils.average_filter(raw_data=effort_data_1, sample_rate=effort_fs, window_size_sec=20) + effort_data_2 = utils.average_filter(raw_data=effort_data_1, sample_rate=effort_fs, window_size_sec=conf["average_filter"]["window_size_sec"]) return effort_data_raw, effort_data_2, effort_fs