From 265fcd958ab8449f58dcd7f43feece37cb02e156 Mon Sep 17 00:00:00 2001 From: marques Date: Fri, 7 Nov 2025 16:52:31 +0800 Subject: [PATCH] Refactor signal processing functions in HYS_process.py and rule_base_event.py, update imports in __init__.py, and enhance event mask handling in operation_tools.py --- HYS_process.py | 6 +- signal_method/__init__.py | 5 +- signal_method/rule_base_event.py | 117 +++++++++++++++++++++++++++++-- signal_method/time_metrics.py | 6 +- utils/operation_tools.py | 14 ++-- 5 files changed, 129 insertions(+), 19 deletions(-) diff --git a/HYS_process.py b/HYS_process.py index 4873572..0c3627c 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -132,7 +132,6 @@ def process_one_signal(samp_id): resp_movement_revise_conf = conf.get("resp_movement_revise", None) if resp_movement_mask is not None and resp_movement_revise_conf is not None: - print(resp_movement_position_list) resp_movement_mask, resp_movement_position_list = signal_method.movement_revise( signal_data=resp_data, movement_mask=resp_movement_mask, @@ -140,7 +139,7 @@ def process_one_signal(samp_id): sampling_rate=resp_fs, **resp_movement_revise_conf ) - print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}") + print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") else: print("resp_movement_mask revise is skipped") @@ -148,9 +147,10 @@ def process_one_signal(samp_id): # 分析Resp的幅值突变区间 if resp_movement_mask is not None: - resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v2( + resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v3( signal_data=resp_data, movement_mask=resp_movement_mask, + movement_list=resp_movement_position_list, sampling_rate=resp_fs) print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}") else: diff --git a/signal_method/__init__.py b/signal_method/__init__.py index ab44ea9..eaea6ea 100644 --- a/signal_method/__init__.py +++ b/signal_method/__init__.py @@ -1,3 +1,4 @@ -from .rule_base_event import detect_low_amplitude_signal, detect_movement, position_based_sleep_recognition_v2 +from .rule_base_event import detect_low_amplitude_signal, detect_movement +from .rule_base_event import position_based_sleep_recognition_v2, position_based_sleep_recognition_v3 from .rule_base_event import movement_revise -from .time_metrics import calc_mav \ No newline at end of file +from .time_metrics import calc_mav_by_slide_windows \ No newline at end of file diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index 072c375..161f378 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -1,7 +1,8 @@ +import utils from utils.operation_tools import timing_decorator import numpy as np from utils import merge_short_gaps, remove_short_durations, event_mask_2_list, collect_values -from signal_method.time_metrics import calc_mav +from signal_method.time_metrics import calc_mav_by_slide_windows @timing_decorator() @@ -187,9 +188,9 @@ def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up compare_size = int(compare_intervals_sec // (stride_size / sampling_rate)) - _, mav = calc_mav(signal_data, movement_mask=None, low_amp_mask=None, sampling_rate=sampling_rate, - window_second=2, step_second=1, - inner_window_second=1) + _, mav = calc_mav_by_slide_windows(signal_data, movement_mask=None, low_amp_mask=None, sampling_rate=sampling_rate, + window_second=2, step_second=1, + inner_window_second=1) # 往左右两边取compare_size个点的mav,取平均值 for start, end in movement_list: @@ -207,6 +208,8 @@ def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up # 逐秒遍历mav,判断是否需要修正 # print(f"Revising movement from index {start} to {end}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") for i in range(start, end + 5): + if i < 0 or i >= len(mav): + continue # print(f"Index {i}, mav: {mav[i]:.2f}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") if mav[i] > (value_metrics * up_interval_multiplier): movement_mask[i] = 1 @@ -229,8 +232,6 @@ def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up return movement_mask, movement_list - - @timing_decorator() def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, stride_sec=None, amplitude_threshold=50, merge_gap_sec=10, min_duration_sec=5): @@ -511,3 +512,107 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat position_change_times.append((movement_start[i - 1], movement_end[i - 1])) return position_changes, position_change_times + + +def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_list, sampling_rate=100): + """ + + :param movement_list: + :param signal_data: + :param movement_mask: mask的采样率为1Hz + :param sampling_rate: + :param window_size_sec: + :return: + """ + mav_calc_window_sec = 1 # 计算mav的窗口大小,单位秒 + # 判断是否存在显著变化 (可根据实际情况调整阈值) + threshold_amplitude = 0.1 # 幅值变化阈值 + threshold_energy = 0.1 # 能量变化阈值 + + # 获取有效片段起止位置 + + valid_list = utils.event_mask_2_list(movement_mask, event_true=False) + + segment_average_amplitude = [] + segment_average_energy = [] + + signal_data_no_movement = signal_data.copy() + for start, end in movement_list: + signal_data_no_movement[start * sampling_rate:end * sampling_rate] = np.nan + + # from matplotlib import pyplot as plt + # plt.plot(signal_data, alpha=0.3, color='gray') + # plt.plot(signal_data_no_movement, color='blue', linewidth=1) + # plt.show() + + if len(valid_list) < 2: + return [], [] + + def clac_mav(data_segment): + mav = np.nanmean( + np.nanmax(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), + axis=0)) - np.nanmean( + np.nanmin(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) + return mav + + def clac_energy(data_segment): + energy = np.nansum(np.abs(data_segment ** 2)) + return energy + + position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int) + position_change_list = [] + + pre_valid_start = valid_list[0][0] * sampling_rate + pre_valid_end = valid_list[0][1] * sampling_rate + + print(f"Total movement segments to analyze: {len(movement_list)}") + print(f"Total valid segments available: {len(valid_list)}") + + for i in range(len(movement_list)): + print(f"Analyzing movement segment {i + 1}/{len(movement_list)}") + + if i + 1 >= len(valid_list): + print("No more valid segments to compare. Ending analysis.") + break + + next_valid_start = valid_list[i + 1][0] * sampling_rate + next_valid_end = valid_list[i + 1][1] * sampling_rate + + movement_start = movement_list[i][0] + movement_end = movement_list[i][1] + + # 避免过短的片段 + if movement_end - movement_start <= sampling_rate: # 小于1秒的片段不考虑 + print(f"Skipping movement segment {i + 1} due to insufficient length. movement start: {movement_start}, movement end: {movement_end}") + continue + + # 计算前后片段的幅值和能量 + left_mav = clac_mav(signal_data_no_movement[pre_valid_start:pre_valid_end]) + right_mav = clac_mav(signal_data_no_movement[next_valid_start:next_valid_end]) + left_energy = clac_energy(signal_data_no_movement[pre_valid_start:pre_valid_end]) + right_energy = clac_energy(signal_data_no_movement[next_valid_start:next_valid_end]) + + # 计算幅值指标的变化率 + amplitude_change = abs(right_mav - left_mav) / max(left_mav, 1e-6) + # 计算能量指标的变化率 + energy_change = abs(right_energy - left_energy) / max(left_energy, 1e-6) + + significant_change = (amplitude_change > threshold_amplitude) and (energy_change > threshold_energy) + if significant_change: + print( + f"Significant position change detected between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}") + # 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示 + position_changes[movement_start:movement_end] = 1 + position_change_list.append(movement_list[i]) + # 更新前后片段 + pre_valid_start = next_valid_start + pre_valid_end = next_valid_end + + else: + print( + f"No significant position change between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}") + # 仅更新前片段 + pre_valid_start = pre_valid_start + pre_valid_end = next_valid_end + + return position_changes, position_change_list diff --git a/signal_method/time_metrics.py b/signal_method/time_metrics.py index 6d8c196..7895c3d 100644 --- a/signal_method/time_metrics.py +++ b/signal_method/time_metrics.py @@ -4,7 +4,7 @@ import numpy as np @timing_decorator() -def calc_mav(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window_second=10, step_second=1, inner_window_second=2): +def calc_mav_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window_second=10, step_second=1, inner_window_second=2): if movement_mask is not None: assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" # assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" @@ -21,7 +21,7 @@ def calc_mav(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window return mav_nan, mav @timing_decorator() -def calc_wavefactor(signal_data, movement_mask, low_amp_mask, sampling_rate=100): +def calc_wavefactor_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100): assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" @@ -33,7 +33,7 @@ def calc_wavefactor(signal_data, movement_mask, low_amp_mask, sampling_rate=100) return wavefactor_nan, wavefactor @timing_decorator() -def calc_peakfactor(signal_data, movement_mask, low_amp_mask, sampling_rate=100): +def calc_peakfactor_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100): assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" diff --git a/utils/operation_tools.py b/utils/operation_tools.py index ffd1b9f..5739097 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -5,6 +5,8 @@ import numpy as np import pandas as pd from matplotlib import pyplot as plt import yaml +from numpy.ma.core import append + from utils.event_map import E2N plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 @@ -212,13 +214,15 @@ def generate_event_mask(signal_second: int, event_df): def event_mask_2_list(mask, event_true=True): if event_true: - event_2_normal = 1 - normal_2_event = -1 - else: event_2_normal = -1 normal_2_event = 1 - mask_start = np.where(np.diff(mask, append=0) == normal_2_event)[0] - mask_end = np.where(np.diff(mask, append=0) == normal_2_event)[0] + 1 + _append = 0 + else: + event_2_normal = 1 + normal_2_event = -1 + _append = 1 + mask_start = np.where(np.diff(mask, prepend=_append, append=_append) == normal_2_event)[0] + mask_end = np.where(np.diff(mask, prepend=_append, append=_append) == event_2_normal)[0] + 1 event_list =[[start, end] for start, end in zip(mask_start, mask_end)] return event_list