Refactor signal processing configurations and improve mask generation logic

This commit is contained in:
marques 2025-10-30 15:46:08 +08:00
parent 9fdbc4a1cb
commit 965f88843a
5 changed files with 41 additions and 49 deletions

View File

@ -28,7 +28,7 @@ import numpy as np
import signal_method
import os
from matplotlib import pyplot as plt
os.environ['DISPLAY'] = "localhost:10.0"
os.environ['DISPLAY'] = "localhost:11.0"
def process_one_signal(samp_id):
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
@ -98,7 +98,7 @@ def process_one_signal(samp_id):
bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs)
label_data = utils.read_label_csv(path=label_path)
label_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data)
event_mask, score_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data)
manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[
all_samp_disable_df["id"] == samp_id])
@ -110,11 +110,7 @@ def process_one_signal(samp_id):
resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal(
signal_data=resp_data,
sampling_rate=resp_fs,
window_size_sec=resp_low_amp_conf["window_size_sec"],
stride_sec=resp_low_amp_conf["stride_sec"],
amplitude_threshold=resp_low_amp_conf["amplitude_threshold"],
merge_gap_sec=resp_low_amp_conf["merge_gap_sec"],
min_duration_sec=resp_low_amp_conf["min_duration_sec"]
**resp_low_amp_conf
)
print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}")
else:
@ -127,13 +123,7 @@ def process_one_signal(samp_id):
raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement(
signal_data=resp_data,
sampling_rate=resp_fs,
window_size_sec=resp_movement_conf["window_size_sec"],
stride_sec=resp_movement_conf["stride_sec"],
std_median_multiplier=resp_movement_conf["std_median_multiplier"],
compare_intervals_sec=resp_movement_conf["compare_intervals_sec"],
interval_multiplier=resp_movement_conf["interval_multiplier"],
merge_gap_sec=resp_movement_conf["merge_gap_sec"],
min_duration_sec=resp_movement_conf["min_duration_sec"]
**resp_movement_conf
)
print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}")
else:
@ -159,11 +149,7 @@ def process_one_signal(samp_id):
bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal(
signal_data=bcg_data,
sampling_rate=bcg_fs,
window_size_sec=bcg_low_amp_conf["window_size_sec"],
stride_sec=bcg_low_amp_conf["stride_sec"],
amplitude_threshold=bcg_low_amp_conf["amplitude_threshold"],
merge_gap_sec=bcg_low_amp_conf["merge_gap_sec"],
min_duration_sec=bcg_low_amp_conf["min_duration_sec"]
**bcg_low_amp_conf
)
print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}")
else:
@ -175,13 +161,7 @@ def process_one_signal(samp_id):
raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement(
signal_data=bcg_data,
sampling_rate=bcg_fs,
window_size_sec=bcg_movement_conf["window_size_sec"],
stride_sec=bcg_movement_conf["stride_sec"],
std_median_multiplier=bcg_movement_conf["std_median_multiplier"],
compare_intervals_sec=bcg_movement_conf["compare_intervals_sec"],
interval_multiplier=bcg_movement_conf["interval_multiplier"],
merge_gap_sec=bcg_movement_conf["merge_gap_sec"],
min_duration_sec=bcg_movement_conf["min_duration_sec"]
**bcg_movement_conf
)
print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}")
else:
@ -215,7 +195,7 @@ def process_one_signal(samp_id):
resp_low_amp_mask=resp_low_amp_mask,
resp_movement_mask=resp_movement_mask,
resp_change_mask=resp_amp_change_mask,
resp_sa_mask=None,
resp_sa_mask=event_mask,
bcg_low_amp_mask=bcg_low_amp_mask,
bcg_movement_mask=bcg_movement_mask,
bcg_change_mask=bcg_amp_change_mask)

View File

@ -20,25 +20,25 @@ resp_filter:
filter_type: bandpass
low_cut: 0.01
high_cut: 0.7
order: 2
order: 3
resp_low_amp:
window_size_sec: 1
window_size_sec: 30
stride_sec:
amplitude_threshold: 20
merge_gap_sec: 10
min_duration_sec: 5
amplitude_threshold: 5
merge_gap_sec: 180
min_duration_sec: 30
resp_movement:
window_size_sec: 2
stride_sec:
std_median_multiplier: 4.5
window_size_sec: 30
stride_sec: 5
std_median_multiplier: 5
compare_intervals_sec:
- 30
- 60
interval_multiplier: 2.5
merge_gap_sec: 10
min_duration_sec: 5
- 90
interval_multiplier: 3.5
merge_gap_sec: 45
min_duration_sec: 10
bcg:
downsample_fs: 100
@ -49,3 +49,17 @@ bcg_filter:
high_cut: 10
order: 10
bcg_low_amp:
window_size_sec: 1
stride_sec:
amplitude_threshold: 10
merge_gap_sec: 20
min_duration_sec: 3
bcg_movement:
window_size_sec: 2
stride_sec:
merge_gap_sec: 20
min_duration_sec: 4

View File

@ -188,8 +188,8 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs,
if mask is None:
return np.full_like(ref, np.nan)
else:
# 将mask中的0替换为nan1替换为1
mask = np.where(mask == 0, np.nan, 1)
# 将mask中的0替换为nan其他的保持
mask = np.where(mask == 0, np.nan, mask)
return mask
signal_disable_mask = _none_to_nan_mask(signal_disable_mask, signal_data)
@ -224,7 +224,7 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs,
ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='orange')
ax1.set_ylabel('Amplitude')
ax1.set_xticklabels([])
# ax1.set_xticklabels([])
ax1_twin = ax1.twinx()
ax1_twin.plot(np.linspace(0, len(resp_low_amp_mask), len(resp_low_amp_mask)), resp_low_amp_mask*-1,
color='blue', alpha=0.5, label='Low Amplitude Mask')

View File

@ -421,7 +421,7 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat
energy = np.sum(np.abs(signal_data[start:end] ** 2))
segment_average_energy.append(energy)
position_changes = []
position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int)
position_change_times = []
# 判断是否存在显著变化 (可根据实际情况调整阈值)
threshold_amplitude = 0.1 # 幅值变化阈值
@ -440,9 +440,7 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat
if significant_change:
# 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示
position_changes.append(1)
position_changes[movement_start[i - 1]:movement_end[i - 1]] = 1
position_change_times.append((movement_start[i - 1], movement_end[i - 1]))
else:
position_changes.append(0) # 0表示不存在姿势变化
return np.array(position_changes), position_change_times
return position_changes, position_change_times

View File

@ -184,12 +184,12 @@ def load_dataset_conf(yaml_path):
def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray:
disable_mask = np.ones(signal_second, dtype=int)
disable_mask = np.zeros(signal_second, dtype=int)
for _, row in disable_df.iterrows():
start = row["start"]
end = row["end"]
disable_mask[start:end] = 0
disable_mask[start:end] = 1
return disable_mask