DataPrepare/utils/operation_tools.py

286 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import time
from pathlib import Path
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import yaml
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
from scipy import ndimage, signal
from functools import wraps
# 全局配置
class Config:
time_verbose = False
def timing_decorator():
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
elapsed_time = end_time - start_time
if Config.time_verbose: # 运行时检查全局配置
print(f"函数 '{func.__name__}' 执行耗时: {elapsed_time:.4f}")
return result
return wrapper
return decorator
@timing_decorator()
def read_auto(file_path):
# print('suffix: ', file_path.suffix)
if file_path.suffix == '.txt':
# 使用pandas read csv读取txt
return pd.read_csv(file_path, header=None).to_numpy().reshape(-1)
elif file_path.suffix == '.npy':
return np.load(file_path.__str__())
elif file_path.suffix == '.base64':
with open(file_path) as f:
files = f.readlines()
data = np.array(files, dtype=int)
return data
else:
raise ValueError('这个文件类型不支持,需要自己写读取程序')
@timing_decorator()
def Butterworth(data, type, low_cut=0.0, high_cut=0.0, order=10,sample_rate=1000):
if type == "lowpass": # 低通滤波处理
sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos')
return signal.sosfiltfilt(sos, np.array(data))
elif type == "bandpass": # 带通滤波处理
low = low_cut / (sample_rate * 0.5)
high = high_cut / (sample_rate * 0.5)
sos = signal.butter(order, [low, high], btype='bandpass', output='sos')
return signal.sosfiltfilt(sos, np.array(data))
elif type == "highpass": # 高通滤波处理
sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos')
return signal.sosfiltfilt(sos, np.array(data))
else: # 警告,滤波器类型必须有
raise ValueError("Please choose a type of fliter")
@timing_decorator()
def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000):
"""
高效整数倍降采样长信号适合8小时以上分段处理以优化内存和速度。
参数:
original_signal : array-like, 原始信号数组
original_fs : float, 原始采样率 (Hz)
target_fs : float, 目标采样率 (Hz)
chunk_size : int, 每段处理的样本数默认100000
返回:
downsampled_signal : array-like, 降采样后的信号
"""
# 输入验证
if not isinstance(original_signal, np.ndarray):
original_signal = np.array(original_signal)
if original_fs <= target_fs:
raise ValueError("目标采样率必须小于原始采样率")
if target_fs <= 0 or original_fs <= 0:
raise ValueError("采样率必须为正数")
# 计算降采样因子(必须为整数)
downsample_factor = original_fs / target_fs
if not downsample_factor.is_integer():
raise ValueError("降采样因子必须为整数倍")
downsample_factor = int(downsample_factor)
# 计算总输出长度
total_length = len(original_signal)
output_length = total_length // downsample_factor
# 初始化输出数组
downsampled_signal = np.zeros(output_length)
# 分段处理
for start in range(0, total_length, chunk_size):
end = min(start + chunk_size, total_length)
chunk = original_signal[start:end]
# 使用decimate进行整数倍降采样
chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True)
# 计算输出位置
out_start = start // downsample_factor
out_end = out_start + len(chunk_downsampled)
if out_end > output_length:
chunk_downsampled = chunk_downsampled[:output_length - out_start]
downsampled_signal[out_start:out_end] = chunk_downsampled
return downsampled_signal
@timing_decorator()
def average_filter(raw_data, sample_rate, window_size=20):
kernel = np.ones(window_size * sample_rate) / (window_size * sample_rate)
filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect')
convolve_filter_signal = raw_data - filtered
return convolve_filter_signal
def merge_short_gaps(state_sequence, time_points, max_gap_sec):
"""
合并状态序列中间隔小于指定时长的段
参数:
- state_sequence: numpy array状态序列0/1
- time_points: numpy array每个状态点对应的时间点
- max_gap_sec: float要合并的最大间隔
返回:
- merged_sequence: numpy array合并后的状态序列
"""
if len(state_sequence) <= 1:
return state_sequence
merged_sequence = state_sequence.copy()
# 找出状态转换点
transitions = np.diff(np.concatenate([[0], merged_sequence, [0]]))
# 找出状态1的起始和结束位置
state_starts = np.where(transitions == 1)[0]
state_ends = np.where(transitions == -1)[0] - 1
# 检查每对连续的状态1
for i in range(len(state_starts) - 1):
if state_ends[i] < len(time_points) and state_starts[i + 1] < len(time_points):
# 计算间隔时长
gap_duration = time_points[state_starts[i + 1]] - time_points[state_ends[i]]
# 如果间隔小于阈值,则合并
if gap_duration <= max_gap_sec:
merged_sequence[state_ends[i]:state_starts[i + 1]] = 1
return merged_sequence
def remove_short_durations(state_sequence, time_points, min_duration_sec):
"""
移除状态序列中持续时间短于指定阈值的段
参数:
- state_sequence: numpy array状态序列0/1
- time_points: numpy array每个状态点对应的时间点
- min_duration_sec: float要保留的最小持续时间
返回:
- filtered_sequence: numpy array过滤后的状态序列
"""
if len(state_sequence) <= 1:
return state_sequence
filtered_sequence = state_sequence.copy()
# 找出状态转换点
transitions = np.diff(np.concatenate([[0], filtered_sequence, [0]]))
# 找出状态1的起始和结束位置
state_starts = np.where(transitions == 1)[0]
state_ends = np.where(transitions == -1)[0] - 1
# 检查每个状态1的持续时间
for i in range(len(state_starts)):
if state_starts[i] < len(time_points) and state_ends[i] < len(time_points):
# 计算持续时间
duration = time_points[state_ends[i]] - time_points[state_starts[i]]
if state_ends[i] == len(time_points) - 1:
# 如果是最后一个窗口,加上一个窗口的长度
duration += time_points[1] - time_points[0] if len(time_points) > 1 else 0
# 如果持续时间短于阈值,则移除
if duration < min_duration_sec:
filtered_sequence[state_starts[i]:state_ends[i] + 1] = 0
return filtered_sequence
@timing_decorator()
def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, window_second=20, step_second=None):
# 处理标志位长度与 signal_data 对齐
if calc_mask is None:
calc_mask = np.zeros(len(signal_data), dtype=bool)
if step_second is None:
step_second = window_second
step_length = step_second * sampling_rate
window_length = window_second * sampling_rate
origin_seconds = len(signal_data) // sampling_rate
total_samples = len(signal_data)
# reflect padding
left_pad_size = int(window_length // 2)
right_pad_size = window_length - left_pad_size
data = np.pad(signal_data, (left_pad_size, right_pad_size), mode='reflect')
num_segments = int(np.ceil(len(signal_data) / step_length))
values_nan = np.full(num_segments, np.nan)
# print(f"num_segments: {num_segments}, step_length: {step_length}, window_length: {window_length}")
for i in range(num_segments):
# 包含体动则仅计算不含体动部分
start = int(i * step_length)
end = start + window_length
segment = data[start:end]
values_nan[i] = func(segment)
values_nan = values_nan.repeat(step_second)[:origin_seconds]
for i in range(len(values_nan)):
if calc_mask[i]:
values_nan[i] = np.nan
values = values_nan.copy()
# 插值处理体动区域的 NaN 值
def interpolate_nans(x, t):
valid_mask = ~np.isnan(x)
return np.interp(t, t[valid_mask], x[valid_mask])
values = interpolate_nans(values, np.arange(len(values)))
return values_nan, values
def load_dataset_info(yaml_path):
with open(yaml_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
select_ids = config.get('select_id', [])
root_path = config.get('root_path', None)
data_path = Path(root_path)
return select_ids, data_path
def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray:
disable_mask = np.ones(signal_second, dtype=int)
for _, row in disable_df.iterrows():
start = row["start"]
end = row["end"]
disable_mask[start:end] = 0
return disable_mask
def generate_event_mask(signal_second: int, event_df) -> np.ndarray:
event_mask = np.zeros(signal_second, dtype=int)
for _, row in event_df.iterrows():
start = row["start"]
end = row["end"]
event_mask[start:end] = 1
return event_mask