DataPrepare/utils/operation_tools.py

209 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import time
from pathlib import Path
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import yaml
from utils.event_map import E2N
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
from scipy import ndimage, signal
from functools import wraps
# 全局配置
class Config:
time_verbose = False
def timing_decorator():
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
elapsed_time = end_time - start_time
if Config.time_verbose: # 运行时检查全局配置
print(f"函数 '{func.__name__}' 执行耗时: {elapsed_time:.4f}")
return result
return wrapper
return decorator
@timing_decorator()
def read_auto(file_path):
# print('suffix: ', file_path.suffix)
if file_path.suffix == '.txt':
# 使用pandas read csv读取txt
return pd.read_csv(file_path, header=None).to_numpy().reshape(-1)
elif file_path.suffix == '.npy':
return np.load(file_path.__str__())
elif file_path.suffix == '.base64':
with open(file_path) as f:
files = f.readlines()
data = np.array(files, dtype=int)
return data
else:
raise ValueError('这个文件类型不支持,需要自己写读取程序')
def merge_short_gaps(state_sequence, time_points, max_gap_sec):
"""
合并状态序列中间隔小于指定时长的段
参数:
- state_sequence: numpy array状态序列0/1
- time_points: numpy array每个状态点对应的时间点
- max_gap_sec: float要合并的最大间隔
返回:
- merged_sequence: numpy array合并后的状态序列
"""
if len(state_sequence) <= 1:
return state_sequence
merged_sequence = state_sequence.copy()
# 找出状态转换点
transitions = np.diff(np.concatenate([[0], merged_sequence, [0]]))
# 找出状态1的起始和结束位置
state_starts = np.where(transitions == 1)[0]
state_ends = np.where(transitions == -1)[0] - 1
# 检查每对连续的状态1
for i in range(len(state_starts) - 1):
if state_ends[i] < len(time_points) and state_starts[i + 1] < len(time_points):
# 计算间隔时长
gap_duration = time_points[state_starts[i + 1]] - time_points[state_ends[i]]
# 如果间隔小于阈值,则合并
if gap_duration <= max_gap_sec:
merged_sequence[state_ends[i]:state_starts[i + 1]] = 1
return merged_sequence
def remove_short_durations(state_sequence, time_points, min_duration_sec):
"""
移除状态序列中持续时间短于指定阈值的段
参数:
- state_sequence: numpy array状态序列0/1
- time_points: numpy array每个状态点对应的时间点
- min_duration_sec: float要保留的最小持续时间
返回:
- filtered_sequence: numpy array过滤后的状态序列
"""
if len(state_sequence) <= 1:
return state_sequence
filtered_sequence = state_sequence.copy()
# 找出状态转换点
transitions = np.diff(np.concatenate([[0], filtered_sequence, [0]]))
# 找出状态1的起始和结束位置
state_starts = np.where(transitions == 1)[0]
state_ends = np.where(transitions == -1)[0] - 1
# 检查每个状态1的持续时间
for i in range(len(state_starts)):
if state_starts[i] < len(time_points) and state_ends[i] < len(time_points):
# 计算持续时间
duration = time_points[state_ends[i]] - time_points[state_starts[i]]
if state_ends[i] == len(time_points) - 1:
# 如果是最后一个窗口,加上一个窗口的长度
duration += time_points[1] - time_points[0] if len(time_points) > 1 else 0
# 如果持续时间短于阈值,则移除
if duration < min_duration_sec:
filtered_sequence[state_starts[i]:state_ends[i] + 1] = 0
return filtered_sequence
@timing_decorator()
def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, window_second=20, step_second=None):
# 处理标志位长度与 signal_data 对齐
if calc_mask is None:
calc_mask = np.zeros(len(signal_data), dtype=bool)
if step_second is None:
step_second = window_second
step_length = step_second * sampling_rate
window_length = window_second * sampling_rate
origin_seconds = len(signal_data) // sampling_rate
total_samples = len(signal_data)
# reflect padding
left_pad_size = int(window_length // 2)
right_pad_size = window_length - left_pad_size
data = np.pad(signal_data, (left_pad_size, right_pad_size), mode='reflect')
num_segments = int(np.ceil(len(signal_data) / step_length))
values_nan = np.full(num_segments, np.nan)
# print(f"num_segments: {num_segments}, step_length: {step_length}, window_length: {window_length}")
for i in range(num_segments):
# 包含体动则仅计算不含体动部分
start = int(i * step_length)
end = start + window_length
segment = data[start:end]
values_nan[i] = func(segment)
values_nan = values_nan.repeat(step_second)[:origin_seconds]
for i in range(len(values_nan)):
if calc_mask[i]:
values_nan[i] = np.nan
values = values_nan.copy()
# 插值处理体动区域的 NaN 值
def interpolate_nans(x, t):
valid_mask = ~np.isnan(x)
return np.interp(t, t[valid_mask], x[valid_mask])
values = interpolate_nans(values, np.arange(len(values)))
return values_nan, values
def load_dataset_conf(yaml_path):
with open(yaml_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# select_ids = config.get('select_id', [])
# root_path = config.get('root_path', None)
# data_path = Path(root_path)
return config
def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray:
disable_mask = np.ones(signal_second, dtype=int)
for _, row in disable_df.iterrows():
start = row["start"]
end = row["end"]
disable_mask[start:end] = 0
return disable_mask
def generate_event_mask(signal_second: int, event_df):
event_mask = np.zeros(signal_second, dtype=int)
score_mask = np.zeros(signal_second, dtype=int)
# 剔除start = -1 的行
event_df = event_df[event_df["correct_Start"] >= 0]
for _, row in event_df.iterrows():
start = row["correct_Start"]
end = row["correct_End"] + 1
event_mask[start:end] = E2N[row["correct_EventsType"]]
score_mask[start:end] = row["score"]
return event_mask, score_mask