Refactor signal processing functions in HYS_process.py and rule_base_event.py, update imports in __init__.py, and enhance event mask handling in operation_tools.py

This commit is contained in:
marques 2025-11-07 16:52:31 +08:00
parent fd7941a80a
commit 265fcd958a
5 changed files with 129 additions and 19 deletions

View File

@ -132,7 +132,6 @@ def process_one_signal(samp_id):
resp_movement_revise_conf = conf.get("resp_movement_revise", None) resp_movement_revise_conf = conf.get("resp_movement_revise", None)
if resp_movement_mask is not None and resp_movement_revise_conf is not None: if resp_movement_mask is not None and resp_movement_revise_conf is not None:
print(resp_movement_position_list)
resp_movement_mask, resp_movement_position_list = signal_method.movement_revise( resp_movement_mask, resp_movement_position_list = signal_method.movement_revise(
signal_data=resp_data, signal_data=resp_data,
movement_mask=resp_movement_mask, movement_mask=resp_movement_mask,
@ -140,7 +139,7 @@ def process_one_signal(samp_id):
sampling_rate=resp_fs, sampling_rate=resp_fs,
**resp_movement_revise_conf **resp_movement_revise_conf
) )
print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}") print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}")
else: else:
print("resp_movement_mask revise is skipped") print("resp_movement_mask revise is skipped")
@ -148,9 +147,10 @@ def process_one_signal(samp_id):
# 分析Resp的幅值突变区间 # 分析Resp的幅值突变区间
if resp_movement_mask is not None: if resp_movement_mask is not None:
resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v2( resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v3(
signal_data=resp_data, signal_data=resp_data,
movement_mask=resp_movement_mask, movement_mask=resp_movement_mask,
movement_list=resp_movement_position_list,
sampling_rate=resp_fs) sampling_rate=resp_fs)
print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}") print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}")
else: else:

View File

@ -1,3 +1,4 @@
from .rule_base_event import detect_low_amplitude_signal, detect_movement, position_based_sleep_recognition_v2 from .rule_base_event import detect_low_amplitude_signal, detect_movement
from .rule_base_event import position_based_sleep_recognition_v2, position_based_sleep_recognition_v3
from .rule_base_event import movement_revise from .rule_base_event import movement_revise
from .time_metrics import calc_mav from .time_metrics import calc_mav_by_slide_windows

View File

@ -1,7 +1,8 @@
import utils
from utils.operation_tools import timing_decorator from utils.operation_tools import timing_decorator
import numpy as np import numpy as np
from utils import merge_short_gaps, remove_short_durations, event_mask_2_list, collect_values from utils import merge_short_gaps, remove_short_durations, event_mask_2_list, collect_values
from signal_method.time_metrics import calc_mav from signal_method.time_metrics import calc_mav_by_slide_windows
@timing_decorator() @timing_decorator()
@ -187,9 +188,9 @@ def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up
compare_size = int(compare_intervals_sec // (stride_size / sampling_rate)) compare_size = int(compare_intervals_sec // (stride_size / sampling_rate))
_, mav = calc_mav(signal_data, movement_mask=None, low_amp_mask=None, sampling_rate=sampling_rate, _, mav = calc_mav_by_slide_windows(signal_data, movement_mask=None, low_amp_mask=None, sampling_rate=sampling_rate,
window_second=2, step_second=1, window_second=2, step_second=1,
inner_window_second=1) inner_window_second=1)
# 往左右两边取compare_size个点的mav取平均值 # 往左右两边取compare_size个点的mav取平均值
for start, end in movement_list: for start, end in movement_list:
@ -207,6 +208,8 @@ def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up
# 逐秒遍历mav判断是否需要修正 # 逐秒遍历mav判断是否需要修正
# print(f"Revising movement from index {start} to {end}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") # print(f"Revising movement from index {start} to {end}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}")
for i in range(start, end + 5): for i in range(start, end + 5):
if i < 0 or i >= len(mav):
continue
# print(f"Index {i}, mav: {mav[i]:.2f}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") # print(f"Index {i}, mav: {mav[i]:.2f}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}")
if mav[i] > (value_metrics * up_interval_multiplier): if mav[i] > (value_metrics * up_interval_multiplier):
movement_mask[i] = 1 movement_mask[i] = 1
@ -229,8 +232,6 @@ def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up
return movement_mask, movement_list return movement_mask, movement_list
@timing_decorator() @timing_decorator()
def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, stride_sec=None, def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, stride_sec=None,
amplitude_threshold=50, merge_gap_sec=10, min_duration_sec=5): amplitude_threshold=50, merge_gap_sec=10, min_duration_sec=5):
@ -511,3 +512,107 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat
position_change_times.append((movement_start[i - 1], movement_end[i - 1])) position_change_times.append((movement_start[i - 1], movement_end[i - 1]))
return position_changes, position_change_times return position_changes, position_change_times
def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_list, sampling_rate=100):
"""
:param movement_list:
:param signal_data:
:param movement_mask: mask的采样率为1Hz
:param sampling_rate:
:param window_size_sec:
:return:
"""
mav_calc_window_sec = 1 # 计算mav的窗口大小单位秒
# 判断是否存在显著变化 (可根据实际情况调整阈值)
threshold_amplitude = 0.1 # 幅值变化阈值
threshold_energy = 0.1 # 能量变化阈值
# 获取有效片段起止位置
valid_list = utils.event_mask_2_list(movement_mask, event_true=False)
segment_average_amplitude = []
segment_average_energy = []
signal_data_no_movement = signal_data.copy()
for start, end in movement_list:
signal_data_no_movement[start * sampling_rate:end * sampling_rate] = np.nan
# from matplotlib import pyplot as plt
# plt.plot(signal_data, alpha=0.3, color='gray')
# plt.plot(signal_data_no_movement, color='blue', linewidth=1)
# plt.show()
if len(valid_list) < 2:
return [], []
def clac_mav(data_segment):
mav = np.nanmean(
np.nanmax(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate),
axis=0)) - np.nanmean(
np.nanmin(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=0))
return mav
def clac_energy(data_segment):
energy = np.nansum(np.abs(data_segment ** 2))
return energy
position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int)
position_change_list = []
pre_valid_start = valid_list[0][0] * sampling_rate
pre_valid_end = valid_list[0][1] * sampling_rate
print(f"Total movement segments to analyze: {len(movement_list)}")
print(f"Total valid segments available: {len(valid_list)}")
for i in range(len(movement_list)):
print(f"Analyzing movement segment {i + 1}/{len(movement_list)}")
if i + 1 >= len(valid_list):
print("No more valid segments to compare. Ending analysis.")
break
next_valid_start = valid_list[i + 1][0] * sampling_rate
next_valid_end = valid_list[i + 1][1] * sampling_rate
movement_start = movement_list[i][0]
movement_end = movement_list[i][1]
# 避免过短的片段
if movement_end - movement_start <= sampling_rate: # 小于1秒的片段不考虑
print(f"Skipping movement segment {i + 1} due to insufficient length. movement start: {movement_start}, movement end: {movement_end}")
continue
# 计算前后片段的幅值和能量
left_mav = clac_mav(signal_data_no_movement[pre_valid_start:pre_valid_end])
right_mav = clac_mav(signal_data_no_movement[next_valid_start:next_valid_end])
left_energy = clac_energy(signal_data_no_movement[pre_valid_start:pre_valid_end])
right_energy = clac_energy(signal_data_no_movement[next_valid_start:next_valid_end])
# 计算幅值指标的变化率
amplitude_change = abs(right_mav - left_mav) / max(left_mav, 1e-6)
# 计算能量指标的变化率
energy_change = abs(right_energy - left_energy) / max(left_energy, 1e-6)
significant_change = (amplitude_change > threshold_amplitude) and (energy_change > threshold_energy)
if significant_change:
print(
f"Significant position change detected between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}")
# 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示
position_changes[movement_start:movement_end] = 1
position_change_list.append(movement_list[i])
# 更新前后片段
pre_valid_start = next_valid_start
pre_valid_end = next_valid_end
else:
print(
f"No significant position change between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}")
# 仅更新前片段
pre_valid_start = pre_valid_start
pre_valid_end = next_valid_end
return position_changes, position_change_list

