初始化多个文件,完成基本读取功能
This commit is contained in:
parent
f79f42fae7
commit
40aad46d6f
@ -24,7 +24,7 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
import utils
|
||||
import numpy as np
|
||||
|
||||
import signal_method
|
||||
|
||||
|
||||
|
||||
@ -50,13 +50,40 @@ def process_one_signal(samp_id):
|
||||
signal_second = signal_length // signal_fs
|
||||
print(f"signal_second: {signal_second}")
|
||||
|
||||
# 滤波
|
||||
# 50Hz陷波滤波器
|
||||
# signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs)
|
||||
resp_data = utils.butterworth(data=signal_data, _type=conf["resp"]["filter_type"], low_cut=conf["resp"]["low_cut"],
|
||||
high_cut=conf["resp"]["high_cut"], order=conf["resp"]["order"], sample_rate=signal_fs)
|
||||
|
||||
label_data = utils.read_label_csv(label_path)
|
||||
label_mask = utils.generate_event_mask(signal_second, label_data)
|
||||
bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg"]["filter_type"], low_cut=conf["bcg"]["low_cut"],
|
||||
high_cut=conf["bcg"]["high_cut"], order=conf["bcg"]["order"], sample_rate=signal_fs)
|
||||
|
||||
manual_disable_mask = utils.generate_disable_mask(signal_second, all_samp_disable_df[all_samp_disable_df["id"] == samp_id])
|
||||
|
||||
label_data = utils.read_label_csv(path=label_path)
|
||||
label_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data)
|
||||
|
||||
manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[all_samp_disable_df["id"] == samp_id])
|
||||
print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}")
|
||||
|
||||
# 分析Resp的低幅值区间
|
||||
resp_low_amp_conf = getattr(conf, "resp_low_amp", None)
|
||||
if resp_low_amp_conf is not None:
|
||||
resp_low_amp_mask = signal_method.detect_low_amplitude_signal(
|
||||
signal_data=resp_data,
|
||||
sampling_rate=signal_fs,
|
||||
window_size_sec=resp_low_amp_conf["window_size_sec"],
|
||||
stride_sec=resp_low_amp_conf["stride_sec"],
|
||||
amplitude_threshold=resp_low_amp_conf["amplitude_threshold"],
|
||||
merge_gap_sec=resp_low_amp_conf["merge_gap_sec"],
|
||||
min_duration_sec=resp_low_amp_conf["min_duration_sec"]
|
||||
)
|
||||
else:
|
||||
resp_low_amp_mask = None
|
||||
|
||||
# 分析Resp的高幅值伪迹区间
|
||||
resp_move
|
||||
|
||||
|
||||
|
||||
|
||||
@ -69,7 +96,10 @@ if __name__ == '__main__':
|
||||
yaml_path = Path("./dataset_config/HYS_config.yaml")
|
||||
disable_df_path = Path("./排除区间.xlsx")
|
||||
|
||||
select_ids, root_path = utils.load_dataset_info(yaml_path)
|
||||
conf = utils.load_dataset_conf(yaml_path)
|
||||
select_ids = conf["select_ids"]
|
||||
root_path = Path(conf["root_path"])
|
||||
|
||||
print(f"select_ids: {select_ids}")
|
||||
print(f"root_path: {root_path}")
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
select_id:
|
||||
select_ids:
|
||||
- 1302
|
||||
- 286
|
||||
- 950
|
||||
@ -10,4 +10,24 @@ select_id:
|
||||
- 684
|
||||
- 960
|
||||
|
||||
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
|
||||
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
|
||||
|
||||
resp_filter:
|
||||
filter_type: bandpass
|
||||
low_cut: 0.01
|
||||
high_cut: 0.7
|
||||
order: 10
|
||||
|
||||
resp_low_amp:
|
||||
windows_size_sec: 1
|
||||
stride_sec: None
|
||||
amplitude_threshold: 50
|
||||
merge_gap_sec: 10
|
||||
min_duration_sec: 5
|
||||
|
||||
bcg_filter:
|
||||
filter_type: bandpass
|
||||
low_cut: 1
|
||||
high_cut: 10
|
||||
order: 10
|
||||
|
||||
|
||||
@ -0,0 +1 @@
|
||||
from signal_method.rule_base_event import detect_low_amplitude_signal
|
||||
@ -3,6 +3,175 @@ import numpy as np
|
||||
from utils.operation_tools import merge_short_gaps, remove_short_durations
|
||||
|
||||
|
||||
@timing_decorator()
|
||||
def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=None,
|
||||
std_median_multiplier=4.5, compare_intervals_sec=[30, 60],
|
||||
interval_multiplier=2.5,
|
||||
merge_gap_sec=10, min_duration_sec=5,
|
||||
low_amplitude_periods=None):
|
||||
"""
|
||||
检测信号中的体动状态,结合两种方法:标准差比较和前后窗口幅值对比。
|
||||
使用反射填充处理信号边界。
|
||||
|
||||
参数:
|
||||
- signal_data: numpy array,输入的信号数据
|
||||
- sampling_rate: int,信号的采样率(Hz)
|
||||
- window_size_sec: float,分析窗口的时长(秒),默认值为 2 秒
|
||||
- stride_sec: float,窗口滑动步长(秒),默认值为None(等于window_size_sec,无重叠)
|
||||
- std_median_multiplier: float,标准差中位数的乘数阈值,默认值为 4.5
|
||||
- compare_intervals_sec: list,用于比较的时间间隔列表(秒),默认为 [30, 60]
|
||||
- interval_multiplier: float,间隔中位数的上限乘数,默认值为 2.5
|
||||
- merge_gap_sec: float,要合并的体动状态之间的最大间隔(秒),默认值为 10 秒
|
||||
- min_duration_sec: float,要保留的体动状态的最小持续时间(秒),默认值为 5 秒
|
||||
- low_amplitude_periods: numpy array,低幅值期间的掩码(1表示低幅值期间),默认为None
|
||||
|
||||
返回:
|
||||
- movement_mask: numpy array,体动状态的掩码(1表示体动,0表示睡眠)
|
||||
"""
|
||||
# 计算窗口大小(样本数)
|
||||
window_samples = int(window_size_sec * sampling_rate)
|
||||
|
||||
# 如果未指定步长,设置为窗口大小(无重叠)
|
||||
if stride_sec is None:
|
||||
stride_sec = window_size_sec
|
||||
|
||||
# 计算步长(样本数)
|
||||
stride_samples = int(stride_sec * sampling_rate)
|
||||
|
||||
# 确保步长至少为1
|
||||
stride_samples = max(1, stride_samples)
|
||||
|
||||
# 计算需要的最大填充大小(基于比较间隔)
|
||||
max_interval_samples = int(max(compare_intervals_sec) * sampling_rate)
|
||||
|
||||
# 应用反射填充以正确处理边界
|
||||
# 填充大小为最大比较间隔的一半,以确保边界有足够的上下文
|
||||
pad_size = max_interval_samples
|
||||
padded_signal = np.pad(signal_data, pad_size, mode='reflect')
|
||||
|
||||
# 计算填充后的窗口数量
|
||||
num_windows = max(1, (len(padded_signal) - window_samples) // stride_samples + 1)
|
||||
|
||||
# 初始化窗口标准差数组
|
||||
window_std = np.zeros(num_windows)
|
||||
# 计算每个窗口的标准差
|
||||
# 分窗计算标准差
|
||||
for i in range(num_windows):
|
||||
start_idx = i * stride_samples
|
||||
end_idx = min(start_idx + window_samples, len(padded_signal))
|
||||
|
||||
# 处理窗口,包括可能不完整的最后一个窗口
|
||||
window_data = padded_signal[start_idx:end_idx]
|
||||
if len(window_data) > 0:
|
||||
window_std[i] = np.std(window_data, ddof=1)
|
||||
else:
|
||||
window_std[i] = 0
|
||||
|
||||
# 计算原始信号对应的窗口索引范围
|
||||
# 填充后,原始信号从pad_size开始
|
||||
orig_start_window = pad_size // stride_samples
|
||||
if stride_sec == 1:
|
||||
orig_end_window = orig_start_window + (len(signal_data) // stride_samples)
|
||||
else:
|
||||
orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + 1
|
||||
|
||||
# 只保留原始信号对应的窗口标准差
|
||||
original_window_std = window_std[orig_start_window:orig_end_window]
|
||||
num_original_windows = len(original_window_std)
|
||||
|
||||
# 创建时间点数组(秒)
|
||||
time_points = np.arange(num_original_windows) * stride_sec
|
||||
|
||||
# # 如果提供了低幅值期间的掩码,则在计算全局中位数时排除这些期间
|
||||
# if low_amplitude_periods is not None and len(low_amplitude_periods) == num_original_windows:
|
||||
# valid_std = original_window_std[low_amplitude_periods == 0]
|
||||
# if len(valid_std) == 0: # 如果所有窗口都在低幅值期间
|
||||
# valid_std = original_window_std # 使用全部窗口
|
||||
# else:
|
||||
# valid_std = original_window_std
|
||||
|
||||
valid_std = original_window_std ##20250418新修改
|
||||
|
||||
#---------------------- 方法一:基于STD的体动判定 ----------------------#
|
||||
# 计算所有有效窗口标准差的中位数
|
||||
median_std = np.median(valid_std)
|
||||
|
||||
# 当窗口标准差大于中位数的倍数,判定为体动状态
|
||||
std_movement = np.where(original_window_std > median_std * std_median_multiplier, 1, 0)
|
||||
|
||||
#------------------ 方法二:基于前后信号幅值变化的体动判定 ------------------#
|
||||
amplitude_movement = np.zeros(num_original_windows, dtype=int)
|
||||
|
||||
# 定义基于时间粒度的比较间隔索引
|
||||
compare_intervals_idx = [int(interval // stride_sec) for interval in compare_intervals_sec]
|
||||
|
||||
# 逐窗口判断
|
||||
for win_idx in range(num_original_windows):
|
||||
# 全局索引(在填充后的窗口数组中)
|
||||
global_win_idx = win_idx + orig_start_window
|
||||
|
||||
# 对每个比较间隔进行检查
|
||||
for interval_idx in compare_intervals_idx:
|
||||
# 确定比较范围的结束索引(在填充后的窗口数组中)
|
||||
end_idx = min(global_win_idx + interval_idx, len(window_std))
|
||||
|
||||
# 提取相应时间范围内的标准差值
|
||||
if global_win_idx < end_idx:
|
||||
interval_std = window_std[global_win_idx:end_idx]
|
||||
|
||||
# 计算该间隔的中位数
|
||||
interval_median = np.median(interval_std)
|
||||
|
||||
# 计算上下阈值
|
||||
upper_threshold = interval_median * interval_multiplier
|
||||
|
||||
# 检查当前窗口是否超出阈值范围,如果超出则直接标记为体动
|
||||
if window_std[global_win_idx] > upper_threshold:
|
||||
amplitude_movement[win_idx] = 1
|
||||
break # 一旦确定为体动就不需要继续检查其他间隔
|
||||
|
||||
# 将两种方法的结果合并:只要其中一种方法判定为体动,最终结果就是体动
|
||||
movement_mask = np.logical_or(std_movement, amplitude_movement).astype(int)
|
||||
raw_movement_mask = movement_mask
|
||||
|
||||
# 如果需要合并间隔小的体动状态
|
||||
if merge_gap_sec > 0:
|
||||
movement_mask = merge_short_gaps(movement_mask, time_points, merge_gap_sec)
|
||||
|
||||
# 如果需要移除短时体动状态
|
||||
if min_duration_sec > 0:
|
||||
movement_mask = remove_short_durations(movement_mask, time_points, min_duration_sec)
|
||||
|
||||
# raw_movement_mask, movement_mask恢复对应秒数,而不是点数
|
||||
raw_movement_mask = raw_movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate]
|
||||
movement_mask = movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate]
|
||||
|
||||
|
||||
# 比较剔除的体动,如果被剔除的体动所在区域有高于3std的幅值,则不剔除
|
||||
removed_movement_mask = (raw_movement_mask - movement_mask) > 0
|
||||
removed_movement_start = np.where(np.diff(np.concatenate([[0], removed_movement_mask])) == 1)[0]
|
||||
removed_movement_end = np.where(np.diff(np.concatenate([removed_movement_mask, [0]])) == -1)[0]
|
||||
|
||||
for start, end in zip(removed_movement_start, removed_movement_end):
|
||||
# print(start ,end)
|
||||
# 计算剔除的体动区域的幅值
|
||||
if np.nanmax(signal_data[start*sampling_rate:(end+1)*sampling_rate]) > median_std * std_median_multiplier:
|
||||
movement_mask[start:end+1] = 1
|
||||
|
||||
# raw体动起止位置 [[start, end], [start, end], ...]
|
||||
raw_movement_start = np.where(np.diff(np.concatenate([[0], raw_movement_mask])) == 1)[0]
|
||||
raw_movement_end = np.where(np.diff(np.concatenate([raw_movement_mask, [0]])) == -1)[0] + 1
|
||||
raw_movement_position_list = [[start, end] for start, end in zip(raw_movement_start, raw_movement_end)]
|
||||
|
||||
# merge体动起止位置 [[start, end], [start, end], ...]
|
||||
movement_start = np.where(np.diff(np.concatenate([[0], movement_mask])) == 1)[0]
|
||||
movement_end = np.where(np.diff(np.concatenate([movement_mask, [0]])) == -1)[0] + 1
|
||||
movement_position_list = [[start, end] for start, end in zip(movement_start, movement_end)]
|
||||
|
||||
return raw_movement_mask, movement_mask, raw_movement_position_list, movement_position_list
|
||||
|
||||
|
||||
|
||||
@timing_decorator()
|
||||
def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, stride_sec=None,
|
||||
amplitude_threshold=50, merge_gap_sec=10, min_duration_sec=5):
|
||||
|
||||
@ -29,7 +29,7 @@ def read_signal_txt(path: Union[str, Path]) -> np.ndarray:
|
||||
|
||||
if HAS_POLARS:
|
||||
df = pl.read_csv(path, has_header=False, infer_schema_length=0)
|
||||
return df[:, 0].to_numpy()
|
||||
return df[:, 0].to_numpy().astype(float)
|
||||
else:
|
||||
df = pd.read_csv(path, header=None, dtype=float)
|
||||
return df.iloc[:, 0].to_numpy()
|
||||
@ -41,8 +41,11 @@ def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame:
|
||||
|
||||
Args:
|
||||
path (str | Path): Path to the CSV file.
|
||||
verbose (bool):
|
||||
Returns:
|
||||
pd.DataFrame: The content of the CSV file as a pandas DataFrame.
|
||||
:param path:
|
||||
:param verbose:
|
||||
"""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from utils.HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel
|
||||
from utils.operation_tools import load_dataset_info, generate_disable_mask, generate_event_mask
|
||||
from utils.event_map import E2N
|
||||
from utils.operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask
|
||||
from utils.event_map import E2N
|
||||
from utils.signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter
|
||||
@ -47,85 +47,6 @@ def read_auto(file_path):
|
||||
else:
|
||||
raise ValueError('这个文件类型不支持,需要自己写读取程序')
|
||||
|
||||
@timing_decorator()
|
||||
def Butterworth(data, type, low_cut=0.0, high_cut=0.0, order=10,sample_rate=1000):
|
||||
|
||||
if type == "lowpass": # 低通滤波处理
|
||||
sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos')
|
||||
return signal.sosfiltfilt(sos, np.array(data))
|
||||
elif type == "bandpass": # 带通滤波处理
|
||||
low = low_cut / (sample_rate * 0.5)
|
||||
high = high_cut / (sample_rate * 0.5)
|
||||
sos = signal.butter(order, [low, high], btype='bandpass', output='sos')
|
||||
return signal.sosfiltfilt(sos, np.array(data))
|
||||
elif type == "highpass": # 高通滤波处理
|
||||
sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos')
|
||||
return signal.sosfiltfilt(sos, np.array(data))
|
||||
else: # 警告,滤波器类型必须有
|
||||
raise ValueError("Please choose a type of fliter")
|
||||
|
||||
@timing_decorator()
|
||||
def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000):
|
||||
"""
|
||||
高效整数倍降采样长信号(适合8小时以上),分段处理以优化内存和速度。
|
||||
|
||||
参数:
|
||||
original_signal : array-like, 原始信号数组
|
||||
original_fs : float, 原始采样率 (Hz)
|
||||
target_fs : float, 目标采样率 (Hz)
|
||||
chunk_size : int, 每段处理的样本数,默认100000
|
||||
|
||||
返回:
|
||||
downsampled_signal : array-like, 降采样后的信号
|
||||
"""
|
||||
# 输入验证
|
||||
if not isinstance(original_signal, np.ndarray):
|
||||
original_signal = np.array(original_signal)
|
||||
if original_fs <= target_fs:
|
||||
raise ValueError("目标采样率必须小于原始采样率")
|
||||
if target_fs <= 0 or original_fs <= 0:
|
||||
raise ValueError("采样率必须为正数")
|
||||
|
||||
# 计算降采样因子(必须为整数)
|
||||
downsample_factor = original_fs / target_fs
|
||||
if not downsample_factor.is_integer():
|
||||
raise ValueError("降采样因子必须为整数倍")
|
||||
downsample_factor = int(downsample_factor)
|
||||
|
||||
# 计算总输出长度
|
||||
total_length = len(original_signal)
|
||||
output_length = total_length // downsample_factor
|
||||
|
||||
# 初始化输出数组
|
||||
downsampled_signal = np.zeros(output_length)
|
||||
|
||||
# 分段处理
|
||||
for start in range(0, total_length, chunk_size):
|
||||
end = min(start + chunk_size, total_length)
|
||||
chunk = original_signal[start:end]
|
||||
|
||||
# 使用decimate进行整数倍降采样
|
||||
chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True)
|
||||
|
||||
# 计算输出位置
|
||||
out_start = start // downsample_factor
|
||||
out_end = out_start + len(chunk_downsampled)
|
||||
if out_end > output_length:
|
||||
chunk_downsampled = chunk_downsampled[:output_length - out_start]
|
||||
|
||||
downsampled_signal[out_start:out_end] = chunk_downsampled
|
||||
|
||||
return downsampled_signal
|
||||
|
||||
@timing_decorator()
|
||||
def average_filter(raw_data, sample_rate, window_size=20):
|
||||
kernel = np.ones(window_size * sample_rate) / (window_size * sample_rate)
|
||||
filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect')
|
||||
convolve_filter_signal = raw_data - filtered
|
||||
return convolve_filter_signal
|
||||
|
||||
|
||||
|
||||
|
||||
def merge_short_gaps(state_sequence, time_points, max_gap_sec):
|
||||
"""
|
||||
@ -252,14 +173,14 @@ def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100,
|
||||
return values_nan, values
|
||||
|
||||
|
||||
def load_dataset_info(yaml_path):
|
||||
def load_dataset_conf(yaml_path):
|
||||
with open(yaml_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
select_ids = config.get('select_id', [])
|
||||
root_path = config.get('root_path', None)
|
||||
data_path = Path(root_path)
|
||||
return select_ids, data_path
|
||||
# select_ids = config.get('select_id', [])
|
||||
# root_path = config.get('root_path', None)
|
||||
# data_path = Path(root_path)
|
||||
return config
|
||||
|
||||
|
||||
def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray:
|
||||
@ -285,36 +206,3 @@ def generate_event_mask(signal_second: int, event_df):
|
||||
score_mask[start:end] = row["score"]
|
||||
return event_mask, score_mask
|
||||
|
||||
|
||||
def slide_window_segment(signal_second: int, window_second, step_second, event_mask, score_mask, disable_mask, ):
|
||||
# 避开不可用区域进行滑窗分割
|
||||
# 滑动到不可用区域时,如果窗口内一侧的不可用区域不超过1/2 windows_second,则继续滑动, 用reflect填充
|
||||
# 如果不可用区间大于1/2的window_second,则跳过该不可用区间,继续滑动
|
||||
# TODO 对于短时强体动区间 考虑填充或者掩码覆盖
|
||||
#
|
||||
half_window_second = window_second // 2
|
||||
for start_second in range(0, signal_second - window_second + 1, step_second):
|
||||
end_second = start_second + window_second
|
||||
|
||||
# 检查当前窗口是否包含不可用区域
|
||||
windows_middle_second = (start_second + end_second) // 2
|
||||
if np.sum(disable_mask[start_second:end_second] > 1) > half_window_second:
|
||||
# 如果窗口内不可用区域超过一半,跳过该窗口
|
||||
continue
|
||||
|
||||
if disable_mask[start_second:end_second] > half_window_second:
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 确保新的起始位置不超过信号长度
|
||||
if start_second + window_second > signal_second:
|
||||
break
|
||||
|
||||
window_event = event_mask[start_second:end_second]
|
||||
window_score = score_mask[start_second:end_second]
|
||||
window_disable = disable_mask[start_second:end_second]
|
||||
|
||||
yield start_second, end_second, window_event, window_score, window_disable
|
||||
|
||||
|
||||
92
utils/signal_process.py
Normal file
92
utils/signal_process.py
Normal file
@ -0,0 +1,92 @@
|
||||
from utils.operation_tools import timing_decorator
|
||||
import numpy as np
|
||||
from scipy import signal, ndimage
|
||||
|
||||
|
||||
@timing_decorator()
|
||||
def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10,sample_rate=1000):
|
||||
|
||||
if _type == "lowpass": # 低通滤波处理
|
||||
sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos')
|
||||
return signal.sosfiltfilt(sos, np.array(data))
|
||||
elif _type == "bandpass": # 带通滤波处理
|
||||
low = low_cut / (sample_rate * 0.5)
|
||||
high = high_cut / (sample_rate * 0.5)
|
||||
sos = signal.butter(order, [low, high], btype='bandpass', output='sos')
|
||||
return signal.sosfiltfilt(sos, np.array(data))
|
||||
elif _type == "highpass": # 高通滤波处理
|
||||
sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos')
|
||||
return signal.sosfiltfilt(sos, np.array(data))
|
||||
else: # 警告,滤波器类型必须有
|
||||
raise ValueError("Please choose a type of fliter")
|
||||
|
||||
|
||||
@timing_decorator()
|
||||
def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000):
|
||||
"""
|
||||
高效整数倍降采样长信号,分段处理以优化内存和速度。
|
||||
|
||||
参数:
|
||||
original_signal : array-like, 原始信号数组
|
||||
original_fs : float, 原始采样率 (Hz)
|
||||
target_fs : float, 目标采样率 (Hz)
|
||||
chunk_size : int, 每段处理的样本数,默认100000
|
||||
|
||||
返回:
|
||||
downsampled_signal : array-like, 降采样后的信号
|
||||
"""
|
||||
# 输入验证
|
||||
if not isinstance(original_signal, np.ndarray):
|
||||
original_signal = np.array(original_signal)
|
||||
if original_fs <= target_fs:
|
||||
raise ValueError("目标采样率必须小于原始采样率")
|
||||
if target_fs <= 0 or original_fs <= 0:
|
||||
raise ValueError("采样率必须为正数")
|
||||
|
||||
# 计算降采样因子(必须为整数)
|
||||
downsample_factor = original_fs / target_fs
|
||||
if not downsample_factor.is_integer():
|
||||
raise ValueError("降采样因子必须为整数倍")
|
||||
downsample_factor = int(downsample_factor)
|
||||
|
||||
# 计算总输出长度
|
||||
total_length = len(original_signal)
|
||||
output_length = total_length // downsample_factor
|
||||
|
||||
# 初始化输出数组
|
||||
downsampled_signal = np.zeros(output_length)
|
||||
|
||||
# 分段处理
|
||||
for start in range(0, total_length, chunk_size):
|
||||
end = min(start + chunk_size, total_length)
|
||||
chunk = original_signal[start:end]
|
||||
|
||||
# 使用decimate进行整数倍降采样
|
||||
chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True)
|
||||
|
||||
# 计算输出位置
|
||||
out_start = start // downsample_factor
|
||||
out_end = out_start + len(chunk_downsampled)
|
||||
if out_end > output_length:
|
||||
chunk_downsampled = chunk_downsampled[:output_length - out_start]
|
||||
|
||||
downsampled_signal[out_start:out_end] = chunk_downsampled
|
||||
|
||||
return downsampled_signal
|
||||
|
||||
@timing_decorator()
|
||||
def average_filter(raw_data, sample_rate, window_size=20):
|
||||
kernel = np.ones(window_size * sample_rate) / (window_size * sample_rate)
|
||||
filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect')
|
||||
convolve_filter_signal = raw_data - filtered
|
||||
return convolve_filter_signal
|
||||
|
||||
|
||||
# 陷波滤波器
|
||||
@timing_decorator()
|
||||
def notch_filter(data, notch_freq=50.0, quality_factor=30.0, sample_rate=1000):
|
||||
nyquist = 0.5 * sample_rate
|
||||
norm_notch_freq = notch_freq / nyquist
|
||||
b, a = signal.iirnotch(norm_notch_freq, quality_factor)
|
||||
filtered_data = signal.filtfilt(b, a, data)
|
||||
return filtered_data
|
||||
Loading…
Reference in New Issue
Block a user