diff --git a/HYS_process.py b/HYS_process.py index edb4065..4f53640 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -28,7 +28,7 @@ import numpy as np import signal_method import os from matplotlib import pyplot as plt -os.environ['DISPLAY'] = "localhost:11.0" +os.environ['DISPLAY'] = "localhost:10.0" def process_one_signal(samp_id): signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) @@ -112,7 +112,7 @@ def process_one_signal(samp_id): sampling_rate=resp_fs, **resp_low_amp_conf ) - print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}") + print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}, count_low_amp_positions: {len(resp_low_amp_position_list)}") else: resp_low_amp_mask, resp_low_amp_position_list = None, None print("resp_low_amp_mask is None") @@ -125,7 +125,7 @@ def process_one_signal(samp_id): sampling_rate=resp_fs, **resp_movement_conf ) - print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}") + print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") else: resp_movement_mask = None print("resp_movement_mask is None") @@ -137,7 +137,7 @@ def process_one_signal(samp_id): signal_data=resp_data, movement_mask=resp_movement_mask, sampling_rate=resp_fs) - print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}") + print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}") else: resp_amp_change_mask = None print("amp_change_mask is None") @@ -151,7 +151,7 @@ def process_one_signal(samp_id): sampling_rate=bcg_fs, **bcg_low_amp_conf ) - print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}") + print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}, count_low_amp_positions: {len(bcg_low_amp_position_list)}") else: bcg_low_amp_mask, bcg_low_amp_position_list = None, None print("bcg_low_amp_mask is None") @@ -163,7 +163,7 @@ def process_one_signal(samp_id): sampling_rate=bcg_fs, **bcg_movement_conf ) - print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}") + print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}, count_movement_positions: {len(bcg_movement_position_list)}") else: bcg_movement_mask = None print("bcg_movement_mask is None") @@ -173,7 +173,7 @@ def process_one_signal(samp_id): signal_data=bcg_data, movement_mask=bcg_movement_mask, sampling_rate=bcg_fs) - print(f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}") + print(f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}, count_amp_change_positions: {len(bcg_amp_change_list)}") else: bcg_amp_change_mask = None print("bcg_amp_change_mask is None") @@ -220,4 +220,4 @@ if __name__ == '__main__': all_samp_disable_df = utils.read_disable_excel(disable_df_path) - process_one_signal(select_ids[0]) + process_one_signal(select_ids[2]) diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index bc07225..1c70389 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -25,7 +25,7 @@ resp_filter: resp_low_amp: window_size_sec: 30 stride_sec: - amplitude_threshold: 5 + amplitude_threshold: 3 merge_gap_sec: 180 min_duration_sec: 30 @@ -52,7 +52,7 @@ bcg_filter: bcg_low_amp: window_size_sec: 1 stride_sec: - amplitude_threshold: 10 + amplitude_threshold: 5 merge_gap_sec: 20 min_duration_sec: 3 diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index eb8db1d..1253f06 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -1,6 +1,6 @@ from utils.operation_tools import timing_decorator import numpy as np -from utils.operation_tools import merge_short_gaps, remove_short_durations +from utils import merge_short_gaps, remove_short_durations, event_mask_2_list @timing_decorator() @@ -159,14 +159,10 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No movement_mask[start:end+1] = 1 # raw体动起止位置 [[start, end], [start, end], ...] - raw_movement_start = np.where(np.diff(np.concatenate([[0], raw_movement_mask])) == 1)[0] - raw_movement_end = np.where(np.diff(np.concatenate([raw_movement_mask, [0]])) == -1)[0] + 1 - raw_movement_position_list = [[start, end] for start, end in zip(raw_movement_start, raw_movement_end)] + raw_movement_position_list = event_mask_2_list(raw_movement_mask) # merge体动起止位置 [[start, end], [start, end], ...] - movement_start = np.where(np.diff(np.concatenate([[0], movement_mask])) == 1)[0] - movement_end = np.where(np.diff(np.concatenate([movement_mask, [0]])) == -1)[0] + 1 - movement_position_list = [[start, end] for start, end in zip(movement_start, movement_end)] + movement_position_list = event_mask_2_list(movement_mask) return raw_movement_mask, movement_mask, raw_movement_position_list, movement_position_list @@ -201,7 +197,7 @@ def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, s stride_samples = int(stride_sec * sampling_rate) # 确保步长至少为1 - stride_samples = max(1, stride_samples) + stride_samples = max(sampling_rate, stride_samples) # 处理信号边界,使用反射填充 pad_size = window_samples // 2 @@ -255,9 +251,7 @@ def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, s low_amplitude_mask = low_amplitude_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate] # 低幅值状态起止位置 [[start, end], [start, end], ...] - low_amplitude_start = np.where(np.diff(np.concatenate([[0], low_amplitude_mask])) == 1)[0] - low_amplitude_end = np.where(np.diff(np.concatenate([low_amplitude_mask, [0]])) == -1)[0] - low_amplitude_position_list = [[start, end] for start, end in zip(low_amplitude_start, low_amplitude_end)] + low_amplitude_position_list = event_mask_2_list(low_amplitude_mask) return low_amplitude_mask, low_amplitude_position_list diff --git a/utils/__init__.py b/utils/__init__.py index ae2ee06..51a18ba 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,4 +1,5 @@ from utils.HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel -from utils.operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask +from utils.operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list +from utils.operation_tools import merge_short_gaps, remove_short_durations from utils.event_map import E2N from utils.signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel \ No newline at end of file diff --git a/utils/operation_tools.py b/utils/operation_tools.py index 75b6e75..4feabdf 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -206,3 +206,9 @@ def generate_event_mask(signal_second: int, event_df): score_mask[start:end] = row["score"] return event_mask, score_mask + +def event_mask_2_list(mask): + mask_start = np.where(np.diff(mask, append=0) == 1)[0] + mask_end = np.where(np.diff(mask, append=0) == -1)[0] + 1 + event_list =[[start, end] for start, end in zip(mask_start, mask_end)] + return event_list \ No newline at end of file diff --git a/utils/signal_process.py b/utils/signal_process.py index c657d3f..e690c33 100644 --- a/utils/signal_process.py +++ b/utils/signal_process.py @@ -4,8 +4,7 @@ from scipy import signal, ndimage @timing_decorator() -def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10,sample_rate=1000): - +def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10, sample_rate=1000): if _type == "lowpass": # 低通滤波处理 sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos') return signal.sosfiltfilt(sos, np.array(data)) @@ -90,6 +89,7 @@ def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=1 return downsampled_signal + @timing_decorator() def average_filter(raw_data, sample_rate, window_size_sec=20): kernel = np.ones(window_size_sec * sample_rate) / (window_size_sec * sample_rate)