Refactor imports in __init__.py, enhance resp_movement handling in HYS_process.py, and update HYS_config.yaml for movement revision parameters
This commit is contained in:
parent
998890377b
commit
2a2604a323
@ -28,7 +28,7 @@ import numpy as np
|
|||||||
import signal_method
|
import signal_method
|
||||||
import os
|
import os
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
os.environ['DISPLAY'] = "localhost:10.0"
|
os.environ['DISPLAY'] = "localhost:14.0"
|
||||||
|
|
||||||
def process_one_signal(samp_id):
|
def process_one_signal(samp_id):
|
||||||
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
|
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
|
||||||
@ -127,33 +127,23 @@ def process_one_signal(samp_id):
|
|||||||
)
|
)
|
||||||
print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}")
|
print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}")
|
||||||
else:
|
else:
|
||||||
resp_movement_mask = None
|
resp_movement_mask, resp_movement_position_list = None, None
|
||||||
print("resp_movement_mask is None")
|
print("resp_movement_mask is None")
|
||||||
|
|
||||||
if resp_movement_mask is not None:
|
resp_movement_revise_conf = conf.get("resp_movement_revise", None)
|
||||||
# 左右翻转resp_data
|
if resp_movement_mask is not None and resp_movement_revise_conf is not None:
|
||||||
reverse_resp_data = resp_data[::-1]
|
resp_movement_mask, resp_movement_position_list = signal_method.movement_revise(
|
||||||
_, resp_movement_mask_reverse, _, resp_movement_position_list_reverse = signal_method.detect_movement(
|
signal_data=resp_data,
|
||||||
signal_data=reverse_resp_data,
|
movement_mask=resp_movement_mask,
|
||||||
|
movement_list=resp_movement_position_list,
|
||||||
sampling_rate=resp_fs,
|
sampling_rate=resp_fs,
|
||||||
**resp_movement_conf
|
**resp_movement_revise_conf
|
||||||
)
|
)
|
||||||
print(f"resp_movement_mask_reverse_shape: {resp_movement_mask_reverse.shape}, num_movement_reverse: {np.sum(resp_movement_mask_reverse == 1)}, count_movement_positions_reverse: {len(resp_movement_position_list_reverse)}")
|
print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}")
|
||||||
# 将resp_movement_mask_reverse翻转回来
|
|
||||||
resp_movement_mask_reverse = resp_movement_mask_reverse[::-1]
|
|
||||||
else:
|
else:
|
||||||
resp_movement_mask_reverse = None
|
print("resp_movement_mask revise is skipped")
|
||||||
print("resp_movement_mask_reverse is None")
|
|
||||||
|
|
||||||
|
|
||||||
# 取交集
|
|
||||||
if resp_movement_mask is not None and resp_movement_mask_reverse is not None:
|
|
||||||
combined_resp_movement_mask = np.logical_and(resp_movement_mask, resp_movement_mask_reverse).astype(int)
|
|
||||||
resp_movement_mask = combined_resp_movement_mask
|
|
||||||
print(f"combined_resp_movement_mask_shape: {combined_resp_movement_mask.shape}, num_combined_movement: {np.sum(combined_resp_movement_mask == 1)}")
|
|
||||||
else:
|
|
||||||
print("combined_resp_movement_mask is None")
|
|
||||||
|
|
||||||
|
|
||||||
# 分析Resp的幅值突变区间
|
# 分析Resp的幅值突变区间
|
||||||
if resp_movement_mask is not None:
|
if resp_movement_mask is not None:
|
||||||
@ -246,4 +236,4 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
all_samp_disable_df = utils.read_disable_excel(disable_df_path)
|
all_samp_disable_df = utils.read_disable_excel(disable_df_path)
|
||||||
|
|
||||||
process_one_signal(select_ids[5])
|
process_one_signal(select_ids[0])
|
||||||
|
|||||||
@ -32,13 +32,21 @@ resp_low_amp:
|
|||||||
resp_movement:
|
resp_movement:
|
||||||
window_size_sec: 20
|
window_size_sec: 20
|
||||||
stride_sec: 1
|
stride_sec: 1
|
||||||
std_median_multiplier: 3.5
|
std_median_multiplier: 5
|
||||||
compare_intervals_sec:
|
compare_intervals_sec:
|
||||||
- 60
|
- 60
|
||||||
- 90
|
- 120
|
||||||
|
- 180
|
||||||
interval_multiplier: 3.5
|
interval_multiplier: 3.5
|
||||||
merge_gap_sec: 30
|
merge_gap_sec: 30
|
||||||
min_duration_sec: 2
|
min_duration_sec: 1
|
||||||
|
|
||||||
|
resp_movement_revise:
|
||||||
|
up_interval_multiplier: 3
|
||||||
|
down_interval_multiplier: 1.5
|
||||||
|
compare_intervals_sec: 30
|
||||||
|
merge_gap_sec: 10
|
||||||
|
min_duration_sec: 1
|
||||||
|
|
||||||
bcg:
|
bcg:
|
||||||
downsample_fs: 100
|
downsample_fs: 100
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
from draw_tools.draw_statics import draw_signal_with_mask
|
from .draw_statics import draw_signal_with_mask
|
||||||
@ -222,7 +222,10 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs,
|
|||||||
|
|
||||||
|
|
||||||
ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
|
ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
|
||||||
ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='orange')
|
ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='gray', alpha=0.5)
|
||||||
|
resp_data_no_movement = resp_data.copy()
|
||||||
|
resp_data_no_movement[resp_movement_mask.repeat(int(resp_fs)) == 1] = np.nan
|
||||||
|
ax1.plot(np.linspace(0, len(resp_data_no_movement) // resp_fs, len(resp_data_no_movement)), resp_data_no_movement, color='orange')
|
||||||
ax1.set_ylabel('Amplitude')
|
ax1.set_ylabel('Amplitude')
|
||||||
# ax1.set_xticklabels([])
|
# ax1.set_xticklabels([])
|
||||||
ax1_twin = ax1.twinx()
|
ax1_twin = ax1.twinx()
|
||||||
|
|||||||
@ -1 +1,3 @@
|
|||||||
from signal_method.rule_base_event import detect_low_amplitude_signal, detect_movement, position_based_sleep_recognition_v2
|
from .rule_base_event import detect_low_amplitude_signal, detect_movement, position_based_sleep_recognition_v2
|
||||||
|
from .rule_base_event import movement_revise
|
||||||
|
from .time_metrics import calc_mav
|
||||||
@ -1,6 +1,7 @@
|
|||||||
from utils.operation_tools import timing_decorator
|
from utils.operation_tools import timing_decorator
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from utils import merge_short_gaps, remove_short_durations, event_mask_2_list
|
from utils import merge_short_gaps, remove_short_durations, event_mask_2_list, collect_values
|
||||||
|
from signal_method.time_metrics import calc_mav
|
||||||
|
|
||||||
|
|
||||||
@timing_decorator()
|
@timing_decorator()
|
||||||
@ -90,16 +91,16 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No
|
|||||||
# else:
|
# else:
|
||||||
# valid_std = original_window_std
|
# valid_std = original_window_std
|
||||||
|
|
||||||
valid_std = original_window_std ##20250418新修改
|
valid_std = original_window_std ##20250418新修改
|
||||||
|
|
||||||
#---------------------- 方法一:基于STD的体动判定 ----------------------#
|
# ---------------------- 方法一:基于STD的体动判定 ----------------------#
|
||||||
# 计算所有有效窗口标准差的中位数
|
# 计算所有有效窗口标准差的中位数
|
||||||
median_std = np.median(valid_std)
|
median_std = np.median(valid_std)
|
||||||
|
|
||||||
# 当窗口标准差大于中位数的倍数,判定为体动状态
|
# 当窗口标准差大于中位数的倍数,判定为体动状态
|
||||||
std_movement = np.where(original_window_std > median_std * std_median_multiplier, 1, 0)
|
std_movement = np.where((original_window_std > (median_std * std_median_multiplier)), 1, 0)
|
||||||
|
|
||||||
#------------------ 方法二:基于前后信号幅值变化的体动判定 ------------------#
|
# ------------------ 方法二:基于前后信号幅值变化的体动判定 ------------------#
|
||||||
amplitude_movement = np.zeros(num_original_windows, dtype=int)
|
amplitude_movement = np.zeros(num_original_windows, dtype=int)
|
||||||
|
|
||||||
# 定义基于时间粒度的比较间隔索引
|
# 定义基于时间粒度的比较间隔索引
|
||||||
@ -146,7 +147,6 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No
|
|||||||
raw_movement_mask = raw_movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate]
|
raw_movement_mask = raw_movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate]
|
||||||
movement_mask = movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate]
|
movement_mask = movement_mask.repeat(stride_sec)[:len(signal_data) // sampling_rate]
|
||||||
|
|
||||||
|
|
||||||
# 比较剔除的体动,如果被剔除的体动所在区域有高于3std的幅值,则不剔除
|
# 比较剔除的体动,如果被剔除的体动所在区域有高于3std的幅值,则不剔除
|
||||||
removed_movement_mask = (raw_movement_mask - movement_mask) > 0
|
removed_movement_mask = (raw_movement_mask - movement_mask) > 0
|
||||||
removed_movement_start = np.where(np.diff(np.concatenate([[0], removed_movement_mask])) == 1)[0]
|
removed_movement_start = np.where(np.diff(np.concatenate([[0], removed_movement_mask])) == 1)[0]
|
||||||
@ -155,8 +155,8 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No
|
|||||||
for start, end in zip(removed_movement_start, removed_movement_end):
|
for start, end in zip(removed_movement_start, removed_movement_end):
|
||||||
# print(start ,end)
|
# print(start ,end)
|
||||||
# 计算剔除的体动区域的幅值
|
# 计算剔除的体动区域的幅值
|
||||||
if np.nanmax(signal_data[start*sampling_rate:(end+1)*sampling_rate]) > median_std * std_median_multiplier:
|
if np.nanmax(signal_data[start * sampling_rate:(end + 1) * sampling_rate]) > median_std * std_median_multiplier:
|
||||||
movement_mask[start:end+1] = 1
|
movement_mask[start:end + 1] = 1
|
||||||
|
|
||||||
# raw体动起止位置 [[start, end], [start, end], ...]
|
# raw体动起止位置 [[start, end], [start, end], ...]
|
||||||
raw_movement_position_list = event_mask_2_list(raw_movement_mask)
|
raw_movement_position_list = event_mask_2_list(raw_movement_mask)
|
||||||
@ -164,25 +164,70 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No
|
|||||||
# merge体动起止位置 [[start, end], [start, end], ...]
|
# merge体动起止位置 [[start, end], [start, end], ...]
|
||||||
movement_position_list = event_mask_2_list(movement_mask)
|
movement_position_list = event_mask_2_list(movement_mask)
|
||||||
|
|
||||||
return raw_movement_mask, movement_mask, raw_movement_position_list, movement_position_list
|
return raw_movement_mask, movement_mask, raw_movement_position_list, movement_position_list
|
||||||
|
|
||||||
|
|
||||||
|
def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up_interval_multiplier: float,
|
||||||
|
down_interval_multiplier: float, compare_intervals_sec, merge_gap_sec, min_duration_sec):
|
||||||
def movement_revise(signal_data, sampling_rate, movement_mask, std_median_multiplier=4.5):
|
|
||||||
"""
|
"""
|
||||||
基于标准差对已有体动掩码进行修正。 用于大尺度的体动检测后的位置修正
|
基于标准差对已有体动掩码进行修正。 用于大尺度的体动检测后的位置精细修正
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- signal_data: numpy array,输入的信号数据
|
- signal_data: numpy array,输入的信号数据
|
||||||
- sampling_rate: int,信号的采样率(Hz)
|
- sampling_rate: int,信号的采样率(Hz)
|
||||||
- movement_mask: numpy array,已有的体动掩码(1表示体动,0表示睡眠)
|
- movement_mask: numpy array,已有的体动掩码(1表示体动,0表示睡眠)
|
||||||
- std_median_multiplier: float,标准差中位数的乘数阈值,默认值为 4.5
|
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
- revised_movement_mask: numpy array,修正后的体动掩码
|
- revised_movement_mask: numpy array,修正后的体动掩码
|
||||||
"""
|
"""
|
||||||
pass
|
window_size = sampling_rate
|
||||||
|
stride_size = sampling_rate
|
||||||
|
|
||||||
|
time_points = np.arange(len(signal_data))
|
||||||
|
|
||||||
|
compare_size = int(compare_intervals_sec // (stride_size / sampling_rate))
|
||||||
|
|
||||||
|
_, mav = calc_mav(signal_data, movement_mask=None, low_amp_mask=None, sampling_rate=sampling_rate,
|
||||||
|
window_second=2, step_second=1,
|
||||||
|
inner_window_second=1)
|
||||||
|
|
||||||
|
# 往左右两边取compare_size个点的mav,取平均值
|
||||||
|
for start, end in movement_list:
|
||||||
|
left_values = collect_values(arr=mav, index=start - 1, step=-1, limit=compare_size, mask=movement_mask)
|
||||||
|
right_values = collect_values(arr=mav, index=end + 5, step=1, limit=compare_size, mask=movement_mask)
|
||||||
|
left_value_metrics = np.median(left_values) if len(left_values) > 0 else 0
|
||||||
|
right_value_metrics = np.median(right_values) if len(right_values) > 0 else 0
|
||||||
|
if left_value_metrics == 0:
|
||||||
|
value_metrics = right_value_metrics
|
||||||
|
elif right_value_metrics == 0:
|
||||||
|
value_metrics = left_value_metrics
|
||||||
|
else:
|
||||||
|
value_metrics = np.mean([left_value_metrics, right_value_metrics])
|
||||||
|
|
||||||
|
# 逐秒遍历mav,判断是否需要修正
|
||||||
|
# print(f"Revising movement from index {start} to {end}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}")
|
||||||
|
for i in range(start, end + 5):
|
||||||
|
# print(f"Index {i}, mav: {mav[i]:.2f}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}")
|
||||||
|
if mav[i] > (value_metrics * up_interval_multiplier):
|
||||||
|
movement_mask[i] = 1
|
||||||
|
# print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * up_interval_multiplier:.2f}")
|
||||||
|
elif mav[i] < (value_metrics * down_interval_multiplier):
|
||||||
|
movement_mask[i] = 0
|
||||||
|
# print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * down_interval_multiplier:.2f}")
|
||||||
|
# else:
|
||||||
|
# print(f"No revision at index {i}, mav: {mav[i]:.2f}, up_threshold: {value_mean * up_interval_multiplier:.2f}, down_threshold: {value_mean * down_interval_multiplier:.2f}")
|
||||||
|
#
|
||||||
|
# 如果需要合并间隔小的体动状态
|
||||||
|
if merge_gap_sec > 0:
|
||||||
|
movement_mask = merge_short_gaps(movement_mask, time_points, merge_gap_sec)
|
||||||
|
|
||||||
|
# 如果需要移除短时体动状态
|
||||||
|
if min_duration_sec > 0:
|
||||||
|
movement_mask = remove_short_durations(movement_mask, time_points, min_duration_sec)
|
||||||
|
|
||||||
|
movement_list = event_mask_2_list(movement_mask)
|
||||||
|
return movement_mask, movement_list
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -335,10 +380,10 @@ def position_based_sleep_recognition_v1(signal_data, movement_mask, sampling_rat
|
|||||||
# 新的end - start确保为200的整数倍
|
# 新的end - start确保为200的整数倍
|
||||||
if (left_end - left_start) % (mav_calc_window_sec * sampling_rate) != 0:
|
if (left_end - left_start) % (mav_calc_window_sec * sampling_rate) != 0:
|
||||||
left_end = left_start + ((left_end - left_start) // (mav_calc_window_sec * sampling_rate)) * (
|
left_end = left_start + ((left_end - left_start) // (mav_calc_window_sec * sampling_rate)) * (
|
||||||
mav_calc_window_sec * sampling_rate)
|
mav_calc_window_sec * sampling_rate)
|
||||||
if (right_end - right_start) % (mav_calc_window_sec * sampling_rate) != 0:
|
if (right_end - right_start) % (mav_calc_window_sec * sampling_rate) != 0:
|
||||||
right_end = right_start + ((right_end - right_start) // (mav_calc_window_sec * sampling_rate)) * (
|
right_end = right_start + ((right_end - right_start) // (mav_calc_window_sec * sampling_rate)) * (
|
||||||
mav_calc_window_sec * sampling_rate)
|
mav_calc_window_sec * sampling_rate)
|
||||||
|
|
||||||
# 计算每个片段的幅值指标
|
# 计算每个片段的幅值指标
|
||||||
left_mav = np.mean(np.max(signal_data[left_start:left_end].reshape(-1, mav_calc_window_sec * sampling_rate),
|
left_mav = np.mean(np.max(signal_data[left_start:left_end].reshape(-1, mav_calc_window_sec * sampling_rate),
|
||||||
@ -431,11 +476,12 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat
|
|||||||
# 新的end - start确保为200的整数倍
|
# 新的end - start确保为200的整数倍
|
||||||
if (end - start) % (mav_calc_window_sec * sampling_rate) != 0:
|
if (end - start) % (mav_calc_window_sec * sampling_rate) != 0:
|
||||||
end = start + ((end - start) // (mav_calc_window_sec * sampling_rate)) * (
|
end = start + ((end - start) // (mav_calc_window_sec * sampling_rate)) * (
|
||||||
mav_calc_window_sec * sampling_rate)
|
mav_calc_window_sec * sampling_rate)
|
||||||
|
|
||||||
# 计算每个片段的幅值指标
|
# 计算每个片段的幅值指标
|
||||||
mav = np.nanmean(
|
mav = np.nanmean(
|
||||||
np.nanmax(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0)) - np.nanmean(
|
np.nanmax(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate),
|
||||||
|
axis=0)) - np.nanmean(
|
||||||
np.nanmin(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0))
|
np.nanmin(signal_data_no_movement[start:end].reshape(-1, mav_calc_window_sec * sampling_rate), axis=0))
|
||||||
segment_average_amplitude.append(mav)
|
segment_average_amplitude.append(mav)
|
||||||
|
|
||||||
|
|||||||
@ -5,10 +5,14 @@ import numpy as np
|
|||||||
|
|
||||||
@timing_decorator()
|
@timing_decorator()
|
||||||
def calc_mav(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window_second=10, step_second=1, inner_window_second=2):
|
def calc_mav(signal_data, movement_mask, low_amp_mask, sampling_rate=100, window_second=10, step_second=1, inner_window_second=2):
|
||||||
assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}"
|
if movement_mask is not None:
|
||||||
assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}"
|
assert len(movement_mask) * sampling_rate == len(signal_data), f"movement_mask 长度与 signal_data 长度不一致, {len(movement_mask) * sampling_rate} != {len(signal_data)}"
|
||||||
# print(f"movement_mask_length: {len(movement_mask)}, signal_data_length: {len(signal_data)}")
|
# assert len(movement_mask) == len(low_amp_mask), f"movement_mask 和 low_amp_mask 长度不一致, {len(movement_mask)} != {len(low_amp_mask)}"
|
||||||
processed_mask = movement_mask.copy()
|
# print(f"movement_mask_length: {len(movement_mask)}, signal_data_length: {len(signal_data)}")
|
||||||
|
processed_mask = movement_mask.copy()
|
||||||
|
else:
|
||||||
|
processed_mask = None
|
||||||
|
|
||||||
def mav_func(x):
|
def mav_func(x):
|
||||||
return np.mean(np.nanmax(x.reshape(-1, inner_window_second*sampling_rate), axis=1) - np.nanmin(x.reshape(-1, inner_window_second*sampling_rate), axis=1)) / 2
|
return np.mean(np.nanmax(x.reshape(-1, inner_window_second*sampling_rate), axis=1) - np.nanmin(x.reshape(-1, inner_window_second*sampling_rate), axis=1)) / 2
|
||||||
mav_nan, mav = calculate_by_slide_windows(mav_func, signal_data, processed_mask, sampling_rate=sampling_rate,
|
mav_nan, mav = calculate_by_slide_windows(mav_func, signal_data, processed_mask, sampling_rate=sampling_rate,
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from utils.HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel
|
from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel
|
||||||
from utils.operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list
|
from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list
|
||||||
from utils.operation_tools import merge_short_gaps, remove_short_durations
|
from .operation_tools import merge_short_gaps, remove_short_durations
|
||||||
from utils.event_map import E2N
|
from .operation_tools import collect_values
|
||||||
from utils.signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel
|
from .event_map import E2N
|
||||||
|
from .signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel
|
||||||
@ -125,8 +125,8 @@ def remove_short_durations(state_sequence, time_points, min_duration_sec):
|
|||||||
@timing_decorator()
|
@timing_decorator()
|
||||||
def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, window_second=20, step_second=None):
|
def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100, window_second=20, step_second=None):
|
||||||
# 处理标志位长度与 signal_data 对齐
|
# 处理标志位长度与 signal_data 对齐
|
||||||
if calc_mask is None:
|
# if calc_mask is None:
|
||||||
calc_mask = np.zeros(len(signal_data), dtype=bool)
|
# calc_mask = np.zeros(len(signal_data), dtype=bool)
|
||||||
|
|
||||||
if step_second is None:
|
if step_second is None:
|
||||||
step_second = window_second
|
step_second = window_second
|
||||||
@ -157,18 +157,21 @@ def calculate_by_slide_windows(func, signal_data, calc_mask, sampling_rate=100,
|
|||||||
|
|
||||||
values_nan = values_nan.repeat(step_second)[:origin_seconds]
|
values_nan = values_nan.repeat(step_second)[:origin_seconds]
|
||||||
|
|
||||||
for i in range(len(values_nan)):
|
if calc_mask is not None:
|
||||||
if calc_mask[i]:
|
for i in range(len(values_nan)):
|
||||||
values_nan[i] = np.nan
|
if calc_mask[i]:
|
||||||
|
values_nan[i] = np.nan
|
||||||
|
|
||||||
values = values_nan.copy()
|
values = values_nan.copy()
|
||||||
|
|
||||||
# 插值处理体动区域的 NaN 值
|
# 插值处理体动区域的 NaN 值
|
||||||
def interpolate_nans(x, t):
|
def interpolate_nans(x, t):
|
||||||
valid_mask = ~np.isnan(x)
|
valid_mask = ~np.isnan(x)
|
||||||
return np.interp(t, t[valid_mask], x[valid_mask])
|
return np.interp(t, t[valid_mask], x[valid_mask])
|
||||||
|
|
||||||
values = interpolate_nans(values, np.arange(len(values)))
|
values = interpolate_nans(values, np.arange(len(values)))
|
||||||
|
else:
|
||||||
|
values = values_nan.copy()
|
||||||
|
|
||||||
return values_nan, values
|
return values_nan, values
|
||||||
|
|
||||||
@ -208,7 +211,20 @@ def generate_event_mask(signal_second: int, event_df):
|
|||||||
|
|
||||||
|
|
||||||
def event_mask_2_list(mask):
|
def event_mask_2_list(mask):
|
||||||
mask_start = np.where(np.diff(mask, append=0) == 1)[0]
|
mask_start = np.where(np.diff(mask, append=0) == -1)[0]
|
||||||
mask_end = np.where(np.diff(mask, append=0) == -1)[0] + 1
|
mask_end = np.where(np.diff(mask, append=0) == 1)[0] + 1
|
||||||
event_list =[[start, end] for start, end in zip(mask_start, mask_end)]
|
event_list =[[start, end] for start, end in zip(mask_start, mask_end)]
|
||||||
return event_list
|
return event_list
|
||||||
|
|
||||||
|
|
||||||
|
def collect_values(arr: np.ndarray, index: int, step: int, limit: int, mask=None) -> list:
|
||||||
|
"""收集非 NaN 值,直到达到指定数量或边界"""
|
||||||
|
values = []
|
||||||
|
count = 0
|
||||||
|
mask = mask if mask is not None else arr
|
||||||
|
while count < limit and 0 <= index < len(mask):
|
||||||
|
if not np.isnan(mask[index]):
|
||||||
|
values.append(arr[index])
|
||||||
|
count += 1
|
||||||
|
index += step
|
||||||
|
return values
|
||||||
Loading…
Reference in New Issue
Block a user