109 lines
4.4 KiB
Python
109 lines
4.4 KiB
Python
from utils.operation_tools import timing_decorator
|
||
import numpy as np
|
||
from scipy import signal, ndimage
|
||
|
||
|
||
@timing_decorator()
|
||
def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10, sample_rate=1000):
|
||
if _type == "lowpass": # 低通滤波处理
|
||
sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos')
|
||
return signal.sosfiltfilt(sos, np.array(data))
|
||
elif _type == "bandpass": # 带通滤波处理
|
||
low = low_cut / (sample_rate * 0.5)
|
||
high = high_cut / (sample_rate * 0.5)
|
||
sos = signal.butter(order, [low, high], btype='bandpass', output='sos')
|
||
return signal.sosfiltfilt(sos, np.array(data))
|
||
elif _type == "highpass": # 高通滤波处理
|
||
sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos')
|
||
return signal.sosfiltfilt(sos, np.array(data))
|
||
else: # 警告,滤波器类型必须有
|
||
raise ValueError("Please choose a type of fliter")
|
||
|
||
|
||
def bessel(data, _type, low_cut=0.0, high_cut=0.0, order=4, sample_rate=1000):
|
||
if _type == "lowpass": # 低通滤波处理
|
||
b, a = signal.bessel(order, low_cut / (sample_rate * 0.5), btype='lowpass', analog=False, norm='mag')
|
||
return signal.filtfilt(b, a, np.array(data))
|
||
elif _type == "bandpass": # 带通滤波处理
|
||
low = low_cut / (sample_rate * 0.5)
|
||
high = high_cut / (sample_rate * 0.5)
|
||
b, a = signal.bessel(order, [low, high], btype='bandpass', analog=False, norm='mag')
|
||
return signal.filtfilt(b, a, np.array(data))
|
||
elif _type == "highpass": # 高通滤波处理
|
||
b, a = signal.bessel(order, high_cut / (sample_rate * 0.5), btype='highpass', analog=False, norm='mag')
|
||
return signal.filtfilt(b, a, np.array(data))
|
||
else: # 警告,滤波器类型必须有
|
||
raise ValueError("Please choose a type of fliter")
|
||
|
||
|
||
@timing_decorator()
|
||
def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000):
|
||
"""
|
||
高效整数倍降采样长信号,分段处理以优化内存和速度。
|
||
|
||
参数:
|
||
original_signal : array-like, 原始信号数组
|
||
original_fs : float, 原始采样率 (Hz)
|
||
target_fs : float, 目标采样率 (Hz)
|
||
chunk_size : int, 每段处理的样本数,默认100000
|
||
|
||
返回:
|
||
downsampled_signal : array-like, 降采样后的信号
|
||
"""
|
||
# 输入验证
|
||
if not isinstance(original_signal, np.ndarray):
|
||
original_signal = np.array(original_signal)
|
||
if original_fs <= target_fs:
|
||
raise ValueError("目标采样率必须小于原始采样率")
|
||
if target_fs <= 0 or original_fs <= 0:
|
||
raise ValueError("采样率必须为正数")
|
||
|
||
# 计算降采样因子(必须为整数)
|
||
downsample_factor = original_fs / target_fs
|
||
if not downsample_factor.is_integer():
|
||
raise ValueError("降采样因子必须为整数倍")
|
||
downsample_factor = int(downsample_factor)
|
||
|
||
# 计算总输出长度
|
||
total_length = len(original_signal)
|
||
output_length = total_length // downsample_factor
|
||
|
||
# 初始化输出数组
|
||
downsampled_signal = np.zeros(output_length)
|
||
|
||
# 分段处理
|
||
for start in range(0, total_length, chunk_size):
|
||
end = min(start + chunk_size, total_length)
|
||
chunk = original_signal[start:end]
|
||
|
||
# 使用decimate进行整数倍降采样
|
||
chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True)
|
||
|
||
# 计算输出位置
|
||
out_start = start // downsample_factor
|
||
out_end = out_start + len(chunk_downsampled)
|
||
if out_end > output_length:
|
||
chunk_downsampled = chunk_downsampled[:output_length - out_start]
|
||
|
||
downsampled_signal[out_start:out_end] = chunk_downsampled
|
||
|
||
return downsampled_signal
|
||
|
||
|
||
@timing_decorator()
|
||
def average_filter(raw_data, sample_rate, window_size_sec=20):
|
||
kernel = np.ones(window_size_sec * sample_rate) / (window_size_sec * sample_rate)
|
||
filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect')
|
||
convolve_filter_signal = raw_data - filtered
|
||
return convolve_filter_signal
|
||
|
||
|
||
# 陷波滤波器
|
||
@timing_decorator()
|
||
def notch_filter(data, notch_freq=50.0, quality_factor=30.0, sample_rate=1000):
|
||
nyquist = 0.5 * sample_rate
|
||
norm_notch_freq = notch_freq / nyquist
|
||
b, a = signal.iirnotch(norm_notch_freq, quality_factor)
|
||
filtered_data = signal.filtfilt(b, a, data)
|
||
return filtered_data
|