Add data processing and visualization modules for signal analysis
This commit is contained in:
parent
a3d4087810
commit
805f1dc7f8
21
HYS_process.py
Normal file
21
HYS_process.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""
|
||||
本脚本完成对呼研所数据的处理,包含以下功能:
|
||||
1. 数据读取与预处理
|
||||
从传入路径中,进行数据和标签的读取,并进行初步的预处理
|
||||
预处理包括为数据进行滤波、去噪等操作
|
||||
2. 数据清洗与异常值处理
|
||||
3. 输出清晰后的统计信息
|
||||
4. 数据保存
|
||||
将处理后的数据保存到指定路径,便于后续使用
|
||||
主要是保存切分后的数据位置和标签
|
||||
5. 可视化
|
||||
提供数据处理前后的可视化对比,帮助理解数据变化
|
||||
绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等
|
||||
|
||||
|
||||
|
||||
|
||||
# 低幅值区间规则标定与剔除
|
||||
# 高幅值连续体动规则标定与剔除
|
||||
# 手动标定不可用区间提剔除
|
||||
"""
|
0
SHHS_process.py
Normal file
0
SHHS_process.py
Normal file
0
draw_tools/__init__.py
Normal file
0
draw_tools/__init__.py
Normal file
175
draw_tools/draw_statics.py
Normal file
175
draw_tools/draw_statics.py
Normal file
@ -0,0 +1,175 @@
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.gridspec import GridSpec
|
||||
from matplotlib.colors import PowerNorm
|
||||
import seaborn as sns
|
||||
import numpy as np
|
||||
|
||||
|
||||
def draw_ax_confusion_matrix(ax:Axes, confusion_matrix, segment_count_matrix, confusion_matrix_percent,
|
||||
valid_signal_length, total_duration, time_labels, amp_labels, signal_type=''):
|
||||
# 创建用于热图注释的文本矩阵
|
||||
text_matrix = np.empty((len(amp_labels), len(time_labels)), dtype=object)
|
||||
percent_matrix = np.zeros((len(amp_labels), len(time_labels)))
|
||||
|
||||
# 填充文本矩阵和百分比矩阵
|
||||
for i in range(len(amp_labels)):
|
||||
for j in range(len(time_labels)):
|
||||
val = confusion_matrix.iloc[i, j]
|
||||
segment_count = segment_count_matrix[i, j]
|
||||
percent = confusion_matrix_percent.iloc[i, j]
|
||||
text_matrix[i, j] = f"[{int(segment_count)}]{int(val)}\n({percent:.2f}%)"
|
||||
percent_matrix[i, j] = percent
|
||||
|
||||
# 绘制热图,调整颜色条位置
|
||||
sns_heatmap = sns.heatmap(percent_matrix, annot=text_matrix, fmt='',
|
||||
xticklabels=time_labels, yticklabels=amp_labels,
|
||||
cmap='YlGnBu', ax=ax, vmin=0, vmax=100,
|
||||
norm=PowerNorm(gamma=0.5, vmin=0, vmax=100),
|
||||
cbar_kws={'label': '百分比 (%)', 'shrink': 0.6, 'pad': 0.15},
|
||||
# annot_kws={'fontsize': 12}
|
||||
)
|
||||
|
||||
# 添加行统计(右侧)
|
||||
row_sums = confusion_matrix['总计']
|
||||
row_percents = confusion_matrix_percent['总计']
|
||||
ax.text(len(time_labels) + 1, -0.5, "各幅值时长\n(有效区间百分比%)", ha='center', va='center')
|
||||
for i, (val, perc) in enumerate(zip(row_sums, row_percents)):
|
||||
ax.text(len(time_labels) + 0.5, i + 0.5, f"{int(val)}\n({perc:.2f}%)",
|
||||
ha='center', va='center')
|
||||
|
||||
# 添加列统计(底部)
|
||||
col_sums = segment_count_matrix.sum(axis=0)
|
||||
col_percents = confusion_matrix.sum(axis=0) / total_duration * 100
|
||||
ax.text(-1, len(amp_labels) + 0.5, "[各时长片段数]\n(有效区间百分比%)", ha='center', va='center', rotation=0)
|
||||
for j, (val, perc) in enumerate(zip(col_sums, col_percents)):
|
||||
ax.text(j + 0.5, len(amp_labels) + 0.5, f"[{int(val)}]\n({perc:.2f}%)",
|
||||
ha='center', va='center', rotation=0)
|
||||
|
||||
# 将x轴坐标移到顶部
|
||||
ax.xaxis.set_label_position('top')
|
||||
ax.xaxis.tick_top()
|
||||
|
||||
# 设置标题和标签
|
||||
ax.set_title('幅值-时长统计矩阵', pad=40)
|
||||
ax.set_xlabel('持续时间区间 (秒)', labelpad=10)
|
||||
ax.set_ylabel('幅值区间')
|
||||
|
||||
# 设置坐标轴标签水平显示
|
||||
ax.set_xticklabels(time_labels, rotation=0)
|
||||
ax.set_yticklabels(amp_labels, rotation=0)
|
||||
|
||||
# 调整颜色条位置
|
||||
cbar = sns_heatmap.collections[0].colorbar
|
||||
cbar.ax.yaxis.set_label_position('right')
|
||||
|
||||
# 添加图例说明
|
||||
ax.text(-2, -1, "热图内:\n[片段数]时长\n(有效区间百分比%)",
|
||||
ha='left', va='top', bbox=dict(facecolor='none', edgecolor='black', alpha=0.5))
|
||||
|
||||
# 总计
|
||||
# 总片段数
|
||||
total_segments = segment_count_matrix.sum()
|
||||
# 有效总市场占比
|
||||
total_percent = valid_signal_length / total_duration * 100
|
||||
ax.text(len(time_labels) + 0.5, len(amp_labels) + 0.5, f"[{int(total_segments)}]{valid_signal_length}\n({total_percent:.2f}%)",
|
||||
ha='center', va='center')
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
|
||||
def draw_ax_amp(ax, signal_name, original_times, origin_signal, no_movement_signal, mav_values,
|
||||
movement_position_list, low_amp_position_list, signal_second_length, aml_list=None):
|
||||
# 绘制信号线
|
||||
ax.plot(original_times, origin_signal, 'k-', linewidth=1, alpha=0.7)
|
||||
|
||||
# 添加红色和蓝色的axvspan区域
|
||||
for start, end in movement_position_list:
|
||||
if start < len(original_times) and end < len(original_times):
|
||||
ax.axvspan(start, end, color='red', alpha=0.3)
|
||||
|
||||
for start, end in low_amp_position_list:
|
||||
if start < len(original_times) and end < len(original_times):
|
||||
ax.axvspan(start, end, color='blue', alpha=0.2)
|
||||
|
||||
# 如果存在AML列表,绘制水平线
|
||||
if aml_list is not None:
|
||||
color_map = ['red', 'orange', 'green']
|
||||
for i, aml in enumerate(aml_list):
|
||||
ax.hlines(aml, 0, signal_second_length, color=color_map[min(i, 2)], linestyle='dashed', linewidth=2, alpha=0.5, label=f'{aml} aml')
|
||||
|
||||
|
||||
ax.plot(np.linspace(0, len(mav_values), len(mav_values)), mav_values, color='blue', linewidth=2, alpha=0.4, label='2sMAV')
|
||||
|
||||
# 设置Y轴范围
|
||||
ax.set_ylim((-2000, 2000))
|
||||
|
||||
# 创建表示不同颜色区域的图例
|
||||
red_patch = mpatches.Patch(color='red', alpha=0.2, label='Movement Area')
|
||||
blue_patch = mpatches.Patch(color='blue', alpha=0.2, label='Low Amplitude Area')
|
||||
|
||||
# 添加新的图例项,并保留原来的图例项
|
||||
handles, labels = ax.get_legend_handles_labels() # 获取原有图例
|
||||
ax.legend(handles=[red_patch, blue_patch] + handles,
|
||||
labels=['Movement Area', 'Low Amplitude Area'] + labels,
|
||||
loc='upper right',
|
||||
bbox_to_anchor=(1, 1.4),
|
||||
framealpha=0.5)
|
||||
|
||||
# 设置标题和标签
|
||||
ax.set_title(f'{signal_name} Signal')
|
||||
ax.set_ylabel('Amplitude')
|
||||
ax.set_xlabel('Time (s)')
|
||||
|
||||
# 启用网格
|
||||
ax.grid(True, linestyle='--', alpha=0.7)
|
||||
|
||||
|
||||
def draw_signal_metrics(bcg_origin_signal, resp_origin_signal, bcg_no_movement_signal, resp_no_movement_signal,
|
||||
bcg_sampling_rate, resp_sampling_rate, bcg_movement_position_list, bcg_low_amp_position_list,
|
||||
resp_movement_position_list, resp_low_amp_position_list,
|
||||
bcg_mav_values, resp_mav_values, bcg_statistic_info, resp_statistic_info,
|
||||
signal_info, show=False, save_path=None):
|
||||
|
||||
# 创建图像
|
||||
fig = plt.figure(figsize=(18, 10))
|
||||
|
||||
gs = GridSpec(2, 2, height_ratios=[2, 2], width_ratios=[4, 2], hspace=0.5)
|
||||
|
||||
signal_second_length = len(bcg_origin_signal) // bcg_sampling_rate
|
||||
bcg_origin_times = np.linspace(0, signal_second_length, len(bcg_origin_signal))
|
||||
resp_origin_times = np.linspace(0, signal_second_length, len(resp_origin_signal))
|
||||
# 子图 1:原始信号
|
||||
ax1 = fig.add_subplot(gs[0])
|
||||
draw_ax_amp(ax=ax1, signal_name='BCG', original_times=bcg_origin_times, origin_signal=bcg_origin_signal,
|
||||
no_movement_signal=bcg_no_movement_signal, mav_values=bcg_mav_values,
|
||||
movement_position_list=bcg_movement_position_list, low_amp_position_list=bcg_low_amp_position_list,
|
||||
signal_second_length=signal_second_length, aml_list=[200, 500, 1000])
|
||||
|
||||
ax2 = fig.add_subplot(gs[1])
|
||||
param_names = ['confusion_matrix', 'segment_count_matrix', 'confusion_matrix_percent',
|
||||
'valid_signal_length', 'total_duration', 'time_labels', 'amp_labels']
|
||||
params = dict(zip(param_names, bcg_statistic_info))
|
||||
params['ax'] = ax2
|
||||
draw_ax_confusion_matrix(**params)
|
||||
|
||||
ax3 = fig.add_subplot(gs[2], sharex=ax1)
|
||||
draw_ax_amp(ax=ax3, signal_name='RSEP', original_times=resp_origin_times, origin_signal=resp_origin_signal,
|
||||
no_movement_signal=resp_no_movement_signal, mav_values=resp_mav_values,
|
||||
movement_position_list=resp_movement_position_list, low_amp_position_list=resp_low_amp_position_list,
|
||||
signal_second_length=signal_second_length, aml_list=[100, 300, 500])
|
||||
ax4 = fig.add_subplot(gs[3])
|
||||
params = dict(zip(param_names, resp_statistic_info))
|
||||
params['ax'] = ax4
|
||||
draw_ax_confusion_matrix(**params)
|
||||
|
||||
# 全局标题
|
||||
fig.suptitle(f'{signal_info} Signal Metrics', fontsize=16, x=0.35, y=0.95)
|
||||
|
||||
if save_path is not None:
|
||||
# 保存图像
|
||||
plt.savefig(save_path, dpi=300)
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
|
||||
plt.close()
|
0
signal_method/__init__.py
Normal file
0
signal_method/__init__.py
Normal file
207
signal_method/rule_base_event.py
Normal file
207
signal_method/rule_base_event.py
Normal file
@ -0,0 +1,207 @@
|
||||
from utils.operation_tools import timing_decorator
|
||||
import numpy as np
|
||||
from utils.operation_tools import merge_short_gaps, remove_short_durations
|
||||
|
||||
@timing_decorator()
|
||||
def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, stride_sec=None,
|
||||
amplitude_threshold=50, merge_gap_sec=10, min_duration_sec=5):
|
||||
"""
|
||||
检测信号中的低幅值状态,通过计算RMS值判断信号强度是否低于设定阈值。
|
||||
|
||||
参数:
|
||||
- signal_data: numpy array,输入的信号数据
|
||||
- sampling_rate: int,信号的采样率(Hz)
|
||||
- window_size_sec: float,分析窗口的时长(秒),默认值为 1 秒
|
||||
- stride_sec: float,窗口滑动步长(秒),默认值为None(等于window_size_sec,无重叠)
|
||||
- amplitude_threshold: float,RMS阈值,低于此值表示低幅值状态,默认值为 50
|
||||
- merge_gap_sec: float,要合并的状态之间的最大间隔(秒),默认值为 10 秒
|
||||
- min_duration_sec: float,要保留的状态的最小持续时间(秒),默认值为 5 秒
|
||||
|
||||
返回:
|
||||
- low_amplitude_mask: numpy array,低幅值状态的掩码(1表示低幅值,0表示正常幅值)
|
||||
"""
|
||||
# 计算窗口大小(样本数)
|
||||
window_samples = int(window_size_sec * sampling_rate)
|
||||
|
||||
# 如果未指定步长,设置为窗口大小(无重叠)
|
||||
if stride_sec is None:
|
||||
stride_sec = window_size_sec
|
||||
|
||||
# 计算步长(样本数)
|
||||
stride_samples = int(stride_sec * sampling_rate)
|
||||
|
||||
# 确保步长至少为1
|
||||
stride_samples = max(1, stride_samples)
|
||||
|
||||
# 处理信号边界,使用反射填充
|
||||
pad_size = window_samples // 2
|
||||
padded_signal = np.pad(signal_data, pad_size, mode='reflect')
|
||||
|
||||
# 计算填充后的窗口数量
|
||||
num_windows = max(1, (len(padded_signal) - window_samples) // stride_samples + 1)
|
||||
|
||||
# 初始化RMS值数组
|
||||
rms_values = np.zeros(num_windows)
|
||||
|
||||
# 计算每个窗口的RMS值
|
||||
for i in range(num_windows):
|
||||
start_idx = i * stride_samples
|
||||
end_idx = min(start_idx + window_samples, len(signal_data))
|
||||
|
||||
# 处理窗口,包括可能不完整的最后一个窗口
|
||||
window_data = signal_data[start_idx:end_idx]
|
||||
if len(window_data) > 0:
|
||||
rms_values[i] = np.sqrt(np.mean(window_data ** 2))
|
||||
else:
|
||||
rms_values[i] = 0
|
||||
|
||||
# 生成初始低幅值掩码:RMS低于阈值的窗口标记为1(低幅值),其他为0
|
||||
low_amplitude_mask = np.where(rms_values <= amplitude_threshold, 1, 0)
|
||||
|
||||
# 计算原始信号对应的窗口索引范围
|
||||
orig_start_window = pad_size // stride_samples
|
||||
if stride_sec == 1:
|
||||
orig_end_window = orig_start_window + (len(signal_data) // stride_samples)
|
||||
else:
|
||||
orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + 1
|
||||
|
||||
# 只保留原始信号对应的窗口低幅值掩码
|
||||
low_amplitude_mask = low_amplitude_mask[orig_start_window:orig_end_window]
|
||||
# print("low_amplitude_mask_length: ", len(low_amplitude_mask))
|
||||
num_original_windows = len(low_amplitude_mask)
|
||||
|
||||
# 转换为时间轴上的状态序列
|
||||
# 计算每个窗口对应的时间点(秒)
|
||||
time_points = np.arange(num_original_windows) * stride_sec
|
||||
|
||||
# 如果需要合并间隔小的状态
|
||||
if merge_gap_sec > 0:
|
||||
low_amplitude_mask = merge_short_gaps(low_amplitude_mask, time_points, merge_gap_sec)
|
||||
|
||||
# 如果需要移除短时状态
|
||||
if min_duration_sec > 0:
|
||||
low_amplitude_mask = remove_short_durations(low_amplitude_mask, time_points, min_duration_sec)
|
||||
|
||||
low_amplitude_mask = low_amplitude_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate]
|
||||
|
||||
# 低幅值状态起止位置 [[start, end], [start, end], ...]
|
||||
low_amplitude_start = np.where(np.diff(np.concatenate([[0], low_amplitude_mask])) == 1)[0]
|
||||
low_amplitude_end = np.where(np.diff(np.concatenate([low_amplitude_mask, [0]])) == -1)[0]
|
||||
low_amplitude_position_list = [[start, end] for start, end in zip(low_amplitude_start, low_amplitude_end)]
|
||||
|
||||
return low_amplitude_mask, low_amplitude_position_list
|
||||
|
||||
|
||||
def get_typical_segment_for_continues_signal(signal_data, sampling_rate=100, window_size=30, step_size=1):
|
||||
"""
|
||||
获取十个片段
|
||||
:param signal_data: 信号数据
|
||||
:param sampling_rate: 采样率
|
||||
:param window_size: 窗口大小(秒)
|
||||
:param step_size: 步长(秒)
|
||||
:return: 典型片段列表
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# 基于体动位置和幅值的睡姿识别
|
||||
# 主要是依靠体动mask,将整夜分割成多个有效片段,然后每个片段计算幅值指标,判断两个片段的幅值指标是否存在显著差异,如果存在显著差异,则认为存在睡姿变化
|
||||
# 考虑到每个片段长度为10s,所以每个片段的幅值指标计算时间长度为10s,然后计算每个片段的幅值指标
|
||||
# 仅对比相邻片段的幅值指标,如果存在显著差异,则认为存在睡姿变化,即每个体动相邻的30秒内存在睡姿变化,如果片段不足30秒,则按实际长度对比
|
||||
|
||||
@timing_decorator()
|
||||
def position_based_sleep_recognition(signal_data, movement_mask, sampling_rate=100, window_size_sec=30,
|
||||
interval_to_movement=10):
|
||||
# 获取有效片段起止位置
|
||||
valid_mask = 1 - movement_mask
|
||||
valid_starts = np.where(np.diff(np.concatenate([[0], valid_mask])) == 1)[0]
|
||||
valid_ends = np.where(np.diff(np.concatenate([valid_mask, [0]])) == -1)[0]
|
||||
|
||||
movement_start = np.where(np.diff(np.concatenate([[0], movement_mask])) == 1)[0]
|
||||
movement_end = np.where(np.diff(np.concatenate([movement_mask, [0]])) == -1)[0]
|
||||
|
||||
segment_left_average_amplitude = []
|
||||
segment_right_average_amplitude = []
|
||||
segment_left_average_energy = []
|
||||
segment_right_average_energy = []
|
||||
|
||||
# window_samples = int(window_size_sec * sampling_rate)
|
||||
# pad_size = window_samples // 2
|
||||
# padded_signal = np.pad(signal_data, pad_size, mode='reflect')
|
||||
|
||||
for start, end in zip(valid_starts, valid_ends):
|
||||
start *= sampling_rate
|
||||
end *= sampling_rate
|
||||
# 避免过短的片段
|
||||
if end - start <= sampling_rate: # 小于1秒的片段不考虑
|
||||
continue
|
||||
# 获取当前片段数据
|
||||
|
||||
|
||||
elif end - start < (window_size_sec * interval_to_movement + 1) * sampling_rate:
|
||||
left_start = start
|
||||
left_end = min(end, left_start + window_size_sec * sampling_rate)
|
||||
right_start = max(start, end - window_size_sec * sampling_rate)
|
||||
right_end = end
|
||||
else:
|
||||
left_start = start + interval_to_movement * sampling_rate
|
||||
left_end = left_start + window_size_sec * sampling_rate
|
||||
right_start = end - interval_to_movement * sampling_rate - window_size_sec * sampling_rate
|
||||
right_end = end
|
||||
|
||||
# 新的end - start确保为200的整数倍
|
||||
if (left_end - left_start) % (2 * sampling_rate) != 0:
|
||||
left_end = left_start + ((left_end - left_start) // (2 * sampling_rate)) * (2 * sampling_rate)
|
||||
if (right_end - right_start) % (2 * sampling_rate) != 0:
|
||||
right_end = right_start + ((right_end - right_start) // (2 * sampling_rate)) * (2 * sampling_rate)
|
||||
|
||||
# 计算每个片段的幅值指标
|
||||
left_mav = np.mean(np.max(signal_data[left_start:left_end].reshape(-1, 2 * sampling_rate), axis=0)) - np.mean(
|
||||
np.min(signal_data[left_start:left_end].reshape(-1, 2 * sampling_rate), axis=0))
|
||||
right_mav = np.mean(
|
||||
np.max(signal_data[right_start:right_end].reshape(-1, 2 * sampling_rate), axis=0)) - np.mean(
|
||||
np.min(signal_data[right_start:right_end].reshape(-1, 2 * sampling_rate), axis=0))
|
||||
segment_left_average_amplitude.append(left_mav)
|
||||
segment_right_average_amplitude.append(right_mav)
|
||||
|
||||
left_energy = np.sum(np.abs(signal_data[left_start:left_end] ** 2))
|
||||
right_energy = np.sum(np.abs(signal_data[right_start:right_end] ** 2))
|
||||
segment_left_average_energy.append(left_energy)
|
||||
segment_right_average_energy.append(right_energy)
|
||||
|
||||
position_changes = []
|
||||
position_change_times = []
|
||||
for i in range(1, len(segment_left_average_amplitude)):
|
||||
# 计算幅值指标的变化率
|
||||
left_amplitude_change = abs(segment_left_average_amplitude[i] - segment_left_average_amplitude[i - 1]) / max(
|
||||
segment_left_average_amplitude[i - 1], 1e-6)
|
||||
right_amplitude_change = abs(segment_right_average_amplitude[i] - segment_right_average_amplitude[i - 1]) / max(
|
||||
segment_right_average_amplitude[i - 1], 1e-6)
|
||||
|
||||
# 计算能量指标的变化率
|
||||
left_energy_change = abs(segment_left_average_energy[i] - segment_left_average_energy[i - 1]) / max(
|
||||
segment_left_average_energy[i - 1], 1e-6)
|
||||
right_energy_change = abs(segment_right_average_energy[i] - segment_right_average_energy[i - 1]) / max(
|
||||
segment_right_average_energy[i - 1], 1e-6)
|
||||
|
||||
# 判断是否存在显著变化 (可根据实际情况调整阈值)
|
||||
threshold_amplitude = 0.1 # 幅值变化阈值
|
||||
threshold_energy = 0.1 # 能量变化阈值
|
||||
|
||||
# 如果左右通道中的任一通道同时满足幅值和能量的变化阈值,则认为存在姿势变化
|
||||
left_significant_change = (left_amplitude_change > threshold_amplitude) and (
|
||||
left_energy_change > threshold_energy)
|
||||
right_significant_change = (right_amplitude_change > threshold_amplitude) and (
|
||||
right_energy_change > threshold_energy)
|
||||
|
||||
if left_significant_change or right_significant_change:
|
||||
# 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示
|
||||
position_changes.append(1)
|
||||
position_change_times.append((movement_start[i - 1], movement_end[i - 1]))
|
||||
else:
|
||||
position_changes.append(0) # 0表示不存在姿势变化
|
||||
|
||||
# print(i,movement_start[i], movement_end[i], round(left_amplitude_change, 2), round(right_amplitude_change, 2), round(left_energy_change, 2), round(right_energy_change, 2))
|
||||
|
||||
return position_changes, position_change_times
|
||||
|
41
signal_method/time_metrics.py
Normal file
41
signal_method/time_metrics.py
Normal file
@ -0,0 +1,41 @@
|
||||
from utils.operation_tools import calculate_by_slide_windows
|
||||
from utils.operation_tools import timing_decorator
|
||||
import numpy as np
|
||||
|
||||
|
||||
@timing_decorator()
|
||||
def calc_mav(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window_second=10, step_second=1, inner_window_second=2):
|
||||
assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}"
|
||||
assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}"
|
||||
# print(f"movement_mask_length: {len(movement_mask)}, signal_data_length: {len(signal_data)}")
|
||||
processed_mask = movement_mask.copy()
|
||||
def mav_func(x):
|
||||
return np.mean(np.nanmax(x.reshape(-1, inner_window_second*sampling_rate), axis=1) - np.nanmin(x.reshape(-1, inner_window_second*sampling_rate), axis=1)) / 2
|
||||
mav_nan, mav = calculate_by_slide_windows(mav_func, signal_data, processed_mask, sampling_rate=sampling_rate,
|
||||
window_second=window_second, step_second=step_second)
|
||||
|
||||
return mav_nan, mav
|
||||
|
||||
@timing_decorator()
|
||||
def calc_wavefactor(signal_data, movement_mask, low_amp_mask, sampling_rate=100):
|
||||
assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}"
|
||||
assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}"
|
||||
|
||||
processed_mask = movement_mask.copy()
|
||||
processed_mask = processed_mask.repeat(sampling_rate)
|
||||
wavefactor_nan, wavefactor = calculate_by_slide_windows(lambda x: np.sqrt(np.mean(x ** 2)) / np.mean(np.abs(x)),
|
||||
signal_data, processed_mask, sampling_rate=sampling_rate, window_second=2, step_second=1)
|
||||
|
||||
return wavefactor_nan, wavefactor
|
||||
|
||||
@timing_decorator()
|
||||
def calc_peakfactor(signal_data, movement_mask, low_amp_mask, sampling_rate=100):
|
||||
assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}"
|
||||
assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}"
|
||||
|
||||
processed_mask = movement_mask.copy()
|
||||
processed_mask = processed_mask.repeat(sampling_rate)
|
||||
peakfactor_nan, peakfactor = calculate_by_slide_windows(lambda x: np.max(np.abs(x)) / np.sqrt(np.mean(x ** 2)),
|
||||
signal_data, processed_mask, sampling_rate=sampling_rate, window_second=2, step_second=1)
|
||||
|
||||
return peakfactor_nan, peakfactor
|
54
utils/HYS_FileReader.py
Normal file
54
utils/HYS_FileReader.py
Normal file
@ -0,0 +1,54 @@
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# 尝试导入 Polars
|
||||
try:
|
||||
import polars as pl
|
||||
HAS_POLARS = True
|
||||
except ImportError:
|
||||
HAS_POLARS = False
|
||||
|
||||
|
||||
def read_signal_txt(path: Union[str, Path]) -> np.ndarray:
|
||||
"""
|
||||
Read a txt file and return the first column as a numpy array.
|
||||
|
||||
Args:
|
||||
path (str | Path): Path to the txt file.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The first column of the txt file as a numpy array.
|
||||
"""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
if HAS_POLARS:
|
||||
df = pl.read_csv(path, has_header=False, infer_schema_length=0)
|
||||
return df[:, 0].to_numpy()
|
||||
else:
|
||||
df = pd.read_csv(path, header=None, dtype=float)
|
||||
return df.iloc[:, 0].to_numpy()
|
||||
|
||||
|
||||
def read_laebl_csv(path: Union[str, Path]) -> pd.DataFrame:
|
||||
"""
|
||||
Read a CSV file and return it as a pandas DataFrame.
|
||||
|
||||
Args:
|
||||
path (str | Path): Path to the CSV file.
|
||||
Returns:
|
||||
pd.DataFrame: The content of the CSV file as a pandas DataFrame.
|
||||
"""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
# 直接用pandas读取 包含中文 故指定编码
|
||||
df = pd.read_csv(path, encoding="gbk")
|
||||
df["Start"] = df["Start"].astype(int)
|
||||
df["End"] = df["End"].astype(int)
|
||||
return df
|
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
256
utils/operation_tools.py
Normal file
256
utils/operation_tools.py
Normal file
@ -0,0 +1,256 @@
|
||||
import time
|
||||
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||||
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
||||
|
||||
from scipy import ndimage, signal
|
||||
from functools import wraps
|
||||
|
||||
# 全局配置
|
||||
class Config:
|
||||
time_verbose = False
|
||||
|
||||
def timing_decorator():
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
if Config.time_verbose: # 运行时检查全局配置
|
||||
print(f"函数 '{func.__name__}' 执行耗时: {elapsed_time:.4f} 秒")
|
||||
return result
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
@timing_decorator()
|
||||
def read_auto(file_path):
|
||||
# print('suffix: ', file_path.suffix)
|
||||
if file_path.suffix == '.txt':
|
||||
# 使用pandas read csv读取txt
|
||||
return pd.read_csv(file_path, header=None).to_numpy().reshape(-1)
|
||||
elif file_path.suffix == '.npy':
|
||||
return np.load(file_path.__str__())
|
||||
elif file_path.suffix == '.base64':
|
||||
with open(file_path) as f:
|
||||
files = f.readlines()
|
||||
|
||||
data = np.array(files, dtype=int)
|
||||
return data
|
||||
else:
|
||||
raise ValueError('这个文件类型不支持,需要自己写读取程序')
|
||||
|
||||
@timing_decorator()
|
||||
def Butterworth(data, type, low_cut=0.0, high_cut=0.0, order=10,sample_rate=1000):
|
||||
|
||||
if type == "lowpass": # 低通滤波处理
|
||||
sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos')
|
||||
return signal.sosfiltfilt(sos, np.array(data))
|
||||
elif type == "bandpass": # 带通滤波处理
|
||||
low = low_cut / (sample_rate * 0.5)
|
||||
high = high_cut / (sample_rate * 0.5)
|
||||
sos = signal.butter(order, [low, high], btype='bandpass', output='sos')
|
||||
return signal.sosfiltfilt(sos, np.array(data))
|
||||
elif type == "highpass": # 高通滤波处理
|
||||
sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos')
|
||||
return signal.sosfiltfilt(sos, np.array(data))
|
||||
else: # 警告,滤波器类型必须有
|
||||
raise ValueError("Please choose a type of fliter")
|
||||
|
||||
@timing_decorator()
|
||||
def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000):
|
||||
"""
|
||||
高效整数倍降采样长信号(适合8小时以上),分段处理以优化内存和速度。
|
||||
|
||||
参数:
|
||||
original_signal : array-like, 原始信号数组
|
||||
original_fs : float, 原始采样率 (Hz)
|
||||
target_fs : float, 目标采样率 (Hz)
|
||||
chunk_size : int, 每段处理的样本数,默认100000
|
||||
|
||||
返回:
|
||||
downsampled_signal : array-like, 降采样后的信号
|
||||
"""
|
||||
# 输入验证
|
||||
if not isinstance(original_signal, np.ndarray):
|
||||
original_signal = np.array(original_signal)
|
||||
if original_fs <= target_fs:
|
||||
raise ValueError("目标采样率必须小于原始采样率")
|
||||
if target_fs <= 0 or original_fs <= 0:
|
||||
raise ValueError("采样率必须为正数")
|
||||
|
||||
# 计算降采样因子(必须为整数)
|
||||
downsample_factor = original_fs / target_fs
|
||||
if not downsample_factor.is_integer():
|
||||
raise ValueError("降采样因子必须为整数倍")
|
||||
downsample_factor = int(downsample_factor)
|
||||
|
||||
# 计算总输出长度
|
||||
total_length = len(original_signal)
|
||||
output_length = total_length // downsample_factor
|
||||
|
||||
# 初始化输出数组
|
||||
downsampled_signal = np.zeros(output_length)
|
||||
|
||||
# 分段处理
|
||||
for start in range(0, total_length, chunk_size):
|
||||
end = min(start + chunk_size, total_length)
|
||||
chunk = original_signal[start:end]
|
||||
|
||||
# 使用decimate进行整数倍降采样
|
||||
chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True)
|
||||
|
||||
# 计算输出位置
|
||||
out_start = start // downsample_factor
|
||||
out_end = out_start + len(chunk_downsampled)
|
||||
if out_end > output_length:
|
||||
chunk_downsampled = chunk_downsampled[:output_length - out_start]
|
||||
|
||||
downsampled_signal[out_start:out_end] = chunk_downsampled
|
||||
|
||||
return downsampled_signal
|
||||
|
||||
@timing_decorator()
|
||||
def average_filter(raw_data, sample_rate, window_size=20):
|
||||
kernel = np.ones(window_size * sample_rate) / (window_size * sample_rate)
|
||||
filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect')
|
||||
convolve_filter_signal = raw_data - filtered
|
||||
return convolve_filter_signal
|
||||
|
||||
|
||||
|
||||
|
||||
def merge_short_gaps(state_sequence, time_points, max_gap_sec):
|
||||
"""
|
||||
合并状态序列中间隔小于指定时长的段
|
||||
|
||||
参数:
|
||||
- state_sequence: numpy array,状态序列(0/1)
|
||||
- time_points: numpy array,每个状态点对应的时间点
|
||||
- max_gap_sec: float,要合并的最大间隔(秒)
|
||||
|
||||
返回:
|
||||
- merged_sequence: numpy array,合并后的状态序列
|
||||
"""
|
||||
if len(state_sequence) <= 1:
|
||||
return state_sequence
|
||||
|
||||
merged_sequence = state_sequence.copy()
|
||||
|
||||
# 找出状态转换点
|
||||
transitions = np.diff(np.concatenate([[0], merged_sequence, [0]]))
|
||||
# 找出状态1的起始和结束位置
|
||||
state_starts = np.where(transitions == 1)[0]
|
||||
state_ends = np.where(transitions == -1)[0] - 1
|
||||
|
||||
# 检查每对连续的状态1
|
||||
for i in range(len(state_starts) - 1):
|
||||
if state_ends[i] < len(time_points) and state_starts[i + 1] < len(time_points):
|
||||
# 计算间隔时长
|
||||
gap_duration = time_points[state_starts[i + 1]] - time_points[state_ends[i]]
|
||||
# 如果间隔小于阈值,则合并
|
||||
if gap_duration <= max_gap_sec:
|
||||
merged_sequence[state_ends[i]:state_starts[i + 1]] = 1
|
||||
|
||||
return merged_sequence
|
||||
|
||||
|
||||
def remove_short_durations(state_sequence, time_points, min_duration_sec):
|
||||
"""
|
||||
移除状态序列中持续时间短于指定阈值的段
|
||||
|
||||
参数:
|
||||
- state_sequence: numpy array,状态序列(0/1)
|
||||
- time_points: numpy array,每个状态点对应的时间点
|
||||
- min_duration_sec: float,要保留的最小持续时间(秒)
|
||||
|
||||
返回:
|
||||
- filtered_sequence: numpy array,过滤后的状态序列
|
||||
"""
|
||||
if len(state_sequence) <= 1:
|
||||
return state_sequence
|
||||
|
||||
filtered_sequence = state_sequence.copy()
|
||||
|
||||
# 找出状态转换点
|
||||
transitions = np.diff(np.concatenate([[0], filtered_sequence, [0]]))
|
||||
# 找出状态1的起始和结束位置
|
||||
state_starts = np.where(transitions == 1)[0]
|
||||
state_ends = np.where(transitions == -1)[0] - 1
|
||||
|
||||
# 检查每个状态1的持续时间
|
||||
for i in range(len(state_starts)):
|
||||
if state_starts[i] < len(time_points) and state_ends[i] < len(time_points):
|
||||
# 计算持续时间
|
||||
duration = time_points[state_ends[i]] - time_points[state_starts[i]]
|
||||
if state_ends[i] == len(time_points) - 1:
|
||||
# 如果是最后一个窗口,加上一个窗口的长度
|
||||
duration += time_points[1] - time_points[0] if len(time_points) > 1 else 0
|
||||
|
||||
# 如果持续时间短于阈值,则移除
|
||||
if duration < min_duration_sec:
|
||||
filtered_sequence[state_starts[i]:state_ends[i] + 1] = 0
|
||||
|
||||
return filtered_sequence
|
||||
|
||||
|
||||
@timing_decorator()
|
||||
def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, window_second=20, step_second=None):
|
||||
# 处理标志位长度与 signal_data 对齐
|
||||
if calc_mask is None:
|
||||
calc_mask = np.zeros(len(signal_data), dtype=bool)
|
||||
|
||||
if step_second is None:
|
||||
step_second = window_second
|
||||
|
||||
step_length = step_second * sampling_rate
|
||||
window_length = window_second * sampling_rate
|
||||
|
||||
origin_seconds = len(signal_data) // sampling_rate
|
||||
total_samples = len(signal_data)
|
||||
|
||||
|
||||
# reflect padding
|
||||
left_pad_size = int(window_length // 2)
|
||||
right_pad_size = window_length - left_pad_size
|
||||
data = np.pad(signal_data, (left_pad_size, right_pad_size), mode='reflect')
|
||||
|
||||
num_segments = int(np.ceil(len(signal_data) / step_length))
|
||||
values_nan = np.full(num_segments, np.nan)
|
||||
|
||||
# print(f"num_segments: {num_segments}, step_length: {step_length}, window_length: {window_length}")
|
||||
for i in range(num_segments):
|
||||
# 包含体动则仅计算不含体动部分
|
||||
start = int(i * step_length)
|
||||
end = start + window_length
|
||||
segment = data[start:end]
|
||||
values_nan[i] = func(segment)
|
||||
|
||||
|
||||
values_nan = values_nan.repeat(step_second)[:origin_seconds]
|
||||
|
||||
for i in range(len(values_nan)):
|
||||
if calc_mask[i]:
|
||||
values_nan[i] = np.nan
|
||||
|
||||
values = values_nan.copy()
|
||||
|
||||
# 插值处理体动区域的 NaN 值
|
||||
def interpolate_nans(x, t):
|
||||
valid_mask = ~np.isnan(x)
|
||||
return np.interp(t, t[valid_mask], x[valid_mask])
|
||||
|
||||
values = interpolate_nans(values, np.arange(len(values)))
|
||||
|
||||
return values_nan, values
|
||||
|
||||
|
||||
|
||||
|
105
utils/statistics_metrics.py
Normal file
105
utils/statistics_metrics.py
Normal file
@ -0,0 +1,105 @@
|
||||
from utils.operation_tools import timing_decorator
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
@timing_decorator()
|
||||
def statistic_amplitude_metrics(data, aml_interval=None, time_interval=None):
|
||||
"""
|
||||
计算不同幅值区间占比和时间,最后汇总成混淆矩阵
|
||||
|
||||
参数:
|
||||
data: 采样率为1秒的一维序列,其中体动所在的区域用np.nan填充
|
||||
aml_interval: 幅值区间的分界点列表,默认为[200, 500, 1000, 2000]
|
||||
time_interval: 时间区间的分界点列表,单位为秒,默认为[60, 300, 1800, 3600]
|
||||
|
||||
返回:
|
||||
confusion_matrix: 幅值-时长统计矩阵
|
||||
summary: 汇总统计信息
|
||||
"""
|
||||
if aml_interval is None:
|
||||
aml_interval = [200, 500, 1000, 2000]
|
||||
|
||||
if time_interval is None:
|
||||
time_interval = [60, 300, 1800, 3600]
|
||||
# 检查输入
|
||||
if not isinstance(data, np.ndarray):
|
||||
data = np.array(data)
|
||||
|
||||
# 整个记录的时长(包括nan)
|
||||
total_duration = len(data)
|
||||
|
||||
# 创建幅值标签和时间标签
|
||||
amp_labels = [f"0-{aml_interval[0]}"]
|
||||
for i in range(len(aml_interval) - 1):
|
||||
amp_labels.append(f"{aml_interval[i]}-{aml_interval[i + 1]}")
|
||||
amp_labels.append(f"{aml_interval[-1]}+")
|
||||
|
||||
time_labels = [f"0-{time_interval[0]}"]
|
||||
for i in range(len(time_interval) - 1):
|
||||
time_labels.append(f"{time_interval[i]}-{time_interval[i + 1]}")
|
||||
time_labels.append(f"{time_interval[-1]}+")
|
||||
|
||||
# 初始化结果矩阵(时长)和片段数矩阵
|
||||
result_matrix = np.zeros((len(amp_labels), len(time_labels))) # 时长矩阵
|
||||
segment_count_matrix = np.zeros((len(amp_labels), len(time_labels))) # 片段数矩阵
|
||||
|
||||
# 有效信号总量(非NaN的数据点数量)
|
||||
valid_signal_length = np.sum(~np.isnan(data))
|
||||
|
||||
# 添加信号开始和结束的边界条件
|
||||
signal_padded = np.concatenate(([np.nan], data, [np.nan]))
|
||||
diff = np.diff(np.isnan(signal_padded).astype(int))
|
||||
|
||||
# 连续片段的起始位置(从 nan 变为非 nan)
|
||||
segment_starts = np.where(diff == -1)[0]
|
||||
# 连续片段的结束位置(从非 nan 变为 nan)
|
||||
segment_ends = np.where(diff == 1)[0]
|
||||
|
||||
# 计算每个片段的时长和平均幅值,并填充结果矩阵
|
||||
for start, end in zip(segment_starts, segment_ends):
|
||||
segment = data[start:end]
|
||||
duration = end - start # 时长(单位:秒)
|
||||
mean_amplitude = np.nanmean(segment) # 片段平均幅值
|
||||
|
||||
# 确定幅值区间
|
||||
if mean_amplitude <= aml_interval[0]:
|
||||
amp_idx = 0
|
||||
elif mean_amplitude > aml_interval[-1]:
|
||||
amp_idx = len(aml_interval)
|
||||
else:
|
||||
amp_idx = np.searchsorted(aml_interval, mean_amplitude)
|
||||
|
||||
# 确定时长区间
|
||||
if duration <= time_interval[0]:
|
||||
time_idx = 0
|
||||
elif duration > time_interval[-1]:
|
||||
time_idx = len(time_interval)
|
||||
else:
|
||||
time_idx = np.searchsorted(time_interval, duration)
|
||||
|
||||
# 在对应位置累加该片段的时长和片段数
|
||||
result_matrix[amp_idx, time_idx] += duration
|
||||
segment_count_matrix[amp_idx, time_idx] += 1 # 片段数加1
|
||||
|
||||
# 创建DataFrame以便于展示和后续处理
|
||||
confusion_matrix = pd.DataFrame(result_matrix, index=amp_labels, columns=time_labels)
|
||||
|
||||
# 计算行和列的总和
|
||||
confusion_matrix['总计'] = confusion_matrix.sum(axis=1)
|
||||
row_totals = confusion_matrix['总计'].copy()
|
||||
|
||||
# 计算百分比(相对于有效记录时长)
|
||||
confusion_matrix_percent = confusion_matrix.div(total_duration) * 100
|
||||
|
||||
# 汇总统计
|
||||
summary = {
|
||||
'total_duration': total_duration,
|
||||
'total_valid_signal': valid_signal_length,
|
||||
'amplitude_distribution': row_totals.to_dict(),
|
||||
'amplitude_percent': row_totals.div(total_duration) * 100,
|
||||
'time_distribution': confusion_matrix.sum(axis=0).to_dict(),
|
||||
'time_percent': confusion_matrix.sum(axis=0).div(total_duration) * 100
|
||||
}
|
||||
|
||||
return summary, (confusion_matrix, segment_count_matrix, confusion_matrix_percent, valid_signal_length,
|
||||
total_duration, time_labels, amp_labels)
|
Loading…
Reference in New Issue
Block a user