diff --git a/HYS_process.py b/HYS_process.py new file mode 100644 index 0000000..1f07192 --- /dev/null +++ b/HYS_process.py @@ -0,0 +1,21 @@ +""" +本脚本完成对呼研所数据的处理,包含以下功能: +1. 数据读取与预处理 + 从传入路径中,进行数据和标签的读取,并进行初步的预处理 + 预处理包括为数据进行滤波、去噪等操作 +2. 数据清洗与异常值处理 +3. 输出清晰后的统计信息 +4. 数据保存 + 将处理后的数据保存到指定路径,便于后续使用 + 主要是保存切分后的数据位置和标签 +5. 可视化 + 提供数据处理前后的可视化对比,帮助理解数据变化 + 绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等 + + + + +# 低幅值区间规则标定与剔除 +# 高幅值连续体动规则标定与剔除 +# 手动标定不可用区间提剔除 +""" \ No newline at end of file diff --git a/SHHS_process.py b/SHHS_process.py new file mode 100644 index 0000000..e69de29 diff --git a/draw_tools/__init__.py b/draw_tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/draw_tools/draw_statics.py b/draw_tools/draw_statics.py new file mode 100644 index 0000000..88f790e --- /dev/null +++ b/draw_tools/draw_statics.py @@ -0,0 +1,175 @@ +from matplotlib.axes import Axes +from matplotlib.gridspec import GridSpec +from matplotlib.colors import PowerNorm +import seaborn as sns +import numpy as np + + +def draw_ax_confusion_matrix(ax:Axes, confusion_matrix, segment_count_matrix, confusion_matrix_percent, + valid_signal_length, total_duration, time_labels, amp_labels, signal_type=''): + # 创建用于热图注释的文本矩阵 + text_matrix = np.empty((len(amp_labels), len(time_labels)), dtype=object) + percent_matrix = np.zeros((len(amp_labels), len(time_labels))) + + # 填充文本矩阵和百分比矩阵 + for i in range(len(amp_labels)): + for j in range(len(time_labels)): + val = confusion_matrix.iloc[i, j] + segment_count = segment_count_matrix[i, j] + percent = confusion_matrix_percent.iloc[i, j] + text_matrix[i, j] = f"[{int(segment_count)}]{int(val)}\n({percent:.2f}%)" + percent_matrix[i, j] = percent + + # 绘制热图,调整颜色条位置 + sns_heatmap = sns.heatmap(percent_matrix, annot=text_matrix, fmt='', + xticklabels=time_labels, yticklabels=amp_labels, + cmap='YlGnBu', ax=ax, vmin=0, vmax=100, + norm=PowerNorm(gamma=0.5, vmin=0, vmax=100), + cbar_kws={'label': '百分比 (%)', 'shrink': 0.6, 'pad': 0.15}, + # annot_kws={'fontsize': 12} + ) + + # 添加行统计(右侧) + row_sums = confusion_matrix['总计'] + row_percents = confusion_matrix_percent['总计'] + ax.text(len(time_labels) + 1, -0.5, "各幅值时长\n(有效区间百分比%)", ha='center', va='center') + for i, (val, perc) in enumerate(zip(row_sums, row_percents)): + ax.text(len(time_labels) + 0.5, i + 0.5, f"{int(val)}\n({perc:.2f}%)", + ha='center', va='center') + + # 添加列统计(底部) + col_sums = segment_count_matrix.sum(axis=0) + col_percents = confusion_matrix.sum(axis=0) / total_duration * 100 + ax.text(-1, len(amp_labels) + 0.5, "[各时长片段数]\n(有效区间百分比%)", ha='center', va='center', rotation=0) + for j, (val, perc) in enumerate(zip(col_sums, col_percents)): + ax.text(j + 0.5, len(amp_labels) + 0.5, f"[{int(val)}]\n({perc:.2f}%)", + ha='center', va='center', rotation=0) + + # 将x轴坐标移到顶部 + ax.xaxis.set_label_position('top') + ax.xaxis.tick_top() + + # 设置标题和标签 + ax.set_title('幅值-时长统计矩阵', pad=40) + ax.set_xlabel('持续时间区间 (秒)', labelpad=10) + ax.set_ylabel('幅值区间') + + # 设置坐标轴标签水平显示 + ax.set_xticklabels(time_labels, rotation=0) + ax.set_yticklabels(amp_labels, rotation=0) + + # 调整颜色条位置 + cbar = sns_heatmap.collections[0].colorbar + cbar.ax.yaxis.set_label_position('right') + + # 添加图例说明 + ax.text(-2, -1, "热图内:\n[片段数]时长\n(有效区间百分比%)", + ha='left', va='top', bbox=dict(facecolor='none', edgecolor='black', alpha=0.5)) + + # 总计 + # 总片段数 + total_segments = segment_count_matrix.sum() + # 有效总市场占比 + total_percent = valid_signal_length / total_duration * 100 + ax.text(len(time_labels) + 0.5, len(amp_labels) + 0.5, f"[{int(total_segments)}]{valid_signal_length}\n({total_percent:.2f}%)", + ha='center', va='center') + +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches + +def draw_ax_amp(ax, signal_name, original_times, origin_signal, no_movement_signal, mav_values, + movement_position_list, low_amp_position_list, signal_second_length, aml_list=None): + # 绘制信号线 + ax.plot(original_times, origin_signal, 'k-', linewidth=1, alpha=0.7) + + # 添加红色和蓝色的axvspan区域 + for start, end in movement_position_list: + if start < len(original_times) and end < len(original_times): + ax.axvspan(start, end, color='red', alpha=0.3) + + for start, end in low_amp_position_list: + if start < len(original_times) and end < len(original_times): + ax.axvspan(start, end, color='blue', alpha=0.2) + + # 如果存在AML列表,绘制水平线 + if aml_list is not None: + color_map = ['red', 'orange', 'green'] + for i, aml in enumerate(aml_list): + ax.hlines(aml, 0, signal_second_length, color=color_map[min(i, 2)], linestyle='dashed', linewidth=2, alpha=0.5, label=f'{aml} aml') + + + ax.plot(np.linspace(0, len(mav_values), len(mav_values)), mav_values, color='blue', linewidth=2, alpha=0.4, label='2sMAV') + + # 设置Y轴范围 + ax.set_ylim((-2000, 2000)) + + # 创建表示不同颜色区域的图例 + red_patch = mpatches.Patch(color='red', alpha=0.2, label='Movement Area') + blue_patch = mpatches.Patch(color='blue', alpha=0.2, label='Low Amplitude Area') + + # 添加新的图例项,并保留原来的图例项 + handles, labels = ax.get_legend_handles_labels() # 获取原有图例 + ax.legend(handles=[red_patch, blue_patch] + handles, + labels=['Movement Area', 'Low Amplitude Area'] + labels, + loc='upper right', + bbox_to_anchor=(1, 1.4), + framealpha=0.5) + + # 设置标题和标签 + ax.set_title(f'{signal_name} Signal') + ax.set_ylabel('Amplitude') + ax.set_xlabel('Time (s)') + + # 启用网格 + ax.grid(True, linestyle='--', alpha=0.7) + + +def draw_signal_metrics(bcg_origin_signal, resp_origin_signal, bcg_no_movement_signal, resp_no_movement_signal, + bcg_sampling_rate, resp_sampling_rate, bcg_movement_position_list, bcg_low_amp_position_list, + resp_movement_position_list, resp_low_amp_position_list, + bcg_mav_values, resp_mav_values, bcg_statistic_info, resp_statistic_info, + signal_info, show=False, save_path=None): + + # 创建图像 + fig = plt.figure(figsize=(18, 10)) + + gs = GridSpec(2, 2, height_ratios=[2, 2], width_ratios=[4, 2], hspace=0.5) + + signal_second_length = len(bcg_origin_signal) // bcg_sampling_rate + bcg_origin_times = np.linspace(0, signal_second_length, len(bcg_origin_signal)) + resp_origin_times = np.linspace(0, signal_second_length, len(resp_origin_signal)) + # 子图 1:原始信号 + ax1 = fig.add_subplot(gs[0]) + draw_ax_amp(ax=ax1, signal_name='BCG', original_times=bcg_origin_times, origin_signal=bcg_origin_signal, + no_movement_signal=bcg_no_movement_signal, mav_values=bcg_mav_values, + movement_position_list=bcg_movement_position_list, low_amp_position_list=bcg_low_amp_position_list, + signal_second_length=signal_second_length, aml_list=[200, 500, 1000]) + + ax2 = fig.add_subplot(gs[1]) + param_names = ['confusion_matrix', 'segment_count_matrix', 'confusion_matrix_percent', + 'valid_signal_length', 'total_duration', 'time_labels', 'amp_labels'] + params = dict(zip(param_names, bcg_statistic_info)) + params['ax'] = ax2 + draw_ax_confusion_matrix(**params) + + ax3 = fig.add_subplot(gs[2], sharex=ax1) + draw_ax_amp(ax=ax3, signal_name='RSEP', original_times=resp_origin_times, origin_signal=resp_origin_signal, + no_movement_signal=resp_no_movement_signal, mav_values=resp_mav_values, + movement_position_list=resp_movement_position_list, low_amp_position_list=resp_low_amp_position_list, + signal_second_length=signal_second_length, aml_list=[100, 300, 500]) + ax4 = fig.add_subplot(gs[3]) + params = dict(zip(param_names, resp_statistic_info)) + params['ax'] = ax4 + draw_ax_confusion_matrix(**params) + + # 全局标题 + fig.suptitle(f'{signal_info} Signal Metrics', fontsize=16, x=0.35, y=0.95) + + if save_path is not None: + # 保存图像 + plt.savefig(save_path, dpi=300) + + if show: + plt.show() + + plt.close() \ No newline at end of file diff --git a/signal_method/__init__.py b/signal_method/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py new file mode 100644 index 0000000..74e5929 --- /dev/null +++ b/signal_method/rule_base_event.py @@ -0,0 +1,207 @@ +from utils.operation_tools import timing_decorator +import numpy as np +from utils.operation_tools import merge_short_gaps, remove_short_durations + +@timing_decorator() +def detect_low_amplitude_signal(signal_data, sampling_rate, window_size_sec=1, stride_sec=None, + amplitude_threshold=50, merge_gap_sec=10, min_duration_sec=5): + """ + 检测信号中的低幅值状态,通过计算RMS值判断信号强度是否低于设定阈值。 + + 参数: + - signal_data: numpy array,输入的信号数据 + - sampling_rate: int,信号的采样率(Hz) + - window_size_sec: float,分析窗口的时长(秒),默认值为 1 秒 + - stride_sec: float,窗口滑动步长(秒),默认值为None(等于window_size_sec,无重叠) + - amplitude_threshold: float,RMS阈值,低于此值表示低幅值状态,默认值为 50 + - merge_gap_sec: float,要合并的状态之间的最大间隔(秒),默认值为 10 秒 + - min_duration_sec: float,要保留的状态的最小持续时间(秒),默认值为 5 秒 + + 返回: + - low_amplitude_mask: numpy array,低幅值状态的掩码(1表示低幅值,0表示正常幅值) + """ + # 计算窗口大小(样本数) + window_samples = int(window_size_sec * sampling_rate) + + # 如果未指定步长,设置为窗口大小(无重叠) + if stride_sec is None: + stride_sec = window_size_sec + + # 计算步长(样本数) + stride_samples = int(stride_sec * sampling_rate) + + # 确保步长至少为1 + stride_samples = max(1, stride_samples) + + # 处理信号边界,使用反射填充 + pad_size = window_samples // 2 + padded_signal = np.pad(signal_data, pad_size, mode='reflect') + + # 计算填充后的窗口数量 + num_windows = max(1, (len(padded_signal) - window_samples) // stride_samples + 1) + + # 初始化RMS值数组 + rms_values = np.zeros(num_windows) + + # 计算每个窗口的RMS值 + for i in range(num_windows): + start_idx = i * stride_samples + end_idx = min(start_idx + window_samples, len(signal_data)) + + # 处理窗口,包括可能不完整的最后一个窗口 + window_data = signal_data[start_idx:end_idx] + if len(window_data) > 0: + rms_values[i] = np.sqrt(np.mean(window_data ** 2)) + else: + rms_values[i] = 0 + + # 生成初始低幅值掩码:RMS低于阈值的窗口标记为1(低幅值),其他为0 + low_amplitude_mask = np.where(rms_values <= amplitude_threshold, 1, 0) + + # 计算原始信号对应的窗口索引范围 + orig_start_window = pad_size // stride_samples + if stride_sec == 1: + orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + else: + orig_end_window = orig_start_window + (len(signal_data) // stride_samples) + 1 + + # 只保留原始信号对应的窗口低幅值掩码 + low_amplitude_mask = low_amplitude_mask[orig_start_window:orig_end_window] + # print("low_amplitude_mask_length: ", len(low_amplitude_mask)) + num_original_windows = len(low_amplitude_mask) + + # 转换为时间轴上的状态序列 + # 计算每个窗口对应的时间点(秒) + time_points = np.arange(num_original_windows) * stride_sec + + # 如果需要合并间隔小的状态 + if merge_gap_sec > 0: + low_amplitude_mask = merge_short_gaps(low_amplitude_mask, time_points, merge_gap_sec) + + # 如果需要移除短时状态 + if min_duration_sec > 0: + low_amplitude_mask = remove_short_durations(low_amplitude_mask, time_points, min_duration_sec) + + low_amplitude_mask = low_amplitude_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate] + + # 低幅值状态起止位置 [[start, end], [start, end], ...] + low_amplitude_start = np.where(np.diff(np.concatenate([[0], low_amplitude_mask])) == 1)[0] + low_amplitude_end = np.where(np.diff(np.concatenate([low_amplitude_mask, [0]])) == -1)[0] + low_amplitude_position_list = [[start, end] for start, end in zip(low_amplitude_start, low_amplitude_end)] + + return low_amplitude_mask, low_amplitude_position_list + + +def get_typical_segment_for_continues_signal(signal_data, sampling_rate=100, window_size=30, step_size=1): + """ + 获取十个片段 + :param signal_data: 信号数据 + :param sampling_rate: 采样率 + :param window_size: 窗口大小(秒) + :param step_size: 步长(秒) + :return: 典型片段列表 + """ + pass + + +# 基于体动位置和幅值的睡姿识别 +# 主要是依靠体动mask,将整夜分割成多个有效片段,然后每个片段计算幅值指标,判断两个片段的幅值指标是否存在显著差异,如果存在显著差异,则认为存在睡姿变化 +# 考虑到每个片段长度为10s,所以每个片段的幅值指标计算时间长度为10s,然后计算每个片段的幅值指标 +# 仅对比相邻片段的幅值指标,如果存在显著差异,则认为存在睡姿变化,即每个体动相邻的30秒内存在睡姿变化,如果片段不足30秒,则按实际长度对比 + +@timing_decorator() +def position_based_sleep_recognition(signal_data, movement_mask, sampling_rate=100, window_size_sec=30, + interval_to_movement=10): + # 获取有效片段起止位置 + valid_mask = 1 - movement_mask + valid_starts = np.where(np.diff(np.concatenate([[0], valid_mask])) == 1)[0] + valid_ends = np.where(np.diff(np.concatenate([valid_mask, [0]])) == -1)[0] + + movement_start = np.where(np.diff(np.concatenate([[0], movement_mask])) == 1)[0] + movement_end = np.where(np.diff(np.concatenate([movement_mask, [0]])) == -1)[0] + + segment_left_average_amplitude = [] + segment_right_average_amplitude = [] + segment_left_average_energy = [] + segment_right_average_energy = [] + + # window_samples = int(window_size_sec * sampling_rate) + # pad_size = window_samples // 2 + # padded_signal = np.pad(signal_data, pad_size, mode='reflect') + + for start, end in zip(valid_starts, valid_ends): + start *= sampling_rate + end *= sampling_rate + # 避免过短的片段 + if end - start <= sampling_rate: # 小于1秒的片段不考虑 + continue + # 获取当前片段数据 + + + elif end - start < (window_size_sec * interval_to_movement + 1) * sampling_rate: + left_start = start + left_end = min(end, left_start + window_size_sec * sampling_rate) + right_start = max(start, end - window_size_sec * sampling_rate) + right_end = end + else: + left_start = start + interval_to_movement * sampling_rate + left_end = left_start + window_size_sec * sampling_rate + right_start = end - interval_to_movement * sampling_rate - window_size_sec * sampling_rate + right_end = end + + # 新的end - start确保为200的整数倍 + if (left_end - left_start) % (2 * sampling_rate) != 0: + left_end = left_start + ((left_end - left_start) // (2 * sampling_rate)) * (2 * sampling_rate) + if (right_end - right_start) % (2 * sampling_rate) != 0: + right_end = right_start + ((right_end - right_start) // (2 * sampling_rate)) * (2 * sampling_rate) + + # 计算每个片段的幅值指标 + left_mav = np.mean(np.max(signal_data[left_start:left_end].reshape(-1, 2 * sampling_rate), axis=0)) - np.mean( + np.min(signal_data[left_start:left_end].reshape(-1, 2 * sampling_rate), axis=0)) + right_mav = np.mean( + np.max(signal_data[right_start:right_end].reshape(-1, 2 * sampling_rate), axis=0)) - np.mean( + np.min(signal_data[right_start:right_end].reshape(-1, 2 * sampling_rate), axis=0)) + segment_left_average_amplitude.append(left_mav) + segment_right_average_amplitude.append(right_mav) + + left_energy = np.sum(np.abs(signal_data[left_start:left_end] ** 2)) + right_energy = np.sum(np.abs(signal_data[right_start:right_end] ** 2)) + segment_left_average_energy.append(left_energy) + segment_right_average_energy.append(right_energy) + + position_changes = [] + position_change_times = [] + for i in range(1, len(segment_left_average_amplitude)): + # 计算幅值指标的变化率 + left_amplitude_change = abs(segment_left_average_amplitude[i] - segment_left_average_amplitude[i - 1]) / max( + segment_left_average_amplitude[i - 1], 1e-6) + right_amplitude_change = abs(segment_right_average_amplitude[i] - segment_right_average_amplitude[i - 1]) / max( + segment_right_average_amplitude[i - 1], 1e-6) + + # 计算能量指标的变化率 + left_energy_change = abs(segment_left_average_energy[i] - segment_left_average_energy[i - 1]) / max( + segment_left_average_energy[i - 1], 1e-6) + right_energy_change = abs(segment_right_average_energy[i] - segment_right_average_energy[i - 1]) / max( + segment_right_average_energy[i - 1], 1e-6) + + # 判断是否存在显著变化 (可根据实际情况调整阈值) + threshold_amplitude = 0.1 # 幅值变化阈值 + threshold_energy = 0.1 # 能量变化阈值 + + # 如果左右通道中的任一通道同时满足幅值和能量的变化阈值,则认为存在姿势变化 + left_significant_change = (left_amplitude_change > threshold_amplitude) and ( + left_energy_change > threshold_energy) + right_significant_change = (right_amplitude_change > threshold_amplitude) and ( + right_energy_change > threshold_energy) + + if left_significant_change or right_significant_change: + # 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示 + position_changes.append(1) + position_change_times.append((movement_start[i - 1], movement_end[i - 1])) + else: + position_changes.append(0) # 0表示不存在姿势变化 + + # print(i,movement_start[i], movement_end[i], round(left_amplitude_change, 2), round(right_amplitude_change, 2), round(left_energy_change, 2), round(right_energy_change, 2)) + + return position_changes, position_change_times + diff --git a/signal_method/time_metrics.py b/signal_method/time_metrics.py new file mode 100644 index 0000000..a6a2a94 --- /dev/null +++ b/signal_method/time_metrics.py @@ -0,0 +1,41 @@ +from utils.operation_tools import calculate_by_slide_windows +from utils.operation_tools import timing_decorator +import numpy as np + + +@timing_decorator() +def calc_mav(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window_second=10, step_second=1, inner_window_second=2): + assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" + assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" + # print(f"movement_mask_length: {len(movement_mask)}, signal_data_length: {len(signal_data)}") + processed_mask = movement_mask.copy() + def mav_func(x): + return np.mean(np.nanmax(x.reshape(-1, inner_window_second*sampling_rate), axis=1) - np.nanmin(x.reshape(-1, inner_window_second*sampling_rate), axis=1)) / 2 + mav_nan, mav = calculate_by_slide_windows(mav_func, signal_data, processed_mask, sampling_rate=sampling_rate, + window_second=window_second, step_second=step_second) + + return mav_nan, mav + +@timing_decorator() +def calc_wavefactor(signal_data, movement_mask, low_amp_mask, sampling_rate=100): + assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" + assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" + + processed_mask = movement_mask.copy() + processed_mask = processed_mask.repeat(sampling_rate) + wavefactor_nan, wavefactor = calculate_by_slide_windows(lambda x: np.sqrt(np.mean(x ** 2)) / np.mean(np.abs(x)), + signal_data, processed_mask, sampling_rate=sampling_rate, window_second=2, step_second=1) + + return wavefactor_nan, wavefactor + +@timing_decorator() +def calc_peakfactor(signal_data, movement_mask, low_amp_mask, sampling_rate=100): + assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}" + assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}" + + processed_mask = movement_mask.copy() + processed_mask = processed_mask.repeat(sampling_rate) + peakfactor_nan, peakfactor = calculate_by_slide_windows(lambda x: np.max(np.abs(x)) / np.sqrt(np.mean(x ** 2)), + signal_data, processed_mask, sampling_rate=sampling_rate, window_second=2, step_second=1) + + return peakfactor_nan, peakfactor \ No newline at end of file diff --git a/utils/HYS_FileReader.py b/utils/HYS_FileReader.py new file mode 100644 index 0000000..82e4584 --- /dev/null +++ b/utils/HYS_FileReader.py @@ -0,0 +1,54 @@ +from pathlib import Path +from typing import Union + +import numpy as np +import pandas as pd + +# 尝试导入 Polars +try: + import polars as pl + HAS_POLARS = True +except ImportError: + HAS_POLARS = False + + +def read_signal_txt(path: Union[str, Path]) -> np.ndarray: + """ + Read a txt file and return the first column as a numpy array. + + Args: + path (str | Path): Path to the txt file. + + Returns: + np.ndarray: The first column of the txt file as a numpy array. + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + if HAS_POLARS: + df = pl.read_csv(path, has_header=False, infer_schema_length=0) + return df[:, 0].to_numpy() + else: + df = pd.read_csv(path, header=None, dtype=float) + return df.iloc[:, 0].to_numpy() + + +def read_laebl_csv(path: Union[str, Path]) -> pd.DataFrame: + """ + Read a CSV file and return it as a pandas DataFrame. + + Args: + path (str | Path): Path to the CSV file. + Returns: + pd.DataFrame: The content of the CSV file as a pandas DataFrame. + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + # 直接用pandas读取 包含中文 故指定编码 + df = pd.read_csv(path, encoding="gbk") + df["Start"] = df["Start"].astype(int) + df["End"] = df["End"].astype(int) + return df \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/operation_tools.py b/utils/operation_tools.py new file mode 100644 index 0000000..bcc5062 --- /dev/null +++ b/utils/operation_tools.py @@ -0,0 +1,256 @@ +import time + +from pathlib import Path +import numpy as np +import pandas as pd +from matplotlib import pyplot as plt + + +plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 +plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 + +from scipy import ndimage, signal +from functools import wraps + +# 全局配置 +class Config: + time_verbose = False + +def timing_decorator(): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + elapsed_time = end_time - start_time + if Config.time_verbose: # 运行时检查全局配置 + print(f"函数 '{func.__name__}' 执行耗时: {elapsed_time:.4f} 秒") + return result + return wrapper + return decorator + +@timing_decorator() +def read_auto(file_path): + # print('suffix: ', file_path.suffix) + if file_path.suffix == '.txt': + # 使用pandas read csv读取txt + return pd.read_csv(file_path, header=None).to_numpy().reshape(-1) + elif file_path.suffix == '.npy': + return np.load(file_path.__str__()) + elif file_path.suffix == '.base64': + with open(file_path) as f: + files = f.readlines() + + data = np.array(files, dtype=int) + return data + else: + raise ValueError('这个文件类型不支持,需要自己写读取程序') + +@timing_decorator() +def Butterworth(data, type, low_cut=0.0, high_cut=0.0, order=10,sample_rate=1000): + + if type == "lowpass": # 低通滤波处理 + sos = signal.butter(order, low_cut / (sample_rate * 0.5), btype='lowpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + elif type == "bandpass": # 带通滤波处理 + low = low_cut / (sample_rate * 0.5) + high = high_cut / (sample_rate * 0.5) + sos = signal.butter(order, [low, high], btype='bandpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + elif type == "highpass": # 高通滤波处理 + sos = signal.butter(order, high_cut / (sample_rate * 0.5), btype='highpass', output='sos') + return signal.sosfiltfilt(sos, np.array(data)) + else: # 警告,滤波器类型必须有 + raise ValueError("Please choose a type of fliter") + +@timing_decorator() +def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000): + """ + 高效整数倍降采样长信号(适合8小时以上),分段处理以优化内存和速度。 + + 参数: + original_signal : array-like, 原始信号数组 + original_fs : float, 原始采样率 (Hz) + target_fs : float, 目标采样率 (Hz) + chunk_size : int, 每段处理的样本数,默认100000 + + 返回: + downsampled_signal : array-like, 降采样后的信号 + """ + # 输入验证 + if not isinstance(original_signal, np.ndarray): + original_signal = np.array(original_signal) + if original_fs <= target_fs: + raise ValueError("目标采样率必须小于原始采样率") + if target_fs <= 0 or original_fs <= 0: + raise ValueError("采样率必须为正数") + + # 计算降采样因子(必须为整数) + downsample_factor = original_fs / target_fs + if not downsample_factor.is_integer(): + raise ValueError("降采样因子必须为整数倍") + downsample_factor = int(downsample_factor) + + # 计算总输出长度 + total_length = len(original_signal) + output_length = total_length // downsample_factor + + # 初始化输出数组 + downsampled_signal = np.zeros(output_length) + + # 分段处理 + for start in range(0, total_length, chunk_size): + end = min(start + chunk_size, total_length) + chunk = original_signal[start:end] + + # 使用decimate进行整数倍降采样 + chunk_downsampled = signal.decimate(chunk, downsample_factor, ftype='iir', zero_phase=True) + + # 计算输出位置 + out_start = start // downsample_factor + out_end = out_start + len(chunk_downsampled) + if out_end > output_length: + chunk_downsampled = chunk_downsampled[:output_length - out_start] + + downsampled_signal[out_start:out_end] = chunk_downsampled + + return downsampled_signal + +@timing_decorator() +def average_filter(raw_data, sample_rate, window_size=20): + kernel = np.ones(window_size * sample_rate) / (window_size * sample_rate) + filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect') + convolve_filter_signal = raw_data - filtered + return convolve_filter_signal + + + + +def merge_short_gaps(state_sequence, time_points, max_gap_sec): + """ + 合并状态序列中间隔小于指定时长的段 + + 参数: + - state_sequence: numpy array,状态序列(0/1) + - time_points: numpy array,每个状态点对应的时间点 + - max_gap_sec: float,要合并的最大间隔(秒) + + 返回: + - merged_sequence: numpy array,合并后的状态序列 + """ + if len(state_sequence) <= 1: + return state_sequence + + merged_sequence = state_sequence.copy() + + # 找出状态转换点 + transitions = np.diff(np.concatenate([[0], merged_sequence, [0]])) + # 找出状态1的起始和结束位置 + state_starts = np.where(transitions == 1)[0] + state_ends = np.where(transitions == -1)[0] - 1 + + # 检查每对连续的状态1 + for i in range(len(state_starts) - 1): + if state_ends[i] < len(time_points) and state_starts[i + 1] < len(time_points): + # 计算间隔时长 + gap_duration = time_points[state_starts[i + 1]] - time_points[state_ends[i]] + # 如果间隔小于阈值,则合并 + if gap_duration <= max_gap_sec: + merged_sequence[state_ends[i]:state_starts[i + 1]] = 1 + + return merged_sequence + + +def remove_short_durations(state_sequence, time_points, min_duration_sec): + """ + 移除状态序列中持续时间短于指定阈值的段 + + 参数: + - state_sequence: numpy array,状态序列(0/1) + - time_points: numpy array,每个状态点对应的时间点 + - min_duration_sec: float,要保留的最小持续时间(秒) + + 返回: + - filtered_sequence: numpy array,过滤后的状态序列 + """ + if len(state_sequence) <= 1: + return state_sequence + + filtered_sequence = state_sequence.copy() + + # 找出状态转换点 + transitions = np.diff(np.concatenate([[0], filtered_sequence, [0]])) + # 找出状态1的起始和结束位置 + state_starts = np.where(transitions == 1)[0] + state_ends = np.where(transitions == -1)[0] - 1 + + # 检查每个状态1的持续时间 + for i in range(len(state_starts)): + if state_starts[i] < len(time_points) and state_ends[i] < len(time_points): + # 计算持续时间 + duration = time_points[state_ends[i]] - time_points[state_starts[i]] + if state_ends[i] == len(time_points) - 1: + # 如果是最后一个窗口,加上一个窗口的长度 + duration += time_points[1] - time_points[0] if len(time_points) > 1 else 0 + + # 如果持续时间短于阈值,则移除 + if duration < min_duration_sec: + filtered_sequence[state_starts[i]:state_ends[i] + 1] = 0 + + return filtered_sequence + + +@timing_decorator() +def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, window_second=20, step_second=None): + # 处理标志位长度与 signal_data 对齐 + if calc_mask is None: + calc_mask = np.zeros(len(signal_data), dtype=bool) + + if step_second is None: + step_second = window_second + + step_length = step_second * sampling_rate + window_length = window_second * sampling_rate + + origin_seconds = len(signal_data) // sampling_rate + total_samples = len(signal_data) + + + # reflect padding + left_pad_size = int(window_length // 2) + right_pad_size = window_length - left_pad_size + data = np.pad(signal_data, (left_pad_size, right_pad_size), mode='reflect') + + num_segments = int(np.ceil(len(signal_data) / step_length)) + values_nan = np.full(num_segments, np.nan) + + # print(f"num_segments: {num_segments}, step_length: {step_length}, window_length: {window_length}") + for i in range(num_segments): + # 包含体动则仅计算不含体动部分 + start = int(i * step_length) + end = start + window_length + segment = data[start:end] + values_nan[i] = func(segment) + + + values_nan = values_nan.repeat(step_second)[:origin_seconds] + + for i in range(len(values_nan)): + if calc_mask[i]: + values_nan[i] = np.nan + + values = values_nan.copy() + + # 插值处理体动区域的 NaN 值 + def interpolate_nans(x, t): + valid_mask = ~np.isnan(x) + return np.interp(t, t[valid_mask], x[valid_mask]) + + values = interpolate_nans(values, np.arange(len(values))) + + return values_nan, values + + + + diff --git a/utils/statistics_metrics.py b/utils/statistics_metrics.py new file mode 100644 index 0000000..5c06b19 --- /dev/null +++ b/utils/statistics_metrics.py @@ -0,0 +1,105 @@ +from utils.operation_tools import timing_decorator +import numpy as np +import pandas as pd + +@timing_decorator() +def statistic_amplitude_metrics(data, aml_interval=None, time_interval=None): + """ + 计算不同幅值区间占比和时间,最后汇总成混淆矩阵 + + 参数: + data: 采样率为1秒的一维序列,其中体动所在的区域用np.nan填充 + aml_interval: 幅值区间的分界点列表,默认为[200, 500, 1000, 2000] + time_interval: 时间区间的分界点列表,单位为秒,默认为[60, 300, 1800, 3600] + + 返回: + confusion_matrix: 幅值-时长统计矩阵 + summary: 汇总统计信息 + """ + if aml_interval is None: + aml_interval = [200, 500, 1000, 2000] + + if time_interval is None: + time_interval = [60, 300, 1800, 3600] + # 检查输入 + if not isinstance(data, np.ndarray): + data = np.array(data) + + # 整个记录的时长(包括nan) + total_duration = len(data) + + # 创建幅值标签和时间标签 + amp_labels = [f"0-{aml_interval[0]}"] + for i in range(len(aml_interval) - 1): + amp_labels.append(f"{aml_interval[i]}-{aml_interval[i + 1]}") + amp_labels.append(f"{aml_interval[-1]}+") + + time_labels = [f"0-{time_interval[0]}"] + for i in range(len(time_interval) - 1): + time_labels.append(f"{time_interval[i]}-{time_interval[i + 1]}") + time_labels.append(f"{time_interval[-1]}+") + + # 初始化结果矩阵(时长)和片段数矩阵 + result_matrix = np.zeros((len(amp_labels), len(time_labels))) # 时长矩阵 + segment_count_matrix = np.zeros((len(amp_labels), len(time_labels))) # 片段数矩阵 + + # 有效信号总量(非NaN的数据点数量) + valid_signal_length = np.sum(~np.isnan(data)) + + # 添加信号开始和结束的边界条件 + signal_padded = np.concatenate(([np.nan], data, [np.nan])) + diff = np.diff(np.isnan(signal_padded).astype(int)) + + # 连续片段的起始位置(从 nan 变为非 nan) + segment_starts = np.where(diff == -1)[0] + # 连续片段的结束位置(从非 nan 变为 nan) + segment_ends = np.where(diff == 1)[0] + + # 计算每个片段的时长和平均幅值,并填充结果矩阵 + for start, end in zip(segment_starts, segment_ends): + segment = data[start:end] + duration = end - start # 时长(单位:秒) + mean_amplitude = np.nanmean(segment) # 片段平均幅值 + + # 确定幅值区间 + if mean_amplitude <= aml_interval[0]: + amp_idx = 0 + elif mean_amplitude > aml_interval[-1]: + amp_idx = len(aml_interval) + else: + amp_idx = np.searchsorted(aml_interval, mean_amplitude) + + # 确定时长区间 + if duration <= time_interval[0]: + time_idx = 0 + elif duration > time_interval[-1]: + time_idx = len(time_interval) + else: + time_idx = np.searchsorted(time_interval, duration) + + # 在对应位置累加该片段的时长和片段数 + result_matrix[amp_idx, time_idx] += duration + segment_count_matrix[amp_idx, time_idx] += 1 # 片段数加1 + + # 创建DataFrame以便于展示和后续处理 + confusion_matrix = pd.DataFrame(result_matrix, index=amp_labels, columns=time_labels) + + # 计算行和列的总和 + confusion_matrix['总计'] = confusion_matrix.sum(axis=1) + row_totals = confusion_matrix['总计'].copy() + + # 计算百分比(相对于有效记录时长) + confusion_matrix_percent = confusion_matrix.div(total_duration) * 100 + + # 汇总统计 + summary = { + 'total_duration': total_duration, + 'total_valid_signal': valid_signal_length, + 'amplitude_distribution': row_totals.to_dict(), + 'amplitude_percent': row_totals.div(total_duration) * 100, + 'time_distribution': confusion_matrix.sum(axis=0).to_dict(), + 'time_percent': confusion_matrix.sum(axis=0).div(total_duration) * 100 + } + + return summary, (confusion_matrix, segment_count_matrix, confusion_matrix_percent, valid_signal_length, + total_duration, time_labels, amp_labels)