286 lines
9.6 KiB
Python
286 lines
9.6 KiB
Python
import time
|
||
|
||
from pathlib import Path
|
||
import numpy as np
|
||
import pandas as pd
|
||
from matplotlib import pyplot as plt
|
||
import yaml
|
||
|
||
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
||
|
||
from scipy import ndimage, signal
|
||
from functools import wraps
|
||
|
||
# 全局配置
|
||
class Config:
|
||
time_verbose = False
|
||
|
||
def timing_decorator():
|
||
def decorator(func):
|
||
@wraps(func)
|
||
def wrapper(*args, **kwargs):
|
||
start_time = time.time()
|
||
result = func(*args, **kwargs)
|
||
end_time = time.time()
|
||
elapsed_time = end_time - start_time
|
||
if Config.time_verbose: # 运行时检查全局配置
|
||
print(f"函数 '{func.__name__}' 执行耗时: {elapsed_time:.4f} 秒")
|
||
return result
|
||
return wrapper
|
||
return decorator
|
||
|
||
@timing_decorator()
|
||
def read_auto(file_path):
|
||
# print('suffix: ', file_path.suffix)
|
||
if file_path.suffix == '.txt':
|
||
# 使用pandas read csv读取txt
|
||
return pd.read_csv(file_path, header=None).to_numpy().reshape(-1)
|
||
elif file_path.suffix == '.npy':
|
||
return np.load(file_path.__str__())
|
||
elif file_path.suffix == '.base64':
|
||
with open(file_path) as f:
|
||
files = f.readlines()
|
||
|
||
data = np.array(files, dtype=int)
|
||
return data
|
||
else:
|
||
raise ValueError('这个文件类型不支持,需要自己写读取程序')
|
||
|
||
@timing_decorator()
|
||
def Butterworth(data, type, low_cut=0.0, high_cut=0.0, order=10,sample_rate=1000):
|
||
|
||
if type == "lowpass": # 低通滤波处理
|
||
sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos')
|
||
return signal.sosfiltfilt(sos, np.array(data))
|
||
elif type == "bandpass": # 带通滤波处理
|
||
low = low_cut / (sample_rate * 0.5)
|
||
high = high_cut / (sample_rate * 0.5)
|
||
sos = signal.butter(order, [low, high], btype='bandpass', output='sos')
|
||
return signal.sosfiltfilt(sos, np.array(data))
|
||
elif type == "highpass": # 高通滤波处理
|
||
sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos')
|
||
return signal.sosfiltfilt(sos, np.array(data))
|
||
else: # 警告,滤波器类型必须有
|
||
raise ValueError("Please choose a type of fliter")
|
||
|
||
@timing_decorator()
|
||
def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000):
|
||
"""
|
||
高效整数倍降采样长信号(适合8小时以上),分段处理以优化内存和速度。
|
||
|
||
参数:
|
||
original_signal : array-like, 原始信号数组
|
||
original_fs : float, 原始采样率 (Hz)
|
||
target_fs : float, 目标采样率 (Hz)
|
||
chunk_size : int, 每段处理的样本数,默认100000
|
||
|
||
返回:
|
||
downsampled_signal : array-like, 降采样后的信号
|
||
"""
|
||
# 输入验证
|
||
if not isinstance(original_signal, np.ndarray):
|
||
original_signal = np.array(original_signal)
|
||
if original_fs <= target_fs:
|
||
raise ValueError("目标采样率必须小于原始采样率")
|
||
if target_fs <= 0 or original_fs <= 0:
|
||
raise ValueError("采样率必须为正数")
|
||
|
||
# 计算降采样因子(必须为整数)
|
||
downsample_factor = original_fs / target_fs
|
||
if not downsample_factor.is_integer():
|
||
raise ValueError("降采样因子必须为整数倍")
|
||
downsample_factor = int(downsample_factor)
|
||
|
||
# 计算总输出长度
|
||
total_length = len(original_signal)
|
||
output_length = total_length // downsample_factor
|
||
|
||
# 初始化输出数组
|
||
downsampled_signal = np.zeros(output_length)
|
||
|
||
# 分段处理
|
||
for start in range(0, total_length, chunk_size):
|
||
end = min(start + chunk_size, total_length)
|
||
chunk = original_signal[start:end]
|
||
|
||
# 使用decimate进行整数倍降采样
|
||
chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True)
|
||
|
||
# 计算输出位置
|
||
out_start = start // downsample_factor
|
||
out_end = out_start + len(chunk_downsampled)
|
||
if out_end > output_length:
|
||
chunk_downsampled = chunk_downsampled[:output_length - out_start]
|
||
|
||
downsampled_signal[out_start:out_end] = chunk_downsampled
|
||
|
||
return downsampled_signal
|
||
|
||
@timing_decorator()
|
||
def average_filter(raw_data, sample_rate, window_size=20):
|
||
kernel = np.ones(window_size * sample_rate) / (window_size * sample_rate)
|
||
filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect')
|
||
convolve_filter_signal = raw_data - filtered
|
||
return convolve_filter_signal
|
||
|
||
|
||
|
||
|
||
def merge_short_gaps(state_sequence, time_points, max_gap_sec):
|
||
"""
|
||
合并状态序列中间隔小于指定时长的段
|
||
|
||
参数:
|
||
- state_sequence: numpy array,状态序列(0/1)
|
||
- time_points: numpy array,每个状态点对应的时间点
|
||
- max_gap_sec: float,要合并的最大间隔(秒)
|
||
|
||
返回:
|
||
- merged_sequence: numpy array,合并后的状态序列
|
||
"""
|
||
if len(state_sequence) <= 1:
|
||
return state_sequence
|
||
|
||
merged_sequence = state_sequence.copy()
|
||
|
||
# 找出状态转换点
|
||
transitions = np.diff(np.concatenate([[0], merged_sequence, [0]]))
|
||
# 找出状态1的起始和结束位置
|
||
state_starts = np.where(transitions == 1)[0]
|
||
state_ends = np.where(transitions == -1)[0] - 1
|
||
|
||
# 检查每对连续的状态1
|
||
for i in range(len(state_starts) - 1):
|
||
if state_ends[i] < len(time_points) and state_starts[i + 1] < len(time_points):
|
||
# 计算间隔时长
|
||
gap_duration = time_points[state_starts[i + 1]] - time_points[state_ends[i]]
|
||
# 如果间隔小于阈值,则合并
|
||
if gap_duration <= max_gap_sec:
|
||
merged_sequence[state_ends[i]:state_starts[i + 1]] = 1
|
||
|
||
return merged_sequence
|
||
|
||
|
||
def remove_short_durations(state_sequence, time_points, min_duration_sec):
|
||
"""
|
||
移除状态序列中持续时间短于指定阈值的段
|
||
|
||
参数:
|
||
- state_sequence: numpy array,状态序列(0/1)
|
||
- time_points: numpy array,每个状态点对应的时间点
|
||
- min_duration_sec: float,要保留的最小持续时间(秒)
|
||
|
||
返回:
|
||
- filtered_sequence: numpy array,过滤后的状态序列
|
||
"""
|
||
if len(state_sequence) <= 1:
|
||
return state_sequence
|
||
|
||
filtered_sequence = state_sequence.copy()
|
||
|
||
# 找出状态转换点
|
||
transitions = np.diff(np.concatenate([[0], filtered_sequence, [0]]))
|
||
# 找出状态1的起始和结束位置
|
||
state_starts = np.where(transitions == 1)[0]
|
||
state_ends = np.where(transitions == -1)[0] - 1
|
||
|
||
# 检查每个状态1的持续时间
|
||
for i in range(len(state_starts)):
|
||
if state_starts[i] < len(time_points) and state_ends[i] < len(time_points):
|
||
# 计算持续时间
|
||
duration = time_points[state_ends[i]] - time_points[state_starts[i]]
|
||
if state_ends[i] == len(time_points) - 1:
|
||
# 如果是最后一个窗口,加上一个窗口的长度
|
||
duration += time_points[1] - time_points[0] if len(time_points) > 1 else 0
|
||
|
||
# 如果持续时间短于阈值,则移除
|
||
if duration < min_duration_sec:
|
||
filtered_sequence[state_starts[i]:state_ends[i] + 1] = 0
|
||
|
||
return filtered_sequence
|
||
|
||
|
||
@timing_decorator()
|
||
def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, window_second=20, step_second=None):
|
||
# 处理标志位长度与 signal_data 对齐
|
||
if calc_mask is None:
|
||
calc_mask = np.zeros(len(signal_data), dtype=bool)
|
||
|
||
if step_second is None:
|
||
step_second = window_second
|
||
|
||
step_length = step_second * sampling_rate
|
||
window_length = window_second * sampling_rate
|
||
|
||
origin_seconds = len(signal_data) // sampling_rate
|
||
total_samples = len(signal_data)
|
||
|
||
|
||
# reflect padding
|
||
left_pad_size = int(window_length // 2)
|
||
right_pad_size = window_length - left_pad_size
|
||
data = np.pad(signal_data, (left_pad_size, right_pad_size), mode='reflect')
|
||
|
||
num_segments = int(np.ceil(len(signal_data) / step_length))
|
||
values_nan = np.full(num_segments, np.nan)
|
||
|
||
# print(f"num_segments: {num_segments}, step_length: {step_length}, window_length: {window_length}")
|
||
for i in range(num_segments):
|
||
# 包含体动则仅计算不含体动部分
|
||
start = int(i * step_length)
|
||
end = start + window_length
|
||
segment = data[start:end]
|
||
values_nan[i] = func(segment)
|
||
|
||
|
||
values_nan = values_nan.repeat(step_second)[:origin_seconds]
|
||
|
||
for i in range(len(values_nan)):
|
||
if calc_mask[i]:
|
||
values_nan[i] = np.nan
|
||
|
||
values = values_nan.copy()
|
||
|
||
# 插值处理体动区域的 NaN 值
|
||
def interpolate_nans(x, t):
|
||
valid_mask = ~np.isnan(x)
|
||
return np.interp(t, t[valid_mask], x[valid_mask])
|
||
|
||
values = interpolate_nans(values, np.arange(len(values)))
|
||
|
||
return values_nan, values
|
||
|
||
|
||
def load_dataset_info(yaml_path):
|
||
with open(yaml_path, 'r', encoding='utf-8') as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
select_ids = config.get('select_id', [])
|
||
root_path = config.get('root_path', None)
|
||
data_path = Path(root_path)
|
||
return select_ids, data_path
|
||
|
||
|
||
def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray:
|
||
disable_mask = np.ones(signal_second, dtype=int)
|
||
|
||
for _, row in disable_df.iterrows():
|
||
start = row["start"]
|
||
end = row["end"]
|
||
disable_mask[start:end] = 0
|
||
return disable_mask
|
||
|
||
|
||
def generate_event_mask(signal_second: int, event_df) -> np.ndarray:
|
||
event_mask = np.zeros(signal_second, dtype=int)
|
||
|
||
for _, row in event_df.iterrows():
|
||
start = row["start"]
|
||
end = row["end"]
|
||
event_mask[start:end] = 1
|
||
return event_mask
|
||
|
||
|
||
|