Refactor signal processing configurations and improve mask generation logic

This commit is contained in:
marques 2025-10-30 15:46:08 +08:00
parent 9fdbc4a1cb
commit 965f88843a
5 changed files with 41 additions and 49 deletions

View File

@ -28,7 +28,7 @@ import numpy as np
import signal_method import signal_method
import os import os
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
os.environ['DISPLAY'] = "localhost:10.0" os.environ['DISPLAY'] = "localhost:11.0"
def process_one_signal(samp_id): def process_one_signal(samp_id):
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
@ -98,7 +98,7 @@ def process_one_signal(samp_id):
bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs) bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs)
label_data = utils.read_label_csv(path=label_path) label_data = utils.read_label_csv(path=label_path)
label_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data) event_mask, score_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data)
manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[ manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[
all_samp_disable_df["id"] == samp_id]) all_samp_disable_df["id"] == samp_id])
@ -110,11 +110,7 @@ def process_one_signal(samp_id):
resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal( resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal(
signal_data=resp_data, signal_data=resp_data,
sampling_rate=resp_fs, sampling_rate=resp_fs,
window_size_sec=resp_low_amp_conf["window_size_sec"], **resp_low_amp_conf
stride_sec=resp_low_amp_conf["stride_sec"],
amplitude_threshold=resp_low_amp_conf["amplitude_threshold"],
merge_gap_sec=resp_low_amp_conf["merge_gap_sec"],
min_duration_sec=resp_low_amp_conf["min_duration_sec"]
) )
print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}") print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}")
else: else:
@ -127,13 +123,7 @@ def process_one_signal(samp_id):
raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement( raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement(
signal_data=resp_data, signal_data=resp_data,
sampling_rate=resp_fs, sampling_rate=resp_fs,
window_size_sec=resp_movement_conf["window_size_sec"], **resp_movement_conf
stride_sec=resp_movement_conf["stride_sec"],
std_median_multiplier=resp_movement_conf["std_median_multiplier"],
compare_intervals_sec=resp_movement_conf["compare_intervals_sec"],
interval_multiplier=resp_movement_conf["interval_multiplier"],
merge_gap_sec=resp_movement_conf["merge_gap_sec"],
min_duration_sec=resp_movement_conf["min_duration_sec"]
) )
print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}") print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}")
else: else:
@ -159,11 +149,7 @@ def process_one_signal(samp_id):
bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal( bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal(
signal_data=bcg_data, signal_data=bcg_data,
sampling_rate=bcg_fs, sampling_rate=bcg_fs,
window_size_sec=bcg_low_amp_conf["window_size_sec"], **bcg_low_amp_conf
stride_sec=bcg_low_amp_conf["stride_sec"],
amplitude_threshold=bcg_low_amp_conf["amplitude_threshold"],
merge_gap_sec=bcg_low_amp_conf["merge_gap_sec"],
min_duration_sec=bcg_low_amp_conf["min_duration_sec"]
) )
print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}") print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}")
else: else:
@ -175,13 +161,7 @@ def process_one_signal(samp_id):
raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement( raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement(
signal_data=bcg_data, signal_data=bcg_data,
sampling_rate=bcg_fs, sampling_rate=bcg_fs,
window_size_sec=bcg_movement_conf["window_size_sec"], **bcg_movement_conf
stride_sec=bcg_movement_conf["stride_sec"],
std_median_multiplier=bcg_movement_conf["std_median_multiplier"],
compare_intervals_sec=bcg_movement_conf["compare_intervals_sec"],
interval_multiplier=bcg_movement_conf["interval_multiplier"],
merge_gap_sec=bcg_movement_conf["merge_gap_sec"],
min_duration_sec=bcg_movement_conf["min_duration_sec"]
) )
print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}") print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}")
else: else:
@ -215,7 +195,7 @@ def process_one_signal(samp_id):
resp_low_amp_mask=resp_low_amp_mask, resp_low_amp_mask=resp_low_amp_mask,
resp_movement_mask=resp_movement_mask, resp_movement_mask=resp_movement_mask,
resp_change_mask=resp_amp_change_mask, resp_change_mask=resp_amp_change_mask,
resp_sa_mask=None, resp_sa_mask=event_mask,
bcg_low_amp_mask=bcg_low_amp_mask, bcg_low_amp_mask=bcg_low_amp_mask,
bcg_movement_mask=bcg_movement_mask, bcg_movement_mask=bcg_movement_mask,
bcg_change_mask=bcg_amp_change_mask) bcg_change_mask=bcg_amp_change_mask)

