import time from pathlib import Path import numpy as np import pandas as pd from matplotlib import pyplot as plt import yaml from utils.event_map import E2N plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 from scipy import ndimage, signal from functools import wraps # 全局配置 class Config: time_verbose = False def timing_decorator(): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): start_time = time.time() result = func(*args, **kwargs) end_time = time.time() elapsed_time = end_time - start_time if Config.time_verbose: # 运行时检查全局配置 print(f"函数 '{func.__name__}' 执行耗时: {elapsed_time:.4f} 秒") return result return wrapper return decorator @timing_decorator() def read_auto(file_path): # print('suffix: ', file_path.suffix) if file_path.suffix == '.txt': # 使用pandas read csv读取txt return pd.read_csv(file_path, header=None).to_numpy().reshape(-1) elif file_path.suffix == '.npy': return np.load(file_path.__str__()) elif file_path.suffix == '.base64': with open(file_path) as f: files = f.readlines() data = np.array(files, dtype=int) return data else: raise ValueError('这个文件类型不支持,需要自己写读取程序') def merge_short_gaps(state_sequence, time_points, max_gap_sec): """ 合并状态序列中间隔小于指定时长的段 参数: - state_sequence: numpy array,状态序列(0/1) - time_points: numpy array,每个状态点对应的时间点 - max_gap_sec: float,要合并的最大间隔(秒) 返回: - merged_sequence: numpy array,合并后的状态序列 """ if len(state_sequence) <= 1: return state_sequence merged_sequence = state_sequence.copy() # 找出状态转换点 transitions = np.diff(np.concatenate([[0], merged_sequence, [0]])) # 找出状态1的起始和结束位置 state_starts = np.where(transitions == 1)[0] state_ends = np.where(transitions == -1)[0] - 1 # 检查每对连续的状态1 for i in range(len(state_starts) - 1): if state_ends[i] < len(time_points) and state_starts[i + 1] < len(time_points): # 计算间隔时长 gap_duration = time_points[state_starts[i + 1]] - time_points[state_ends[i]] # 如果间隔小于阈值,则合并 if gap_duration <= max_gap_sec: merged_sequence[state_ends[i]:state_starts[i + 1]] = 1 return merged_sequence def remove_short_durations(state_sequence, time_points, min_duration_sec): """ 移除状态序列中持续时间短于指定阈值的段 参数: - state_sequence: numpy array,状态序列(0/1) - time_points: numpy array,每个状态点对应的时间点 - min_duration_sec: float,要保留的最小持续时间(秒) 返回: - filtered_sequence: numpy array,过滤后的状态序列 """ if len(state_sequence) <= 1: return state_sequence filtered_sequence = state_sequence.copy() # 找出状态转换点 transitions = np.diff(np.concatenate([[0], filtered_sequence, [0]])) # 找出状态1的起始和结束位置 state_starts = np.where(transitions == 1)[0] state_ends = np.where(transitions == -1)[0] - 1 # 检查每个状态1的持续时间 for i in range(len(state_starts)): if state_starts[i] < len(time_points) and state_ends[i] < len(time_points): # 计算持续时间 duration = time_points[state_ends[i]] - time_points[state_starts[i]] if state_ends[i] == len(time_points) - 1: # 如果是最后一个窗口,加上一个窗口的长度 duration += time_points[1] - time_points[0] if len(time_points) > 1 else 0 # 如果持续时间短于阈值,则移除 if duration < min_duration_sec: filtered_sequence[state_starts[i]:state_ends[i] + 1] = 0 return filtered_sequence @timing_decorator() def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, window_second=20, step_second=None): # 处理标志位长度与 signal_data 对齐 # if calc_mask is None: # calc_mask = np.zeros(len(signal_data), dtype=bool) if step_second is None: step_second = window_second step_length = step_second * sampling_rate window_length = window_second * sampling_rate origin_seconds = len(signal_data) // sampling_rate total_samples = len(signal_data) # reflect padding left_pad_size = int(window_length // 2) right_pad_size = window_length - left_pad_size data = np.pad(signal_data, (left_pad_size, right_pad_size), mode='reflect') num_segments = int(np.ceil(len(signal_data) / step_length)) values_nan = np.full(num_segments, np.nan) # print(f"num_segments: {num_segments}, step_length: {step_length}, window_length: {window_length}") for i in range(num_segments): # 包含体动则仅计算不含体动部分 start = int(i * step_length) end = start + window_length segment = data[start:end] values_nan[i] = func(segment) values_nan = values_nan.repeat(step_second)[:origin_seconds] if calc_mask is not None: for i in range(len(values_nan)): if calc_mask[i]: values_nan[i] = np.nan values = values_nan.copy() # 插值处理体动区域的 NaN 值 def interpolate_nans(x, t): valid_mask = ~np.isnan(x) return np.interp(t, t[valid_mask], x[valid_mask]) values = interpolate_nans(values, np.arange(len(values))) else: values = values_nan.copy() return values_nan, values def load_dataset_conf(yaml_path): with open(yaml_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) # select_ids = config.get('select_id', []) # root_path = config.get('root_path', None) # data_path = Path(root_path) return config def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray: disable_mask = np.zeros(signal_second, dtype=int) for _, row in disable_df.iterrows(): start = row["start"] end = row["end"] disable_mask[start:end] = 1 return disable_mask def generate_event_mask(signal_second: int, event_df): event_mask = np.zeros(signal_second, dtype=int) score_mask = np.zeros(signal_second, dtype=int) # 剔除start = -1 的行 event_df = event_df[event_df["correct_Start"] >= 0] for _, row in event_df.iterrows(): start = row["correct_Start"] end = row["correct_End"] + 1 event_mask[start:end] = E2N[row["correct_EventsType"]] score_mask[start:end] = row["score"] return event_mask, score_mask def event_mask_2_list(mask): mask_start = np.where(np.diff(mask, append=0) == -1)[0] mask_end = np.where(np.diff(mask, append=0) == 1)[0] + 1 event_list =[[start, end] for start, end in zip(mask_start, mask_end)] return event_list def collect_values(arr: np.ndarray, index: int, step: int, limit: int, mask=None) -> list: """收集非 NaN 值,直到达到指定数量或边界""" values = [] count = 0 mask = mask if mask is not None else arr while count < limit and 0 <= index < len(mask): if not np.isnan(mask[index]): values.append(arr[index]) count += 1 index += step return values