diff --git a/HYS_process.py b/HYS_process.py index 8dec227..9f66a51 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -28,7 +28,7 @@ import numpy as np import signal_method import os from matplotlib import pyplot as plt -os.environ['DISPLAY'] = "localhost:10.0" +os.environ['DISPLAY'] = "localhost:14.0" def process_one_signal(samp_id): signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) @@ -127,33 +127,23 @@ def process_one_signal(samp_id): ) print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") else: - resp_movement_mask = None + resp_movement_mask, resp_movement_position_list = None, None print("resp_movement_mask is None") - if resp_movement_mask is not None: - # 左右翻转resp_data - reverse_resp_data = resp_data[::-1] - _, resp_movement_mask_reverse, _, resp_movement_position_list_reverse = signal_method.detect_movement( - signal_data=reverse_resp_data, + resp_movement_revise_conf = conf.get("resp_movement_revise", None) + if resp_movement_mask is not None and resp_movement_revise_conf is not None: + resp_movement_mask, resp_movement_position_list = signal_method.movement_revise( + signal_data=resp_data, + movement_mask=resp_movement_mask, + movement_list=resp_movement_position_list, sampling_rate=resp_fs, - **resp_movement_conf + **resp_movement_revise_conf ) - print(f"resp_movement_mask_reverse_shape: {resp_movement_mask_reverse.shape}, num_movement_reverse: {np.sum(resp_movement_mask_reverse == 1)}, count_movement_positions_reverse: {len(resp_movement_position_list_reverse)}") - # 将resp_movement_mask_reverse翻转回来 - resp_movement_mask_reverse = resp_movement_mask_reverse[::-1] + print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}") else: - resp_movement_mask_reverse = None - print("resp_movement_mask_reverse is None") + print("resp_movement_mask revise is skipped") - # 取交集 - if resp_movement_mask is not None and resp_movement_mask_reverse is not None: - combined_resp_movement_mask = np.logical_and(resp_movement_mask, resp_movement_mask_reverse).astype(int) - resp_movement_mask = combined_resp_movement_mask - print(f"combined_resp_movement_mask_shape: {combined_resp_movement_mask.shape}, num_combined_movement: {np.sum(combined_resp_movement_mask == 1)}") - else: - print("combined_resp_movement_mask is None") - # 分析Resp的幅值突变区间 if resp_movement_mask is not None: @@ -246,4 +236,4 @@ if __name__ == '__main__': all_samp_disable_df = utils.read_disable_excel(disable_df_path) - process_one_signal(select_ids[5]) + process_one_signal(select_ids[0]) diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index 605c010..50e0c93 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -32,13 +32,21 @@ resp_low_amp: resp_movement: window_size_sec: 20 stride_sec: 1 - std_median_multiplier: 3.5 + std_median_multiplier: 5 compare_intervals_sec: - 60 - - 90 + - 120 + - 180 interval_multiplier: 3.5 merge_gap_sec: 30 - min_duration_sec: 2 + min_duration_sec: 1 + +resp_movement_revise: + up_interval_multiplier: 3 + down_interval_multiplier: 1.5 + compare_intervals_sec: 30 + merge_gap_sec: 10 + min_duration_sec: 1 bcg: downsample_fs: 100 diff --git a/draw_tools/__init__.py b/draw_tools/__init__.py index 5d4efe2..281cc34 100644 --- a/draw_tools/__init__.py +++ b/draw_tools/__init__.py @@ -1 +1 @@ -from draw_tools.draw_statics import draw_signal_with_mask \ No newline at end of file +from .draw_statics import draw_signal_with_mask \ No newline at end of file diff --git a/draw_tools/draw_statics.py b/draw_tools/draw_statics.py index adbc3d9..f44cde0 100644 --- a/draw_tools/draw_statics.py +++ b/draw_tools/draw_statics.py @@ -222,7 +222,10 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) - ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='orange') + ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='gray', alpha=0.5) + resp_data_no_movement = resp_data.copy() + resp_data_no_movement[resp_movement_mask.repeat(int(resp_fs)) == 1] = np.nan + ax1.plot(np.linspace(0, len(resp_data_no_movement) // resp_fs, len(resp_data_no_movement)), resp_data_no_movement, color='orange') ax1.set_ylabel('Amplitude') # ax1.set_xticklabels([]) ax1_twin = ax1.twinx() diff --git a/signal_method/__init__.py b/signal_method/__init__.py index a1d61f2..ab44ea9 100644 --- a/signal_method/__init__.py +++ b/signal_method/__init__.py @@ -1 +1,3 @@ -from signal_method.rule_base_event import detect_low_amplitude_signal, detect_movement, position_based_sleep_recognition_v2 \ No newline at end of file +from .rule_base_event import detect_low_amplitude_signal, detect_movement, position_based_sleep_recognition_v2 +from .rule_base_event import movement_revise +from .time_metrics import calc_mav \ No newline at end of file diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index 76b6302..072c375 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -1,6 +1,7 @@ from utils.operation_tools import timing_decorator import numpy as np -from utils import merge_short_gaps, remove_short_durations, event_mask_2_list +from utils import merge_short_gaps, remove_short_durations, event_mask_2_list, collect_values +from signal_method.time_metrics import calc_mav @timing_decorator() @@ -90,16 +91,16 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No # else: # valid_std = original_window_std - valid_std = original_window_std ##20250418新修改 + valid_std = original_window_std ##20250418新修改 - #---------------------- 方法一:基于STD的体动判定 ----------------------# + # ---------------------- 方法一:基于STD的体动判定 ----------------------# # 计算所有有效窗口标准差的中位数 median_std = np.median(valid_std) # 当窗口标准差大于中位数的倍数,判定为体动状态 - std_movement = np.where(original_window_std > median_std * std_median_multiplier, 1, 0) + std_movement = np.where((original_window_std > (median_std * std_median_multiplier)), 1, 0) - #------------------ 方法二:基于前后信号幅值变化的体动判定 ------------------# + # ------------------ 方法二:基于前后信号幅值变化的体动判定 ------------------# amplitude_movement = np.zeros(num_original_windows, dtype=int) # 定义基于时间粒度的比较间隔索引 @@ -146,7 +147,6 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No raw_movement_mask = raw_movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate] movement_mask = movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate] - # 比较剔除的体动,如果被剔除的体动所在区域有高于3std的幅值,则不剔除 removed_movement_mask = (raw_movement_mask - movement_mask) > 0 removed_movement_start = np.where(np.diff(np.concatenate([[0], removed_movement_mask])) == 1)[0] @@ -155,8 +155,8 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No for start, end in zip(removed_movement_start, removed_movement_end): # print(start ,end) # 计算剔除的体动区域的幅值 - if np.nanmax(signal_data[start*sampling_rate:(end+1)*sampling_rate]) > median_std * std_median_multiplier: - movement_mask[start:end+1] = 1 + if np.nanmax(signal_data[start * sampling_rate:(end + 1) * sampling_rate]) > median_std * std_median_multiplier: + movement_mask[start:end + 1] = 1 # raw体动起止位置 [[start, end], [start, end], ...] raw_movement_position_list = event_mask_2_list(raw_movement_mask) @@ -164,25 +164,70 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No # merge体动起止位置 [[start, end], [start, end], ...] movement_position_list = event_mask_2_list(movement_mask) - return raw_movement_mask, movement_mask, raw_movement_position_list, movement_position_list + return raw_movement_mask, movement_mask, raw_movement_position_list, movement_position_list - - -def movement_revise(signal_data, sampling_rate, movement_mask, std_median_multiplier=4.5): +def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up_interval_multiplier: float, + down_interval_multiplier: float, compare_intervals_sec, merge_gap_sec, min_duration_sec): """ - 基于标准差对已有体动掩码进行修正。 用于大尺度的体动检测后的位置修正 + 基于标准差对已有体动掩码进行修正。 用于大尺度的体动检测后的位置精细修正 参数: - signal_data: numpy array,输入的信号数据 - sampling_rate: int,信号的采样率(Hz) - movement_mask: numpy array,已有的体动掩码(1表示体动,0表示睡眠) - - std_median_multiplier: float,标准差中位数的乘数阈值,默认值为 4.5 返回: - revised_movement_mask: numpy array,修正后的体动掩码 """ - pass + window_size = sampling_rate + stride_size = sampling_rate + + time_points = np.arange(len(signal_data)) + + compare_size = int(compare_intervals_sec // (stride_size / sampling_rate)) + + _, mav = calc_mav(signal_data, movement_mask=None, low_amp_mask=None, sampling_rate=sampling_rate, + window_second=2, step_second=1, + inner_window_second=1) + + # 往左右两边取compare_size个点的mav,取平均值 + for start, end in movement_list: + left_values = collect_values(arr=mav, index=start - 1, step=-1, limit=compare_size, mask=movement_mask) + right_values = collect_values(arr=mav, index=end + 5, step=1, limit=compare_size, mask=movement_mask) + left_value_metrics = np.median(left_values) if len(left_values) > 0 else 0 + right_value_metrics = np.median(right_values) if len(right_values) > 0 else 0 + if left_value_metrics == 0: + value_metrics = right_value_metrics + elif right_value_metrics == 0: + value_metrics = left_value_metrics + else: + value_metrics = np.mean([left_value_metrics, right_value_metrics]) + + # 逐秒遍历mav,判断是否需要修正 + # print(f"Revising movement from index {start} to {end}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") + for i in range(start, end + 5): + # print(f"Index {i}, mav: {mav[i]:.2f}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") + if mav[i] > (value_metrics * up_interval_multiplier): + movement_mask[i] = 1 + # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * up_interval_multiplier:.2f}") + elif mav[i] < (value_metrics * down_interval_multiplier): + movement_mask[i] = 0 + # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * down_interval_multiplier:.2f}") + # else: + # print(f"No revision at index {i}, mav: {mav[i]:.2f}, up_threshold: {value_mean * up_interval_multiplier:.2f}, down_threshold: {value_mean * down_interval_multiplier:.2f}") + # + # 如果需要合并间隔小的体动状态 + if merge_gap_sec > 0: + movement_mask = merge_short_gaps(movement_mask, time_points, merge_gap_sec) + + # 如果需要移除短时体动状态 + if min_duration_sec > 0: + movement_mask = remove_short_durations(movement_mask, time_points, min_duration_sec) + + movement_list = event_mask_2_list(movement_mask) + return movement_mask, movement_list + @@ -335,10 +380,10 @@ def position_based_sleep_recognition_v1(signal_data, movement_mask, sampling_rat # 新的end - start确保为200的整数倍 if (left_end - left_start) % (mav_calc_window_sec * sampling_rate) != 0: left_end = left_start + ((left_end - left_start) // (mav_calc_window_sec * sampling_rate)) * ( - mav_calc_window_sec * sampling_rate) + mav_calc_window_sec * sampling_rate) if (right_end - right_start) % (mav_calc_window_sec * sampling_rate) != 0: right_end = right_start + ((right_end - right_start) // (mav_calc_window_sec * sampling_rate)) * ( - mav_calc_window_sec * sampling_rate) + mav_calc_window_sec * sampling_rate) # 计算每个片段的幅值指标 left_mav = np.mean(np.max(signal_data[left_start:left_end].reshape(-1, mav_calc_window_sec * sampling_rate), @@ -431,11 +476,12 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat # 新的end - start确保为200的整数倍 if (end - start) % (mav_calc_window_sec * sampling_rate) != 0: end = start + ((end - start) // (mav_calc_window_sec * sampling_rate)) * ( - mav_calc_window_sec * sampling_rate) + mav_calc_window_sec * sampling_rate) # 计算每个片段的幅值指标 mav = np.nanmean( - np.nanmax(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) - np.nanmean( + np.nanmax(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), + axis=0)) - np.nanmean( np.nanmin(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) segment_average_amplitude.append(mav) diff --git a/signal_method/time_metrics.py b/signal_method/time_metrics.py index a6a2a94..6d8c196 100644 --- a/signal_method/time_metrics.py +++ b/signal_method/time_metrics.py @@ -5,10 +5,14 @@ import numpy as np @timing_decorator() def calc_mav(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window_second=10, step_second=1, inner_window_second=2): - assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" - assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" - # print(f"movement_mask_length: {len(movement_mask)}, signal_data_length: {len(signal_data)}") - processed_mask = movement_mask.copy() + if movement_mask is not None: + assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" + # assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" + # print(f"movement_mask_length: {len(movement_mask)}, signal_data_length: {len(signal_data)}") + processed_mask = movement_mask.copy() + else: + processed_mask = None + def mav_func(x): return np.mean(np.nanmax(x.reshape(-1, inner_window_second*sampling_rate), axis=1) - np.nanmin(x.reshape(-1, inner_window_second*sampling_rate), axis=1)) / 2 mav_nan, mav = calculate_by_slide_windows(mav_func, signal_data, processed_mask, sampling_rate=sampling_rate, diff --git a/utils/__init__.py b/utils/__init__.py index 51a18ba..c89b90c 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,5 +1,6 @@ -from utils.HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel -from utils.operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list -from utils.operation_tools import merge_short_gaps, remove_short_durations -from utils.event_map import E2N -from utils.signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel \ No newline at end of file +from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel +from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list +from .operation_tools import merge_short_gaps, remove_short_durations +from .operation_tools import collect_values +from .event_map import E2N +from .signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel \ No newline at end of file diff --git a/utils/operation_tools.py b/utils/operation_tools.py index 4feabdf..0229a06 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -125,8 +125,8 @@ def remove_short_durations(state_sequence, time_points, min_duration_sec): @timing_decorator() def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, window_second=20, step_second=None): # 处理标志位长度与 signal_data 对齐 - if calc_mask is None: - calc_mask = np.zeros(len(signal_data), dtype=bool) + # if calc_mask is None: + # calc_mask = np.zeros(len(signal_data), dtype=bool) if step_second is None: step_second = window_second @@ -157,18 +157,21 @@ def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, values_nan = values_nan.repeat(step_second)[:origin_seconds] - for i in range(len(values_nan)): - if calc_mask[i]: - values_nan[i] = np.nan + if calc_mask is not None: + for i in range(len(values_nan)): + if calc_mask[i]: + values_nan[i] = np.nan - values = values_nan.copy() + values = values_nan.copy() - # 插值处理体动区域的 NaN 值 - def interpolate_nans(x, t): - valid_mask = ~np.isnan(x) - return np.interp(t, t[valid_mask], x[valid_mask]) + # 插值处理体动区域的 NaN 值 + def interpolate_nans(x, t): + valid_mask = ~np.isnan(x) + return np.interp(t, t[valid_mask], x[valid_mask]) - values = interpolate_nans(values, np.arange(len(values))) + values = interpolate_nans(values, np.arange(len(values))) + else: + values = values_nan.copy() return values_nan, values @@ -208,7 +211,20 @@ def generate_event_mask(signal_second: int, event_df): def event_mask_2_list(mask): - mask_start = np.where(np.diff(mask, append=0) == 1)[0] - mask_end = np.where(np.diff(mask, append=0) == -1)[0] + 1 + mask_start = np.where(np.diff(mask, append=0) == -1)[0] + mask_end = np.where(np.diff(mask, append=0) == 1)[0] + 1 event_list =[[start, end] for start, end in zip(mask_start, mask_end)] - return event_list \ No newline at end of file + return event_list + + +def collect_values(arr: np.ndarray, index: int, step: int, limit: int, mask=None) -> list: + """收集非 NaN 值,直到达到指定数量或边界""" + values = [] + count = 0 + mask = mask if mask is not None else arr + while count < limit and 0 <= index < len(mask): + if not np.isnan(mask[index]): + values.append(arr[index]) + count += 1 + index += step + return values \ No newline at end of file