Add signal drawing functionality and enhance signal processing methods

This commit is contained in:
marques 2025-10-29 10:53:14 +08:00
parent 40aad46d6f
commit 9fdbc4a1cb
8 changed files with 318 additions and 27 deletions

View File

@ -21,12 +21,14 @@ todo: 使用mask 屏蔽无用区间
"""
from pathlib import Path
from typing import Union
import draw_tools
import utils
import numpy as np
import signal_method
import os
from matplotlib import pyplot as plt
os.environ['DISPLAY'] = "localhost:10.0"
def process_one_signal(samp_id):
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
@ -41,7 +43,6 @@ def process_one_signal(samp_id):
label_path = list(label_path)[0]
print(f"Processing Label_corrected file: {label_path}")
signal_data = utils.read_signal_txt(signal_path)
signal_length = len(signal_data)
print(f"signal_length: {signal_length}")
@ -50,43 +51,174 @@ def process_one_signal(samp_id):
signal_second = signal_length // signal_fs
print(f"signal_second: {signal_second}")
# 根据采样率进行截断
signal_data = signal_data[:signal_second * signal_fs]
# 滤波
# 50Hz陷波滤波器
# signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs)
resp_data = utils.butterworth(data=signal_data, _type=conf["resp"]["filter_type"], low_cut=conf["resp"]["low_cut"],
high_cut=conf["resp"]["high_cut"], order=conf["resp"]["order"], sample_rate=signal_fs)
print("Applying 50Hz notch filter...")
signal_data = utils.notch_filter(data=signal_data, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs)
bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg"]["filter_type"], low_cut=conf["bcg"]["low_cut"],
high_cut=conf["bcg"]["high_cut"], order=conf["bcg"]["order"], sample_rate=signal_fs)
resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs)
resp_fs = conf["resp"]["downsample_fs_1"]
resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs)
resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20)
resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"],
low_cut=conf["resp_filter"]["low_cut"],
high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"],
sample_rate=resp_fs)
print("Begin plotting signal data...")
# fig = plt.figure(figsize=(12, 8))
# # 绘制三个图raw_data、resp_data_1、resp_data_2
# ax0 = fig.add_subplot(3, 1, 1)
# ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue')
# ax0.set_title('Raw Signal Data')
# ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
# ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange')
# ax1.set_title('Resp Data after Average Filtering')
# ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
# ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green')
# ax2.set_title('Resp Data after Butterworth Filtering')
# plt.tight_layout()
# plt.show()
bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"],
low_cut=conf["bcg_filter"]["low_cut"],
high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"],
sample_rate=signal_fs)
# 降采样
old_resp_fs = resp_fs
resp_fs = conf["resp"]["downsample_fs_2"]
resp_data = utils.downsample_signal_fast(original_signal=resp_data_2, original_fs=old_resp_fs, target_fs=resp_fs)
bcg_fs = conf["bcg"]["downsample_fs"]
bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs)
label_data = utils.read_label_csv(path=label_path)
label_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data)
manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[all_samp_disable_df["id"] == samp_id])
manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[
all_samp_disable_df["id"] == samp_id])
print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}")
# 分析Resp的低幅值区间
resp_low_amp_conf = getattr(conf, "resp_low_amp", None)
resp_low_amp_conf = conf.get("resp_low_amp", None)
if resp_low_amp_conf is not None:
resp_low_amp_mask = signal_method.detect_low_amplitude_signal(
resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal(
signal_data=resp_data,
sampling_rate=signal_fs,
sampling_rate=resp_fs,
window_size_sec=resp_low_amp_conf["window_size_sec"],
stride_sec=resp_low_amp_conf["stride_sec"],
amplitude_threshold=resp_low_amp_conf["amplitude_threshold"],
merge_gap_sec=resp_low_amp_conf["merge_gap_sec"],
min_duration_sec=resp_low_amp_conf["min_duration_sec"]
)
print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}")
else:
resp_low_amp_mask = None
resp_low_amp_mask, resp_low_amp_position_list = None, None
print("resp_low_amp_mask is None")
# 分析Resp的高幅值伪迹区间
resp_move
resp_movement_conf = conf.get("resp_movement", None)
if resp_movement_conf is not None:
raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement(
signal_data=resp_data,
sampling_rate=resp_fs,
window_size_sec=resp_movement_conf["window_size_sec"],
stride_sec=resp_movement_conf["stride_sec"],
std_median_multiplier=resp_movement_conf["std_median_multiplier"],
compare_intervals_sec=resp_movement_conf["compare_intervals_sec"],
interval_multiplier=resp_movement_conf["interval_multiplier"],
merge_gap_sec=resp_movement_conf["merge_gap_sec"],
min_duration_sec=resp_movement_conf["min_duration_sec"]
)
print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}")
else:
resp_movement_mask = None
print("resp_movement_mask is None")
# 分析Resp的幅值突变区间
if resp_movement_mask is not None:
resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v2(
signal_data=resp_data,
movement_mask=resp_movement_mask,
sampling_rate=resp_fs)
print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}")
else:
resp_amp_change_mask = None
print("amp_change_mask is None")
# 分析Bcg的低幅值区间
bcg_low_amp_conf = conf.get("bcg_low_amp", None)
if bcg_low_amp_conf is not None:
bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal(
signal_data=bcg_data,
sampling_rate=bcg_fs,
window_size_sec=bcg_low_amp_conf["window_size_sec"],
stride_sec=bcg_low_amp_conf["stride_sec"],
amplitude_threshold=bcg_low_amp_conf["amplitude_threshold"],
merge_gap_sec=bcg_low_amp_conf["merge_gap_sec"],
min_duration_sec=bcg_low_amp_conf["min_duration_sec"]
)
print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}")
else:
bcg_low_amp_mask, bcg_low_amp_position_list = None, None
print("bcg_low_amp_mask is None")
# 分析Bcg的高幅值伪迹区间
bcg_movement_conf = conf.get("bcg_movement", None)
if bcg_movement_conf is not None:
raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement(
signal_data=bcg_data,
sampling_rate=bcg_fs,
window_size_sec=bcg_movement_conf["window_size_sec"],
stride_sec=bcg_movement_conf["stride_sec"],
std_median_multiplier=bcg_movement_conf["std_median_multiplier"],
compare_intervals_sec=bcg_movement_conf["compare_intervals_sec"],
interval_multiplier=bcg_movement_conf["interval_multiplier"],
merge_gap_sec=bcg_movement_conf["merge_gap_sec"],
min_duration_sec=bcg_movement_conf["min_duration_sec"]
)
print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}")
else:
bcg_movement_mask = None
print("bcg_movement_mask is None")
# 分析Bcg的幅值突变区间
if bcg_movement_mask is not None:
bcg_amp_change_mask, bcg_amp_change_list = signal_method.position_based_sleep_recognition_v2(
signal_data=bcg_data,
movement_mask=bcg_movement_mask,
sampling_rate=bcg_fs)
print(f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}")
else:
bcg_amp_change_mask = None
print("bcg_amp_change_mask is None")
# 如果signal_data采样率过进行降采样
if signal_fs == 1000:
signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100)
signal_fs = 100
draw_tools.draw_signal_with_mask(samp_id=samp_id,
signal_data=signal_data,
signal_fs=signal_fs,
resp_data=resp_data,
resp_fs=resp_fs,
bcg_data=bcg_data,
bcg_fs=bcg_fs,
signal_disable_mask=manual_disable_mask,
resp_low_amp_mask=resp_low_amp_mask,
resp_movement_mask=resp_movement_mask,
resp_change_mask=resp_amp_change_mask,
resp_sa_mask=None,
bcg_low_amp_mask=bcg_low_amp_mask,
bcg_movement_mask=bcg_movement_mask,
bcg_change_mask=bcg_amp_change_mask)

View File

@ -12,19 +12,37 @@ select_ids:
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
resp:
downsample_fs_1: 100
downsample_fs_2: 10
resp_filter:
filter_type: bandpass
low_cut: 0.01
high_cut: 0.7
order: 10
order: 2
resp_low_amp:
windows_size_sec: 1
stride_sec: None
amplitude_threshold: 50
window_size_sec: 1
stride_sec:
amplitude_threshold: 20
merge_gap_sec: 10
min_duration_sec: 5
resp_movement:
window_size_sec: 2
stride_sec:
std_median_multiplier: 4.5
compare_intervals_sec:
- 30
- 60
interval_multiplier: 2.5
merge_gap_sec: 10
min_duration_sec: 5
bcg:
downsample_fs: 100
bcg_filter:
filter_type: bandpass
low_cut: 1

View File

@ -0,0 +1 @@
from draw_tools.draw_statics import draw_signal_with_mask

View File

@ -1,6 +1,8 @@
from matplotlib.axes import Axes
from matplotlib.gridspec import GridSpec
from matplotlib.colors import PowerNorm
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import numpy as np
@ -74,8 +76,7 @@ def draw_ax_confusion_matrix(ax:Axes, confusion_matrix, segment_count_matrix, co
ax.text(len(time_labels) + 0.5, len(amp_labels) + 0.5, f"[{int(total_segments)}]{valid_signal_length}\n({total_percent:.2f}%)",
ha='center', va='center')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
def draw_ax_amp(ax, signal_name, original_times, origin_signal, no_movement_signal, mav_values,
movement_position_list, low_amp_position_list, signal_second_length, aml_list=None):
@ -173,3 +174,126 @@ def draw_signal_metrics(bcg_origin_signal, resp_origin_signal, bcg_no_movement_s
plt.show()
plt.close()
def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, resp_fs, bcg_fs,
signal_disable_mask, resp_low_amp_mask, resp_movement_mask, resp_change_mask,
resp_sa_mask, bcg_low_amp_mask, bcg_movement_mask, bcg_change_mask
):
# 第一行绘制去工频噪声的原始信号,右侧为不可用区间标记,左侧为信号幅值纵坐标
# 第二行绘制呼吸分量右侧低幅值、高幅值、幅值变换标记、SA标签左侧为呼吸幅值纵坐标
# 第三行绘制心冲击分量,右侧为低幅值、高幅值、幅值变换标记、,左侧为心冲击幅值纵坐标
# mask为None则生成全Nan掩码
def _none_to_nan_mask(mask, ref):
if mask is None:
return np.full_like(ref, np.nan)
else:
# 将mask中的0替换为nan1替换为1
mask = np.where(mask == 0, np.nan, 1)
return mask
signal_disable_mask = _none_to_nan_mask(signal_disable_mask, signal_data)
resp_low_amp_mask = _none_to_nan_mask(resp_low_amp_mask, resp_data)
resp_movement_mask = _none_to_nan_mask(resp_movement_mask, resp_data)
resp_change_mask = _none_to_nan_mask(resp_change_mask, resp_data)
resp_sa_mask = _none_to_nan_mask(resp_sa_mask, resp_data)
bcg_low_amp_mask = _none_to_nan_mask(bcg_low_amp_mask, bcg_data)
bcg_movement_mask = _none_to_nan_mask(bcg_movement_mask, bcg_data)
bcg_change_mask = _none_to_nan_mask(bcg_change_mask, bcg_data)
fig = plt.figure(figsize=(18, 10))
ax0 = fig.add_subplot(3, 1, 1)
ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue')
ax0.set_title(f'Sample {samp_id} - Raw Signal Data')
ax0.set_ylabel('Amplitude')
# ax0.set_xticklabels([])
ax0_twin = ax0.twinx()
ax0_twin.plot(np.linspace(0, len(signal_disable_mask), len(signal_disable_mask)), signal_disable_mask,
color='red', alpha=0.5)
ax0_twin.autoscale(enable=False, axis='y', tight=True)
ax0_twin.set_ylim((-2, 2))
ax0_twin.set_ylabel('Disable Mask')
ax0_twin.set_yticks([0, 1])
ax0_twin.set_yticklabels(['Enabled', 'Disabled'])
ax0_twin.grid(False)
ax0_twin.legend(['Disable Mask'], loc='upper right')
ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='orange')
ax1.set_ylabel('Amplitude')
ax1.set_xticklabels([])
ax1_twin = ax1.twinx()
ax1_twin.plot(np.linspace(0, len(resp_low_amp_mask), len(resp_low_amp_mask)), resp_low_amp_mask*-1,
color='blue', alpha=0.5, label='Low Amplitude Mask')
ax1_twin.plot(np.linspace(0, len(resp_movement_mask), len(resp_movement_mask)), resp_movement_mask*-2,
color='red', alpha=0.5, label='Movement Mask')
ax1_twin.plot(np.linspace(0, len(resp_change_mask), len(resp_change_mask)), resp_change_mask*-3,
color='green', alpha=0.5, label='Amplitude Change Mask')
ax1_twin.plot(np.linspace(0, len(resp_sa_mask), len(resp_sa_mask)), resp_sa_mask,
color='purple', alpha=0.5, label='SA Mask')
ax1_twin.autoscale(enable=False, axis='y', tight=True)
ax1_twin.set_ylim((-4, 5))
# ax1_twin.set_ylabel('Resp Masks')
# ax1_twin.set_yticks([0, 1])
# ax1_twin.set_yticklabels(['No', 'Yes'])
ax1_twin.grid(False)
ax1_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask', 'SA Mask'], loc='upper right')
ax1.set_title(f'Sample {samp_id} - Respiration Component')
ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
ax2.plot(np.linspace(0, len(bcg_data) // bcg_fs, len(bcg_data)), bcg_data, color='green')
ax2.set_ylabel('Amplitude')
ax2.set_xlabel('Time (s)')
ax2_twin = ax2.twinx()
ax2_twin.plot(np.linspace(0, len(bcg_low_amp_mask), len(bcg_low_amp_mask)), bcg_low_amp_mask*-1,
color='blue', alpha=0.5, label='Low Amplitude Mask')
ax2_twin.plot(np.linspace(0, len(bcg_movement_mask), len(bcg_movement_mask)), bcg_movement_mask*-2,
color='red', alpha=0.5, label='Movement Mask')
ax2_twin.plot(np.linspace(0, len(bcg_change_mask), len(bcg_change_mask)), bcg_change_mask*-3,
color='green', alpha=0.5, label='Amplitude Change Mask')
ax2_twin.autoscale(enable=False, axis='y', tight=True)
ax2_twin.set_ylim((-4, 2))
ax2_twin.set_ylabel('BCG Masks')
# ax2_twin.set_yticks([0, 1])
# ax2_twin.set_yticklabels(['No', 'Yes'])
ax2_twin.grid(False)
ax2_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask'], loc='upper right')
# ax2.set_title(f'Sample {samp_id} - BCG Component')
ax0_twin._lim_lock = False
ax1_twin._lim_lock = False
ax2_twin._lim_lock = False
def on_lims_change(event_ax):
if getattr(event_ax, '_lim_lock', False):
return
try:
event_ax._lim_lock = True
if event_ax == ax0_twin:
# 重新锁定 ax1 的 Y 轴范围
ax0_twin.set_ylim(-2, 2)
elif event_ax == ax1_twin:
ax1_twin.set_ylim(-3, 5)
elif event_ax == ax2_twin:
ax2_twin.set_ylim(-4, 2)
finally:
event_ax._lim_lock = False
ax0_twin.callbacks.connect('ylim_changed', on_lims_change)
ax1_twin.callbacks.connect('ylim_changed', on_lims_change)
ax2_twin.callbacks.connect('ylim_changed', on_lims_change)
plt.tight_layout()
plt.show()

View File

@ -1 +1 @@
from signal_method.rule_base_event import detect_low_amplitude_signal
from signal_method.rule_base_event import detect_low_amplitude_signal, detect_movement, position_based_sleep_recognition_v2

View File

@ -379,7 +379,7 @@ def position_based_sleep_recognition_v1(signal_data, movement_mask, sampling_rat
return position_changes, position_change_times
def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rate=100, window_size_sec=30):
def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rate=100):
"""
:param signal_data:
@ -445,4 +445,4 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat
else:
position_changes.append(0) # 0表示不存在姿势变化
return position_changes, position_change_times
return np.array(position_changes), position_change_times