View File

@ -20,25 +20,25 @@ resp_filter:
filter_type: bandpass filter_type: bandpass
low_cut: 0.01 low_cut: 0.01
high_cut: 0.7 high_cut: 0.7
order: 2 order: 3
resp_low_amp: resp_low_amp:
window_size_sec: 1 window_size_sec: 30
stride_sec: stride_sec:
amplitude_threshold: 20 amplitude_threshold: 5
merge_gap_sec: 10 merge_gap_sec: 180
min_duration_sec: 5 min_duration_sec: 30
resp_movement: resp_movement:
window_size_sec: 2 window_size_sec: 30
stride_sec: stride_sec: 5
std_median_multiplier: 4.5 std_median_multiplier: 5
compare_intervals_sec: compare_intervals_sec:
- 30
- 60 - 60
interval_multiplier: 2.5 - 90
merge_gap_sec: 10 interval_multiplier: 3.5
min_duration_sec: 5 merge_gap_sec: 45
min_duration_sec: 10
bcg: bcg:
downsample_fs: 100 downsample_fs: 100
@ -49,3 +49,17 @@ bcg_filter:
high_cut: 10 high_cut: 10
order: 10 order: 10
bcg_low_amp:
window_size_sec: 1
stride_sec:
amplitude_threshold: 10
merge_gap_sec: 20
min_duration_sec: 3
bcg_movement:
window_size_sec: 2
stride_sec:
merge_gap_sec: 20
min_duration_sec: 4

View File

@ -188,8 +188,8 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs,
if mask is None: if mask is None:
return np.full_like(ref, np.nan) return np.full_like(ref, np.nan)
else: else:
# 将mask中的0替换为nan1替换为1 # 将mask中的0替换为nan其他的保持
mask = np.where(mask == 0, np.nan, 1) mask = np.where(mask == 0, np.nan, mask)
return mask return mask
signal_disable_mask = _none_to_nan_mask(signal_disable_mask, signal_data) signal_disable_mask = _none_to_nan_mask(signal_disable_mask, signal_data)
@ -224,7 +224,7 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs,
ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='orange') ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='orange')
ax1.set_ylabel('Amplitude') ax1.set_ylabel('Amplitude')
ax1.set_xticklabels([]) # ax1.set_xticklabels([])
ax1_twin = ax1.twinx() ax1_twin = ax1.twinx()
ax1_twin.plot(np.linspace(0, len(resp_low_amp_mask), len(resp_low_amp_mask)), resp_low_amp_mask*-1, ax1_twin.plot(np.linspace(0, len(resp_low_amp_mask), len(resp_low_amp_mask)), resp_low_amp_mask*-1,
color='blue', alpha=0.5, label='Low Amplitude Mask') color='blue', alpha=0.5, label='Low Amplitude Mask')

View File

@ -421,7 +421,7 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat
energy = np.sum(np.abs(signal_data[start:end] ** 2)) energy = np.sum(np.abs(signal_data[start:end] ** 2))
segment_average_energy.append(energy) segment_average_energy.append(energy)
position_changes = [] position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int)
position_change_times = [] position_change_times = []
# 判断是否存在显著变化 (可根据实际情况调整阈值) # 判断是否存在显著变化 (可根据实际情况调整阈值)
threshold_amplitude = 0.1 # 幅值变化阈值 threshold_amplitude = 0.1 # 幅值变化阈值
@ -440,9 +440,7 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat
if significant_change: if significant_change:
# 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示 # 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示
position_changes.append(1) position_changes[movement_start[i - 1]:movement_end[i - 1]] = 1
position_change_times.append((movement_start[i - 1], movement_end[i - 1])) position_change_times.append((movement_start[i - 1], movement_end[i - 1]))
else:
position_changes.append(0) # 0表示不存在姿势变化
return np.array(position_changes), position_change_times return position_changes, position_change_times

View File

@ -184,12 +184,12 @@ def load_dataset_conf(yaml_path):
def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray: def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray:
disable_mask = np.ones(signal_second, dtype=int) disable_mask = np.zeros(signal_second, dtype=int)
for _, row in disable_df.iterrows(): for _, row in disable_df.iterrows():
start = row["start"] start = row["start"]
end = row["end"] end = row["end"]
disable_mask[start:end] = 0 disable_mask[start:end] = 1
return disable_mask return disable_mask