View File

@ -4,7 +4,7 @@ import numpy as np
@timing_decorator() @timing_decorator()
def calc_mav(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window_second=10, step_second=1, inner_window_second=2): def calc_mav_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window_second=10, step_second=1, inner_window_second=2):
if movement_mask is not None: if movement_mask is not None:
assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}"
# assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" # assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}"
@ -21,7 +21,7 @@ def calc_mav(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window
return mav_nan, mav return mav_nan, mav
@timing_decorator() @timing_decorator()
def calc_wavefactor(signal_data, movement_mask, low_amp_mask, sampling_rate=100): def calc_wavefactor_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100):
assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}"
assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}"
@ -33,7 +33,7 @@ def calc_wavefactor(signal_data, movement_mask, low_amp_mask, sampling_rate=100)
return wavefactor_nan, wavefactor return wavefactor_nan, wavefactor
@timing_decorator() @timing_decorator()
def calc_peakfactor(signal_data, movement_mask, low_amp_mask, sampling_rate=100): def calc_peakfactor_by_slide_windows(signal_data, movement_mask, low_amp_mask, sampling_rate=100):
assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}"
assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}"

View File

@ -5,6 +5,8 @@ import numpy as np
import pandas as pd import pandas as pd
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import yaml import yaml
from numpy.ma.core import append
from utils.event_map import E2N from utils.event_map import E2N
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
@ -212,13 +214,15 @@ def generate_event_mask(signal_second: int, event_df):
def event_mask_2_list(mask, event_true=True): def event_mask_2_list(mask, event_true=True):
if event_true: if event_true:
event_2_normal = 1
normal_2_event = -1
else:
event_2_normal = -1 event_2_normal = -1
normal_2_event = 1 normal_2_event = 1
mask_start = np.where(np.diff(mask, append=0) == normal_2_event)[0] _append = 0
mask_end = np.where(np.diff(mask, append=0) == normal_2_event)[0] + 1 else:
event_2_normal = 1
normal_2_event = -1
_append = 1
mask_start = np.where(np.diff(mask, prepend=_append, append=_append) == normal_2_event)[0]
mask_end = np.where(np.diff(mask, prepend=_append, append=_append) == event_2_normal)[0] + 1
event_list =[[start, end] for start, end in zip(mask_start, mask_end)] event_list =[[start, end] for start, end in zip(mask_start, mask_end)]
return event_list return event_list