From 40aad46d6f042d03226a3032a462888d859a0ed7 Mon Sep 17 00:00:00 2001 From: marques Date: Thu, 23 Oct 2025 15:43:28 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E5=A4=9A=E4=B8=AA?= =?UTF-8?q?=E6=96=87=E4=BB=B6=EF=BC=8C=E5=AE=8C=E6=88=90=E5=9F=BA=E6=9C=AC?= =?UTF-8?q?=E8=AF=BB=E5=8F=96=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- HYS_process.py | 40 +++++++- dataset_config/HYS_config.yaml | 24 ++++- signal_method/__init__.py | 1 + signal_method/rule_base_event.py | 169 +++++++++++++++++++++++++++++++ utils/HYS_FileReader.py | 5 +- utils/__init__.py | 5 +- utils/operation_tools.py | 122 +--------------------- utils/signal_process.py | 92 +++++++++++++++++ 8 files changed, 331 insertions(+), 127 deletions(-) create mode 100644 utils/signal_process.py diff --git a/HYS_process.py b/HYS_process.py index c0f4ddb..9422d9c 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -24,7 +24,7 @@ from pathlib import Path from typing import Union import utils import numpy as np - +import signal_method @@ -50,13 +50,40 @@ def process_one_signal(samp_id): signal_second = signal_length // signal_fs print(f"signal_second: {signal_second}") + # 滤波 + # 50Hz陷波滤波器 + # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) + resp_data = utils.butterworth(data=signal_data, _type=conf["resp"]["filter_type"], low_cut=conf["resp"]["low_cut"], + high_cut=conf["resp"]["high_cut"], order=conf["resp"]["order"], sample_rate=signal_fs) - label_data = utils.read_label_csv(label_path) - label_mask = utils.generate_event_mask(signal_second, label_data) + bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg"]["filter_type"], low_cut=conf["bcg"]["low_cut"], + high_cut=conf["bcg"]["high_cut"], order=conf["bcg"]["order"], sample_rate=signal_fs) - manual_disable_mask = utils.generate_disable_mask(signal_second, all_samp_disable_df[all_samp_disable_df["id"] == samp_id]) + + label_data = utils.read_label_csv(path=label_path) + label_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data) + + manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[all_samp_disable_df["id"] == samp_id]) print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}") + # 分析Resp的低幅值区间 + resp_low_amp_conf = getattr(conf, "resp_low_amp", None) + if resp_low_amp_conf is not None: + resp_low_amp_mask = signal_method.detect_low_amplitude_signal( + signal_data=resp_data, + sampling_rate=signal_fs, + window_size_sec=resp_low_amp_conf["window_size_sec"], + stride_sec=resp_low_amp_conf["stride_sec"], + amplitude_threshold=resp_low_amp_conf["amplitude_threshold"], + merge_gap_sec=resp_low_amp_conf["merge_gap_sec"], + min_duration_sec=resp_low_amp_conf["min_duration_sec"] + ) + else: + resp_low_amp_mask = None + + # 分析Resp的高幅值伪迹区间 + resp_move + @@ -69,7 +96,10 @@ if __name__ == '__main__': yaml_path = Path("./dataset_config/HYS_config.yaml") disable_df_path = Path("./排除区间.xlsx") - select_ids, root_path = utils.load_dataset_info(yaml_path) + conf = utils.load_dataset_conf(yaml_path) + select_ids = conf["select_ids"] + root_path = Path(conf["root_path"]) + print(f"select_ids: {select_ids}") print(f"root_path: {root_path}") diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index d30264f..dfff364 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -1,4 +1,4 @@ -select_id: +select_ids: - 1302 - 286 - 950 @@ -10,4 +10,24 @@ select_id: - 684 - 960 -root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS \ No newline at end of file +root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS + +resp_filter: + filter_type: bandpass + low_cut: 0.01 + high_cut: 0.7 + order: 10 + +resp_low_amp: + windows_size_sec: 1 + stride_sec: None + amplitude_threshold: 50 + merge_gap_sec: 10 + min_duration_sec: 5 + +bcg_filter: + filter_type: bandpass + low_cut: 1 + high_cut: 10 + order: 10 + diff --git a/signal_method/__init__.py b/signal_method/__init__.py index e69de29..46eac36 100644 --- a/signal_method/__init__.py +++ b/signal_method/__init__.py @@ -0,0 +1 @@ +from signal_method.rule_base_event import detect_low_amplitude_signal \ No newline at end of file diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index 95a540f..8de49da 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -3,6 +3,175 @@ import numpy as np from utils.operation_tools import merge_short_gaps, remove_short_durations +@timing_decorator() +def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=None, + std_median_multiplier=4.5, compare_intervals_sec=[30, 60], + interval_multiplier=2.5, + merge_gap_sec=10, min_duration_sec=5, + low_amplitude_periods=None): + """ + 检测信号中的体动状态,结合两种方法:标准差比较和前后窗口幅值对比。 + 使用反射填充处理信号边界。 + + 参数: + - signal_data: numpy array,输入的信号数据 + - sampling_rate: int,信号的采样率(Hz) + - window_size_sec: float,分析窗口的时长(秒),默认值为 2 秒 + - stride_sec: float,窗口滑动步长(秒),默认值为None(等于window_size_sec,无重叠) + - std_median_multiplier: float,标准差中位数的乘数阈值,默认值为 4.5 + - compare_intervals_sec: list,用于比较的时间间隔列表(秒),默认为 [30, 60] + - interval_multiplier: float,间隔中位数的上限乘数,默认值为 2.5 + - merge_gap_sec: float,要合并的体动状态之间的最大间隔(秒),默认值为 10 秒 + - min_duration_sec: float,要保留的体动状态的最小持续时间(秒),默认值为 5 秒 + - low_amplitude_periods: numpy array,低幅值期间的掩码(1表示低幅值期间),默认为None + + 返回: + - movement_mask: numpy array,体动状态的掩码(1表示体动,0表示睡眠) + """ + # 计算窗口大小(样本数) + window_samples = int(window_size_sec * sampling_rate) + + # 如果未指定步长,设置为窗口大小(无重叠) + if stride_sec is None: + stride_sec = window_size_sec + + # 计算步长(样本数) + stride_samples = int(stride_sec * sampling_rate) + + # 确保步长至少为1 + stride_samples = max(1, stride_samples) + + # 计算需要的最大填充大小(基于比较间隔) + max_interval_samples = int(max(compare_intervals_sec) * sampling_rate) + + # 应用反射填充以正确处理边界 + # 填充大小为最大比较间隔的一半,以确保边界有足够的上下文 + pad_size = max_interval_samples + padded_signal = np.pad(signal_data, pad_size, mode='reflect') + + # 计算填充后的窗口数量 + num_windows = max(1, (len(padded_signal) - window_samples) // stride_samples + 1) + + # 初始化窗口标准差数组 + window_std = np.zeros(num_windows) + # 计算每个窗口的标准差 + # 分窗计算标准差 + for i in range(num_windows): + start_idx = i * stride_samples + end_idx = min(start_idx + window_samples, len(padded_signal)) + + # 处理窗口,包括可能不完整的最后一个窗口 + window_data = padded_signal[start_idx:end_idx] + if len(window_data) > 0: + window_std[i] = np.std(window_data, ddof=1) + else: + window_std[i] = 0 + + # 计算原始信号对应的窗口索引范围 + # 填充后,原始信号从pad_size开始 + orig_start_window = pad_size // stride_samples + if stride_sec == 1: + orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + else: + orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + 1 + + # 只保留原始信号对应的窗口标准差 + original_window_std = window_std[orig_start_window:orig_end_window] + num_original_windows = len(original_window_std) + + # 创建时间点数组(秒) + time_points = np.arange(num_original_windows) * stride_sec + + # # 如果提供了低幅值期间的掩码,则在计算全局中位数时排除这些期间 + # if low_amplitude_periods is not None and len(low_amplitude_periods) == num_original_windows: + # valid_std = original_window_std[low_amplitude_periods == 0] + # if len(valid_std) == 0: # 如果所有窗口都在低幅值期间 + # valid_std = original_window_std # 使用全部窗口 + # else: + # valid_std = original_window_std + + valid_std = original_window_std ##20250418新修改 + + #---------------------- 方法一:基于STD的体动判定 ----------------------# + # 计算所有有效窗口标准差的中位数 + median_std = np.median(valid_std) + + # 当窗口标准差大于中位数的倍数,判定为体动状态 + std_movement = np.where(original_window_std > median_std * std_median_multiplier, 1, 0) + + #------------------ 方法二:基于前后信号幅值变化的体动判定 ------------------# + amplitude_movement = np.zeros(num_original_windows, dtype=int) + + # 定义基于时间粒度的比较间隔索引 + compare_intervals_idx = [int(interval // stride_sec) for interval in compare_intervals_sec] + + # 逐窗口判断 + for win_idx in range(num_original_windows): + # 全局索引(在填充后的窗口数组中) + global_win_idx = win_idx + orig_start_window + + # 对每个比较间隔进行检查 + for interval_idx in compare_intervals_idx: + # 确定比较范围的结束索引(在填充后的窗口数组中) + end_idx = min(global_win_idx + interval_idx, len(window_std)) + + # 提取相应时间范围内的标准差值 + if global_win_idx < end_idx: + interval_std = window_std[global_win_idx:end_idx] + + # 计算该间隔的中位数 + interval_median = np.median(interval_std) + + # 计算上下阈值 + upper_threshold = interval_median * interval_multiplier + + # 检查当前窗口是否超出阈值范围,如果超出则直接标记为体动 + if window_std[global_win_idx] > upper_threshold: + amplitude_movement[win_idx] = 1 + break # 一旦确定为体动就不需要继续检查其他间隔 + + # 将两种方法的结果合并:只要其中一种方法判定为体动,最终结果就是体动 + movement_mask = np.logical_or(std_movement, amplitude_movement).astype(int) + raw_movement_mask = movement_mask + + # 如果需要合并间隔小的体动状态 + if merge_gap_sec > 0: + movement_mask = merge_short_gaps(movement_mask, time_points, merge_gap_sec) + + # 如果需要移除短时体动状态 + if min_duration_sec > 0: + movement_mask = remove_short_durations(movement_mask, time_points, min_duration_sec) + + # raw_movement_mask, movement_mask恢复对应秒数,而不是点数 + raw_movement_mask = raw_movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate] + movement_mask = movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate] + + + # 比较剔除的体动,如果被剔除的体动所在区域有高于3std的幅值,则不剔除 + removed_movement_mask = (raw_movement_mask - movement_mask) > 0 + removed_movement_start = np.where(np.diff(np.concatenate([[0], removed_movement_mask])) == 1)[0] + removed_movement_end = np.where(np.diff(np.concatenate([removed_movement_mask, [0]])) == -1)[0] + + for start, end in zip(removed_movement_start, removed_movement_end): + # print(start ,end) + # 计算剔除的体动区域的幅值 + if np.nanmax(signal_data[start*sampling_rate:(end+1)*sampling_rate]) > median_std * std_median_multiplier: + movement_mask[start:end+1] = 1 + + # raw体动起止位置 [[start, end], [start, end], ...] + raw_movement_start = np.where(np.diff(np.concatenate([[0], raw_movement_mask])) == 1)[0] + raw_movement_end = np.where(np.diff(np.concatenate([raw_movement_mask, [0]])) == -1)[0] + 1 + raw_movement_position_list = [[start, end] for start, end in zip(raw_movement_start, raw_movement_end)] + + # merge体动起止位置 [[start, end], [start, end], ...] + movement_start = np.where(np.diff(np.concatenate([[0], movement_mask])) == 1)[0] + movement_end = np.where(np.diff(np.concatenate([movement_mask, [0]])) == -1)[0] + 1 + movement_position_list = [[start, end] for start, end in zip(movement_start, movement_end)] + + return raw_movement_mask, movement_mask, raw_movement_position_list, movement_position_list + + + @timing_decorator() def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, stride_sec=None, amplitude_threshold=50, merge_gap_sec=10, min_duration_sec=5): diff --git a/utils/HYS_FileReader.py b/utils/HYS_FileReader.py index 5812da8..dd65bab 100644 --- a/utils/HYS_FileReader.py +++ b/utils/HYS_FileReader.py @@ -29,7 +29,7 @@ def read_signal_txt(path: Union[str, Path]) -> np.ndarray: if HAS_POLARS: df = pl.read_csv(path, has_header=False, infer_schema_length=0) - return df[:, 0].to_numpy() + return df[:, 0].to_numpy().astype(float) else: df = pd.read_csv(path, header=None, dtype=float) return df.iloc[:, 0].to_numpy() @@ -41,8 +41,11 @@ def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame: Args: path (str | Path): Path to the CSV file. + verbose (bool): Returns: pd.DataFrame: The content of the CSV file as a pandas DataFrame. + :param path: + :param verbose: """ path = Path(path) if not path.exists(): diff --git a/utils/__init__.py b/utils/__init__.py index 87cc5b9..faaebaf 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,3 +1,4 @@ from utils.HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel -from utils.operation_tools import load_dataset_info, generate_disable_mask, generate_event_mask -from utils.event_map import E2N \ No newline at end of file +from utils.operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask +from utils.event_map import E2N +from utils.signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter \ No newline at end of file diff --git a/utils/operation_tools.py b/utils/operation_tools.py index c83c82e..f775d73 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -47,85 +47,6 @@ def read_auto(file_path): else: raise ValueError('这个文件类型不支持,需要自己写读取程序') -@timing_decorator() -def Butterworth(data, type, low_cut=0.0, high_cut=0.0, order=10,sample_rate=1000): - - if type == "lowpass": # 低通滤波处理 - sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos') - return signal.sosfiltfilt(sos, np.array(data)) - elif type == "bandpass": # 带通滤波处理 - low = low_cut / (sample_rate * 0.5) - high = high_cut / (sample_rate * 0.5) - sos = signal.butter(order, [low, high], btype='bandpass', output='sos') - return signal.sosfiltfilt(sos, np.array(data)) - elif type == "highpass": # 高通滤波处理 - sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos') - return signal.sosfiltfilt(sos, np.array(data)) - else: # 警告,滤波器类型必须有 - raise ValueError("Please choose a type of fliter") - -@timing_decorator() -def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000): - """ - 高效整数倍降采样长信号(适合8小时以上),分段处理以优化内存和速度。 - - 参数: - original_signal : array-like, 原始信号数组 - original_fs : float, 原始采样率 (Hz) - target_fs : float, 目标采样率 (Hz) - chunk_size : int, 每段处理的样本数,默认100000 - - 返回: - downsampled_signal : array-like, 降采样后的信号 - """ - # 输入验证 - if not isinstance(original_signal, np.ndarray): - original_signal = np.array(original_signal) - if original_fs <= target_fs: - raise ValueError("目标采样率必须小于原始采样率") - if target_fs <= 0 or original_fs <= 0: - raise ValueError("采样率必须为正数") - - # 计算降采样因子(必须为整数) - downsample_factor = original_fs / target_fs - if not downsample_factor.is_integer(): - raise ValueError("降采样因子必须为整数倍") - downsample_factor = int(downsample_factor) - - # 计算总输出长度 - total_length = len(original_signal) - output_length = total_length // downsample_factor - - # 初始化输出数组 - downsampled_signal = np.zeros(output_length) - - # 分段处理 - for start in range(0, total_length, chunk_size): - end = min(start + chunk_size, total_length) - chunk = original_signal[start:end] - - # 使用decimate进行整数倍降采样 - chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True) - - # 计算输出位置 - out_start = start // downsample_factor - out_end = out_start + len(chunk_downsampled) - if out_end > output_length: - chunk_downsampled = chunk_downsampled[:output_length - out_start] - - downsampled_signal[out_start:out_end] = chunk_downsampled - - return downsampled_signal - -@timing_decorator() -def average_filter(raw_data, sample_rate, window_size=20): - kernel = np.ones(window_size * sample_rate) / (window_size * sample_rate) - filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect') - convolve_filter_signal = raw_data - filtered - return convolve_filter_signal - - - def merge_short_gaps(state_sequence, time_points, max_gap_sec): """ @@ -252,14 +173,14 @@ def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, return values_nan, values -def load_dataset_info(yaml_path): +def load_dataset_conf(yaml_path): with open(yaml_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) - select_ids = config.get('select_id', []) - root_path = config.get('root_path', None) - data_path = Path(root_path) - return select_ids, data_path + # select_ids = config.get('select_id', []) + # root_path = config.get('root_path', None) + # data_path = Path(root_path) + return config def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray: @@ -285,36 +206,3 @@ def generate_event_mask(signal_second: int, event_df): score_mask[start:end] = row["score"] return event_mask, score_mask - -def slide_window_segment(signal_second: int, window_second, step_second, event_mask, score_mask, disable_mask, ): - # 避开不可用区域进行滑窗分割 - # 滑动到不可用区域时,如果窗口内一侧的不可用区域不超过1/2 windows_second,则继续滑动, 用reflect填充 - # 如果不可用区间大于1/2的window_second,则跳过该不可用区间,继续滑动 - # TODO 对于短时强体动区间 考虑填充或者掩码覆盖 - # - half_window_second = window_second // 2 - for start_second in range(0, signal_second - window_second + 1, step_second): - end_second = start_second + window_second - - # 检查当前窗口是否包含不可用区域 - windows_middle_second = (start_second + end_second) // 2 - if np.sum(disable_mask[start_second:end_second] > 1) > half_window_second: - # 如果窗口内不可用区域超过一半,跳过该窗口 - continue - - if disable_mask[start_second:end_second] > half_window_second: - - - - - - # 确保新的起始位置不超过信号长度 - if start_second + window_second > signal_second: - break - - window_event = event_mask[start_second:end_second] - window_score = score_mask[start_second:end_second] - window_disable = disable_mask[start_second:end_second] - - yield start_second, end_second, window_event, window_score, window_disable - diff --git a/utils/signal_process.py b/utils/signal_process.py new file mode 100644 index 0000000..dea0ea1 --- /dev/null +++ b/utils/signal_process.py @@ -0,0 +1,92 @@ +from utils.operation_tools import timing_decorator +import numpy as np +from scipy import signal, ndimage + + +@timing_decorator() +def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10,sample_rate=1000): + + if _type == "lowpass": # 低通滤波处理 + sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + elif _type == "bandpass": # 带通滤波处理 + low = low_cut / (sample_rate * 0.5) + high = high_cut / (sample_rate * 0.5) + sos = signal.butter(order, [low, high], btype='bandpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + elif _type == "highpass": # 高通滤波处理 + sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + else: # 警告,滤波器类型必须有 + raise ValueError("Please choose a type of fliter") + + +@timing_decorator() +def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000): + """ + 高效整数倍降采样长信号,分段处理以优化内存和速度。 + + 参数: + original_signal : array-like, 原始信号数组 + original_fs : float, 原始采样率 (Hz) + target_fs : float, 目标采样率 (Hz) + chunk_size : int, 每段处理的样本数,默认100000 + + 返回: + downsampled_signal : array-like, 降采样后的信号 + """ + # 输入验证 + if not isinstance(original_signal, np.ndarray): + original_signal = np.array(original_signal) + if original_fs <= target_fs: + raise ValueError("目标采样率必须小于原始采样率") + if target_fs <= 0 or original_fs <= 0: + raise ValueError("采样率必须为正数") + + # 计算降采样因子(必须为整数) + downsample_factor = original_fs / target_fs + if not downsample_factor.is_integer(): + raise ValueError("降采样因子必须为整数倍") + downsample_factor = int(downsample_factor) + + # 计算总输出长度 + total_length = len(original_signal) + output_length = total_length // downsample_factor + + # 初始化输出数组 + downsampled_signal = np.zeros(output_length) + + # 分段处理 + for start in range(0, total_length, chunk_size): + end = min(start + chunk_size, total_length) + chunk = original_signal[start:end] + + # 使用decimate进行整数倍降采样 + chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True) + + # 计算输出位置 + out_start = start // downsample_factor + out_end = out_start + len(chunk_downsampled) + if out_end > output_length: + chunk_downsampled = chunk_downsampled[:output_length - out_start] + + downsampled_signal[out_start:out_end] = chunk_downsampled + + return downsampled_signal + +@timing_decorator() +def average_filter(raw_data, sample_rate, window_size=20): + kernel = np.ones(window_size * sample_rate) / (window_size * sample_rate) + filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect') + convolve_filter_signal = raw_data - filtered + return convolve_filter_signal + + +# 陷波滤波器 +@timing_decorator() +def notch_filter(data, notch_freq=50.0, quality_factor=30.0, sample_rate=1000): + nyquist = 0.5 * sample_rate + norm_notch_freq = notch_freq / nyquist + b, a = signal.iirnotch(norm_notch_freq, quality_factor) + filtered_data = signal.filtfilt(b, a, data) + return filtered_data