diff --git a/HYS_process.py b/HYS_process.py index 85fb7a3..edb4065 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -28,7 +28,7 @@ import numpy as np import signal_method import os from matplotlib import pyplot as plt -os.environ['DISPLAY'] = "localhost:10.0" +os.environ['DISPLAY'] = "localhost:11.0" def process_one_signal(samp_id): signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) @@ -98,7 +98,7 @@ def process_one_signal(samp_id): bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs) label_data = utils.read_label_csv(path=label_path) - label_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data) + event_mask, score_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data) manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[ all_samp_disable_df["id"] == samp_id]) @@ -110,11 +110,7 @@ def process_one_signal(samp_id): resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal( signal_data=resp_data, sampling_rate=resp_fs, - window_size_sec=resp_low_amp_conf["window_size_sec"], - stride_sec=resp_low_amp_conf["stride_sec"], - amplitude_threshold=resp_low_amp_conf["amplitude_threshold"], - merge_gap_sec=resp_low_amp_conf["merge_gap_sec"], - min_duration_sec=resp_low_amp_conf["min_duration_sec"] + **resp_low_amp_conf ) print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}") else: @@ -127,13 +123,7 @@ def process_one_signal(samp_id): raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement( signal_data=resp_data, sampling_rate=resp_fs, - window_size_sec=resp_movement_conf["window_size_sec"], - stride_sec=resp_movement_conf["stride_sec"], - std_median_multiplier=resp_movement_conf["std_median_multiplier"], - compare_intervals_sec=resp_movement_conf["compare_intervals_sec"], - interval_multiplier=resp_movement_conf["interval_multiplier"], - merge_gap_sec=resp_movement_conf["merge_gap_sec"], - min_duration_sec=resp_movement_conf["min_duration_sec"] + **resp_movement_conf ) print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}") else: @@ -159,11 +149,7 @@ def process_one_signal(samp_id): bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal( signal_data=bcg_data, sampling_rate=bcg_fs, - window_size_sec=bcg_low_amp_conf["window_size_sec"], - stride_sec=bcg_low_amp_conf["stride_sec"], - amplitude_threshold=bcg_low_amp_conf["amplitude_threshold"], - merge_gap_sec=bcg_low_amp_conf["merge_gap_sec"], - min_duration_sec=bcg_low_amp_conf["min_duration_sec"] + **bcg_low_amp_conf ) print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}") else: @@ -175,13 +161,7 @@ def process_one_signal(samp_id): raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement( signal_data=bcg_data, sampling_rate=bcg_fs, - window_size_sec=bcg_movement_conf["window_size_sec"], - stride_sec=bcg_movement_conf["stride_sec"], - std_median_multiplier=bcg_movement_conf["std_median_multiplier"], - compare_intervals_sec=bcg_movement_conf["compare_intervals_sec"], - interval_multiplier=bcg_movement_conf["interval_multiplier"], - merge_gap_sec=bcg_movement_conf["merge_gap_sec"], - min_duration_sec=bcg_movement_conf["min_duration_sec"] + **bcg_movement_conf ) print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}") else: @@ -215,7 +195,7 @@ def process_one_signal(samp_id): resp_low_amp_mask=resp_low_amp_mask, resp_movement_mask=resp_movement_mask, resp_change_mask=resp_amp_change_mask, - resp_sa_mask=None, + resp_sa_mask=event_mask, bcg_low_amp_mask=bcg_low_amp_mask, bcg_movement_mask=bcg_movement_mask, bcg_change_mask=bcg_amp_change_mask) diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index c30c3c5..2d50926 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -20,25 +20,25 @@ resp_filter: filter_type: bandpass low_cut: 0.01 high_cut: 0.7 - order: 2 + order: 3 resp_low_amp: - window_size_sec: 1 + window_size_sec: 30 stride_sec: - amplitude_threshold: 20 - merge_gap_sec: 10 - min_duration_sec: 5 + amplitude_threshold: 5 + merge_gap_sec: 180 + min_duration_sec: 30 resp_movement: - window_size_sec: 2 - stride_sec: - std_median_multiplier: 4.5 + window_size_sec: 30 + stride_sec: 5 + std_median_multiplier: 5 compare_intervals_sec: - - 30 - 60 - interval_multiplier: 2.5 - merge_gap_sec: 10 - min_duration_sec: 5 + - 90 + interval_multiplier: 3.5 + merge_gap_sec: 45 + min_duration_sec: 10 bcg: downsample_fs: 100 @@ -49,3 +49,17 @@ bcg_filter: high_cut: 10 order: 10 +bcg_low_amp: + window_size_sec: 1 + stride_sec: + amplitude_threshold: 10 + merge_gap_sec: 20 + min_duration_sec: 3 + + +bcg_movement: + window_size_sec: 2 + stride_sec: + merge_gap_sec: 20 + min_duration_sec: 4 + diff --git a/draw_tools/draw_statics.py b/draw_tools/draw_statics.py index b4e401a..ad74a3e 100644 --- a/draw_tools/draw_statics.py +++ b/draw_tools/draw_statics.py @@ -188,8 +188,8 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, if mask is None: return np.full_like(ref, np.nan) else: - # 将mask中的0替换为nan,1替换为1 - mask = np.where(mask == 0, np.nan, 1) + # 将mask中的0替换为nan,其他的保持 + mask = np.where(mask == 0, np.nan, mask) return mask signal_disable_mask = _none_to_nan_mask(signal_disable_mask, signal_data) @@ -224,7 +224,7 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='orange') ax1.set_ylabel('Amplitude') - ax1.set_xticklabels([]) + # ax1.set_xticklabels([]) ax1_twin = ax1.twinx() ax1_twin.plot(np.linspace(0, len(resp_low_amp_mask), len(resp_low_amp_mask)), resp_low_amp_mask*-1, color='blue', alpha=0.5, label='Low Amplitude Mask') diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index a8e2480..eb8db1d 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -421,7 +421,7 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat energy = np.sum(np.abs(signal_data[start:end] ** 2)) segment_average_energy.append(energy) - position_changes = [] + position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int) position_change_times = [] # 判断是否存在显著变化 (可根据实际情况调整阈值) threshold_amplitude = 0.1 # 幅值变化阈值 @@ -440,9 +440,7 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat if significant_change: # 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示 - position_changes.append(1) + position_changes[movement_start[i - 1]:movement_end[i - 1]] = 1 position_change_times.append((movement_start[i - 1], movement_end[i - 1])) - else: - position_changes.append(0) # 0表示不存在姿势变化 - return np.array(position_changes), position_change_times + return position_changes, position_change_times diff --git a/utils/operation_tools.py b/utils/operation_tools.py index f775d73..75b6e75 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -184,12 +184,12 @@ def load_dataset_conf(yaml_path): def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray: - disable_mask = np.ones(signal_second, dtype=int) + disable_mask = np.zeros(signal_second, dtype=int) for _, row in disable_df.iterrows(): start = row["start"] end = row["end"] - disable_mask[start:end] = 0 + disable_mask[start:end] = 1 return disable_mask