From 998890377b93219e551c23a1769a0dd58e525547 Mon Sep 17 00:00:00 2001 From: marques Date: Wed, 5 Nov 2025 10:29:24 +0800 Subject: [PATCH] Update HYS_config.yaml and HYS_process.py for signal processing parameters and add movement revision function --- HYS_process.py | 36 +++++++++++++++++++++++++++----- dataset_config/HYS_config.yaml | 12 +++++------ signal_method/rule_base_event.py | 35 +++++++++++++++++++++++++++---- 3 files changed, 68 insertions(+), 15 deletions(-) diff --git a/HYS_process.py b/HYS_process.py index 4f53640..8dec227 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -43,8 +43,8 @@ def process_one_signal(samp_id): label_path = list(label_path)[0] print(f"Processing Label_corrected file: {label_path}") - signal_data = utils.read_signal_txt(signal_path) - signal_length = len(signal_data) + signal_data_raw = utils.read_signal_txt(signal_path) + signal_length = len(signal_data_raw) print(f"signal_length: {signal_length}") signal_fs = int(signal_path.stem.split("_")[-1]) print(f"signal_fs: {signal_fs}") @@ -52,13 +52,13 @@ def process_one_signal(samp_id): print(f"signal_second: {signal_second}") # 根据采样率进行截断 - signal_data = signal_data[:signal_second * signal_fs] + signal_data_raw = signal_data_raw[:signal_second * signal_fs] # 滤波 # 50Hz陷波滤波器 # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) print("Applying 50Hz notch filter...") - signal_data = utils.notch_filter(data=signal_data, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) + signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs) resp_fs = conf["resp"]["downsample_fs_1"] @@ -130,6 +130,30 @@ def process_one_signal(samp_id): resp_movement_mask = None print("resp_movement_mask is None") + if resp_movement_mask is not None: + # 左右翻转resp_data + reverse_resp_data = resp_data[::-1] + _, resp_movement_mask_reverse, _, resp_movement_position_list_reverse = signal_method.detect_movement( + signal_data=reverse_resp_data, + sampling_rate=resp_fs, + **resp_movement_conf + ) + print(f"resp_movement_mask_reverse_shape: {resp_movement_mask_reverse.shape}, num_movement_reverse: {np.sum(resp_movement_mask_reverse == 1)}, count_movement_positions_reverse: {len(resp_movement_position_list_reverse)}") + # 将resp_movement_mask_reverse翻转回来 + resp_movement_mask_reverse = resp_movement_mask_reverse[::-1] + else: + resp_movement_mask_reverse = None + print("resp_movement_mask_reverse is None") + + + # 取交集 + if resp_movement_mask is not None and resp_movement_mask_reverse is not None: + combined_resp_movement_mask = np.logical_and(resp_movement_mask, resp_movement_mask_reverse).astype(int) + resp_movement_mask = combined_resp_movement_mask + print(f"combined_resp_movement_mask_shape: {combined_resp_movement_mask.shape}, num_combined_movement: {np.sum(combined_resp_movement_mask == 1)}") + else: + print("combined_resp_movement_mask is None") + # 分析Resp的幅值突变区间 if resp_movement_mask is not None: @@ -143,6 +167,7 @@ def process_one_signal(samp_id): print("amp_change_mask is None") + # 分析Bcg的低幅值区间 bcg_low_amp_conf = conf.get("bcg_low_amp", None) if bcg_low_amp_conf is not None: @@ -182,6 +207,7 @@ def process_one_signal(samp_id): # 如果signal_data采样率过,进行降采样 if signal_fs == 1000: signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100) + signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs, target_fs=100) signal_fs = 100 draw_tools.draw_signal_with_mask(samp_id=samp_id, @@ -220,4 +246,4 @@ if __name__ == '__main__': all_samp_disable_df = utils.read_disable_excel(disable_df_path) - process_one_signal(select_ids[2]) + process_one_signal(select_ids[5]) diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index 1c70389..605c010 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -26,19 +26,19 @@ resp_low_amp: window_size_sec: 30 stride_sec: amplitude_threshold: 3 - merge_gap_sec: 180 - min_duration_sec: 30 + merge_gap_sec: 60 + min_duration_sec: 60 resp_movement: window_size_sec: 20 - stride_sec: 5 - std_median_multiplier: 5 + stride_sec: 1 + std_median_multiplier: 3.5 compare_intervals_sec: - 60 - 90 interval_multiplier: 3.5 merge_gap_sec: 30 - min_duration_sec: 5 + min_duration_sec: 2 bcg: downsample_fs: 100 @@ -52,7 +52,7 @@ bcg_filter: bcg_low_amp: window_size_sec: 1 stride_sec: - amplitude_threshold: 5 + amplitude_threshold: 8 merge_gap_sec: 20 min_duration_sec: 3 diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index 1253f06..76b6302 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -168,6 +168,24 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No + +def movement_revise(signal_data, sampling_rate, movement_mask, std_median_multiplier=4.5): + """ + 基于标准差对已有体动掩码进行修正。 用于大尺度的体动检测后的位置修正 + + 参数: + - signal_data: numpy array,输入的信号数据 + - sampling_rate: int,信号的采样率(Hz) + - movement_mask: numpy array,已有的体动掩码(1表示体动,0表示睡眠) + - std_median_multiplier: float,标准差中位数的乘数阈值,默认值为 4.5 + + 返回: + - revised_movement_mask: numpy array,修正后的体动掩码 + """ + pass + + + @timing_decorator() def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, stride_sec=None, amplitude_threshold=50, merge_gap_sec=10, min_duration_sec=5): @@ -394,6 +412,15 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat segment_average_amplitude = [] segment_average_energy = [] + signal_data_no_movement = signal_data.copy() + for start, end in zip(movement_start, movement_end): + signal_data_no_movement[start * sampling_rate:end * sampling_rate] = np.nan + + # from matplotlib import pyplot as plt + # plt.plot(signal_data, alpha=0.3, color='gray') + # plt.plot(signal_data_no_movement, color='blue', linewidth=1) + # plt.show() + for start, end in zip(valid_starts, valid_ends): start *= sampling_rate end *= sampling_rate @@ -407,12 +434,12 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat mav_calc_window_sec * sampling_rate) # 计算每个片段的幅值指标 - mav = np.mean( - np.max(signal_data[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) - np.mean( - np.min(signal_data[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) + mav = np.nanmean( + np.nanmax(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) - np.nanmean( + np.nanmin(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) segment_average_amplitude.append(mav) - energy = np.sum(np.abs(signal_data[start:end] ** 2)) + energy = np.nansum(np.abs(signal_data_no_movement[start:end] ** 2)) segment_average_energy.append(energy) position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int)