View File

@ -1,4 +1,4 @@
from utils.HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel
from utils.operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask
from utils.event_map import E2N
from utils.signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter
from utils.signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel

View File

@ -21,6 +21,22 @@ def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10,sample_rate=100
raise ValueError("Please choose a type of fliter")
def bessel(data, _type, low_cut=0.0, high_cut=0.0, order=4, sample_rate=1000):
if _type == "lowpass": # 低通滤波处理
b, a = signal.bessel(order, low_cut / (sample_rate * 0.5), btype='lowpass', analog=False, norm='mag')
return signal.filtfilt(b, a, np.array(data))
elif _type == "bandpass": # 带通滤波处理
low = low_cut / (sample_rate * 0.5)
high = high_cut / (sample_rate * 0.5)
b, a = signal.bessel(order, [low, high], btype='bandpass', analog=False, norm='mag')
return signal.filtfilt(b, a, np.array(data))
elif _type == "highpass": # 高通滤波处理
b, a = signal.bessel(order, high_cut / (sample_rate * 0.5), btype='highpass', analog=False, norm='mag')
return signal.filtfilt(b, a, np.array(data))
else: # 警告,滤波器类型必须有
raise ValueError("Please choose a type of fliter")
@timing_decorator()
def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000):
"""
@ -75,8 +91,8 @@ def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=1
return downsampled_signal
@timing_decorator()
def average_filter(raw_data, sample_rate, window_size=20):
kernel = np.ones(window_size * sample_rate) / (window_size * sample_rate)
def average_filter(raw_data, sample_rate, window_size_sec=20):
kernel = np.ones(window_size_sec * sample_rate) / (window_size_sec * sample_rate)
filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect')
convolve_filter_signal = raw_data - filtered
return convolve_filter_signal