DataPrepare/utils/signal_process.py

109 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from utils.operation_tools import timing_decorator
import numpy as np
from scipy import signal, ndimage
@timing_decorator()
def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10, sample_rate=1000):
if _type == "lowpass": # 低通滤波处理
sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos')
return signal.sosfiltfilt(sos, np.array(data))
elif _type == "bandpass": # 带通滤波处理
low = low_cut / (sample_rate * 0.5)
high = high_cut / (sample_rate * 0.5)
sos = signal.butter(order, [low, high], btype='bandpass', output='sos')
return signal.sosfiltfilt(sos, np.array(data))
elif _type == "highpass": # 高通滤波处理
sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos')
return signal.sosfiltfilt(sos, np.array(data))
else: # 警告,滤波器类型必须有
raise ValueError("Please choose a type of fliter")
def bessel(data, _type, low_cut=0.0, high_cut=0.0, order=4, sample_rate=1000):
if _type == "lowpass": # 低通滤波处理
b, a = signal.bessel(order, low_cut / (sample_rate * 0.5), btype='lowpass', analog=False, norm='mag')
return signal.filtfilt(b, a, np.array(data))
elif _type == "bandpass": # 带通滤波处理
low = low_cut / (sample_rate * 0.5)
high = high_cut / (sample_rate * 0.5)
b, a = signal.bessel(order, [low, high], btype='bandpass', analog=False, norm='mag')
return signal.filtfilt(b, a, np.array(data))
elif _type == "highpass": # 高通滤波处理
b, a = signal.bessel(order, high_cut / (sample_rate * 0.5), btype='highpass', analog=False, norm='mag')
return signal.filtfilt(b, a, np.array(data))
else: # 警告,滤波器类型必须有
raise ValueError("Please choose a type of fliter")
@timing_decorator()
def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000):
"""
高效整数倍降采样长信号,分段处理以优化内存和速度。
参数:
original_signal : array-like, 原始信号数组
original_fs : float, 原始采样率 (Hz)
target_fs : float, 目标采样率 (Hz)
chunk_size : int, 每段处理的样本数默认100000
返回:
downsampled_signal : array-like, 降采样后的信号
"""
# 输入验证
if not isinstance(original_signal, np.ndarray):
original_signal = np.array(original_signal)
if original_fs <= target_fs:
raise ValueError("目标采样率必须小于原始采样率")
if target_fs <= 0 or original_fs <= 0:
raise ValueError("采样率必须为正数")
# 计算降采样因子(必须为整数)
downsample_factor = original_fs / target_fs
if not downsample_factor.is_integer():
raise ValueError("降采样因子必须为整数倍")
downsample_factor = int(downsample_factor)
# 计算总输出长度
total_length = len(original_signal)
output_length = total_length // downsample_factor
# 初始化输出数组
downsampled_signal = np.zeros(output_length)
# 分段处理
for start in range(0, total_length, chunk_size):
end = min(start + chunk_size, total_length)
chunk = original_signal[start:end]
# 使用decimate进行整数倍降采样
chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True)
# 计算输出位置
out_start = start // downsample_factor
out_end = out_start + len(chunk_downsampled)
if out_end > output_length:
chunk_downsampled = chunk_downsampled[:output_length - out_start]
downsampled_signal[out_start:out_end] = chunk_downsampled
return downsampled_signal
@timing_decorator()
def average_filter(raw_data, sample_rate, window_size_sec=20):
kernel = np.ones(window_size_sec * sample_rate) / (window_size_sec * sample_rate)
filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect')
convolve_filter_signal = raw_data - filtered
return convolve_filter_signal
# 陷波滤波器
@timing_decorator()
def notch_filter(data, notch_freq=50.0, quality_factor=30.0, sample_rate=1000):
nyquist = 0.5 * sample_rate
norm_notch_freq = notch_freq / nyquist
b, a = signal.iirnotch(norm_notch_freq, quality_factor)
filtered_data = signal.filtfilt(b, a, data)
return filtered_data