import time from pathlib import Path import numpy as np import pandas as pd from matplotlib import pyplot as plt import yaml plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 from scipy import ndimage, signal from functools import wraps # 全局配置 class Config: time_verbose = False def timing_decorator(): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): start_time = time.time() result = func(*args, **kwargs) end_time = time.time() elapsed_time = end_time - start_time if Config.time_verbose: # 运行时检查全局配置 print(f"函数 '{func.__name__}' 执行耗时: {elapsed_time:.4f} 秒") return result return wrapper return decorator @timing_decorator() def read_auto(file_path): # print('suffix: ', file_path.suffix) if file_path.suffix == '.txt': # 使用pandas read csv读取txt return pd.read_csv(file_path, header=None).to_numpy().reshape(-1) elif file_path.suffix == '.npy': return np.load(file_path.__str__()) elif file_path.suffix == '.base64': with open(file_path) as f: files = f.readlines() data = np.array(files, dtype=int) return data else: raise ValueError('这个文件类型不支持,需要自己写读取程序') @timing_decorator() def Butterworth(data, type, low_cut=0.0, high_cut=0.0, order=10,sample_rate=1000): if type == "lowpass": # 低通滤波处理 sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos') return signal.sosfiltfilt(sos, np.array(data)) elif type == "bandpass": # 带通滤波处理 low = low_cut / (sample_rate * 0.5) high = high_cut / (sample_rate * 0.5) sos = signal.butter(order, [low, high], btype='bandpass', output='sos') return signal.sosfiltfilt(sos, np.array(data)) elif type == "highpass": # 高通滤波处理 sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos') return signal.sosfiltfilt(sos, np.array(data)) else: # 警告,滤波器类型必须有 raise ValueError("Please choose a type of fliter") @timing_decorator() def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000): """ 高效整数倍降采样长信号(适合8小时以上),分段处理以优化内存和速度。 参数: original_signal : array-like, 原始信号数组 original_fs : float, 原始采样率 (Hz) target_fs : float, 目标采样率 (Hz) chunk_size : int, 每段处理的样本数,默认100000 返回: downsampled_signal : array-like, 降采样后的信号 """ # 输入验证 if not isinstance(original_signal, np.ndarray): original_signal = np.array(original_signal) if original_fs <= target_fs: raise ValueError("目标采样率必须小于原始采样率") if target_fs <= 0 or original_fs <= 0: raise ValueError("采样率必须为正数") # 计算降采样因子(必须为整数) downsample_factor = original_fs / target_fs if not downsample_factor.is_integer(): raise ValueError("降采样因子必须为整数倍") downsample_factor = int(downsample_factor) # 计算总输出长度 total_length = len(original_signal) output_length = total_length // downsample_factor # 初始化输出数组 downsampled_signal = np.zeros(output_length) # 分段处理 for start in range(0, total_length, chunk_size): end = min(start + chunk_size, total_length) chunk = original_signal[start:end] # 使用decimate进行整数倍降采样 chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True) # 计算输出位置 out_start = start // downsample_factor out_end = out_start + len(chunk_downsampled) if out_end > output_length: chunk_downsampled = chunk_downsampled[:output_length - out_start] downsampled_signal[out_start:out_end] = chunk_downsampled return downsampled_signal @timing_decorator() def average_filter(raw_data, sample_rate, window_size=20): kernel = np.ones(window_size * sample_rate) / (window_size * sample_rate) filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect') convolve_filter_signal = raw_data - filtered return convolve_filter_signal def merge_short_gaps(state_sequence, time_points, max_gap_sec): """ 合并状态序列中间隔小于指定时长的段 参数: - state_sequence: numpy array,状态序列(0/1) - time_points: numpy array,每个状态点对应的时间点 - max_gap_sec: float,要合并的最大间隔(秒) 返回: - merged_sequence: numpy array,合并后的状态序列 """ if len(state_sequence) <= 1: return state_sequence merged_sequence = state_sequence.copy() # 找出状态转换点 transitions = np.diff(np.concatenate([[0], merged_sequence, [0]])) # 找出状态1的起始和结束位置 state_starts = np.where(transitions == 1)[0] state_ends = np.where(transitions == -1)[0] - 1 # 检查每对连续的状态1 for i in range(len(state_starts) - 1): if state_ends[i] < len(time_points) and state_starts[i + 1] < len(time_points): # 计算间隔时长 gap_duration = time_points[state_starts[i + 1]] - time_points[state_ends[i]] # 如果间隔小于阈值,则合并 if gap_duration <= max_gap_sec: merged_sequence[state_ends[i]:state_starts[i + 1]] = 1 return merged_sequence def remove_short_durations(state_sequence, time_points, min_duration_sec): """ 移除状态序列中持续时间短于指定阈值的段 参数: - state_sequence: numpy array,状态序列(0/1) - time_points: numpy array,每个状态点对应的时间点 - min_duration_sec: float,要保留的最小持续时间(秒) 返回: - filtered_sequence: numpy array,过滤后的状态序列 """ if len(state_sequence) <= 1: return state_sequence filtered_sequence = state_sequence.copy() # 找出状态转换点 transitions = np.diff(np.concatenate([[0], filtered_sequence, [0]])) # 找出状态1的起始和结束位置 state_starts = np.where(transitions == 1)[0] state_ends = np.where(transitions == -1)[0] - 1 # 检查每个状态1的持续时间 for i in range(len(state_starts)): if state_starts[i] < len(time_points) and state_ends[i] < len(time_points): # 计算持续时间 duration = time_points[state_ends[i]] - time_points[state_starts[i]] if state_ends[i] == len(time_points) - 1: # 如果是最后一个窗口,加上一个窗口的长度 duration += time_points[1] - time_points[0] if len(time_points) > 1 else 0 # 如果持续时间短于阈值,则移除 if duration < min_duration_sec: filtered_sequence[state_starts[i]:state_ends[i] + 1] = 0 return filtered_sequence @timing_decorator() def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, window_second=20, step_second=None): # 处理标志位长度与 signal_data 对齐 if calc_mask is None: calc_mask = np.zeros(len(signal_data), dtype=bool) if step_second is None: step_second = window_second step_length = step_second * sampling_rate window_length = window_second * sampling_rate origin_seconds = len(signal_data) // sampling_rate total_samples = len(signal_data) # reflect padding left_pad_size = int(window_length // 2) right_pad_size = window_length - left_pad_size data = np.pad(signal_data, (left_pad_size, right_pad_size), mode='reflect') num_segments = int(np.ceil(len(signal_data) / step_length)) values_nan = np.full(num_segments, np.nan) # print(f"num_segments: {num_segments}, step_length: {step_length}, window_length: {window_length}") for i in range(num_segments): # 包含体动则仅计算不含体动部分 start = int(i * step_length) end = start + window_length segment = data[start:end] values_nan[i] = func(segment) values_nan = values_nan.repeat(step_second)[:origin_seconds] for i in range(len(values_nan)): if calc_mask[i]: values_nan[i] = np.nan values = values_nan.copy() # 插值处理体动区域的 NaN 值 def interpolate_nans(x, t): valid_mask = ~np.isnan(x) return np.interp(t, t[valid_mask], x[valid_mask]) values = interpolate_nans(values, np.arange(len(values))) return values_nan, values def load_dataset_info(yaml_path): with open(yaml_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) select_ids = config.get('select_id', []) root_path = config.get('root_path', None) data_path = Path(root_path) return select_ids, data_path def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray: disable_mask = np.ones(signal_second, dtype=int) for _, row in disable_df.iterrows(): start = row["start"] end = row["end"] disable_mask[start:end] = 0 return disable_mask def generate_event_mask(signal_second: int, event_df) -> np.ndarray: event_mask = np.zeros(signal_second, dtype=int) for _, row in event_df.iterrows(): start = row["start"] end = row["end"] event_mask[start:end] = 1 return event_mask