diff --git a/HYS_process.py b/HYS_process.py index 9422d9c..85fb7a3 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -21,12 +21,14 @@ todo: 使用mask 屏蔽无用区间 """ from pathlib import Path -from typing import Union + +import draw_tools import utils import numpy as np import signal_method - - +import os +from matplotlib import pyplot as plt +os.environ['DISPLAY'] = "localhost:10.0" def process_one_signal(samp_id): signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) @@ -41,7 +43,6 @@ def process_one_signal(samp_id): label_path = list(label_path)[0] print(f"Processing Label_corrected file: {label_path}") - signal_data = utils.read_signal_txt(signal_path) signal_length = len(signal_data) print(f"signal_length: {signal_length}") @@ -50,43 +51,174 @@ def process_one_signal(samp_id): signal_second = signal_length // signal_fs print(f"signal_second: {signal_second}") + # 根据采样率进行截断 + signal_data = signal_data[:signal_second * signal_fs] + # 滤波 # 50Hz陷波滤波器 # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) - resp_data = utils.butterworth(data=signal_data, _type=conf["resp"]["filter_type"], low_cut=conf["resp"]["low_cut"], - high_cut=conf["resp"]["high_cut"], order=conf["resp"]["order"], sample_rate=signal_fs) + print("Applying 50Hz notch filter...") + signal_data = utils.notch_filter(data=signal_data, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) - bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg"]["filter_type"], low_cut=conf["bcg"]["low_cut"], - high_cut=conf["bcg"]["high_cut"], order=conf["bcg"]["order"], sample_rate=signal_fs) + resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs) + resp_fs = conf["resp"]["downsample_fs_1"] + resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs) + resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20) + resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"], + low_cut=conf["resp_filter"]["low_cut"], + high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"], + sample_rate=resp_fs) + print("Begin plotting signal data...") + # fig = plt.figure(figsize=(12, 8)) + # # 绘制三个图raw_data、resp_data_1、resp_data_2 + # ax0 = fig.add_subplot(3, 1, 1) + # ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue') + # ax0.set_title('Raw Signal Data') + # ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) + # ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange') + # ax1.set_title('Resp Data after Average Filtering') + # ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) + # ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green') + # ax2.set_title('Resp Data after Butterworth Filtering') + # plt.tight_layout() + # plt.show() + + bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"], + low_cut=conf["bcg_filter"]["low_cut"], + high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"], + sample_rate=signal_fs) + + # 降采样 + old_resp_fs = resp_fs + resp_fs = conf["resp"]["downsample_fs_2"] + resp_data = utils.downsample_signal_fast(original_signal=resp_data_2, original_fs=old_resp_fs, target_fs=resp_fs) + bcg_fs = conf["bcg"]["downsample_fs"] + bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs) + label_data = utils.read_label_csv(path=label_path) label_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data) - manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[all_samp_disable_df["id"] == samp_id]) + manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[ + all_samp_disable_df["id"] == samp_id]) print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}") # 分析Resp的低幅值区间 - resp_low_amp_conf = getattr(conf, "resp_low_amp", None) + resp_low_amp_conf = conf.get("resp_low_amp", None) if resp_low_amp_conf is not None: - resp_low_amp_mask = signal_method.detect_low_amplitude_signal( + resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal( signal_data=resp_data, - sampling_rate=signal_fs, + sampling_rate=resp_fs, window_size_sec=resp_low_amp_conf["window_size_sec"], stride_sec=resp_low_amp_conf["stride_sec"], amplitude_threshold=resp_low_amp_conf["amplitude_threshold"], merge_gap_sec=resp_low_amp_conf["merge_gap_sec"], min_duration_sec=resp_low_amp_conf["min_duration_sec"] ) + print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}") else: - resp_low_amp_mask = None + resp_low_amp_mask, resp_low_amp_position_list = None, None + print("resp_low_amp_mask is None") # 分析Resp的高幅值伪迹区间 - resp_move + resp_movement_conf = conf.get("resp_movement", None) + if resp_movement_conf is not None: + raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement( + signal_data=resp_data, + sampling_rate=resp_fs, + window_size_sec=resp_movement_conf["window_size_sec"], + stride_sec=resp_movement_conf["stride_sec"], + std_median_multiplier=resp_movement_conf["std_median_multiplier"], + compare_intervals_sec=resp_movement_conf["compare_intervals_sec"], + interval_multiplier=resp_movement_conf["interval_multiplier"], + merge_gap_sec=resp_movement_conf["merge_gap_sec"], + min_duration_sec=resp_movement_conf["min_duration_sec"] + ) + print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}") + else: + resp_movement_mask = None + print("resp_movement_mask is None") + # 分析Resp的幅值突变区间 + if resp_movement_mask is not None: + resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v2( + signal_data=resp_data, + movement_mask=resp_movement_mask, + sampling_rate=resp_fs) + print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}") + else: + resp_amp_change_mask = None + print("amp_change_mask is None") + # 分析Bcg的低幅值区间 + bcg_low_amp_conf = conf.get("bcg_low_amp", None) + if bcg_low_amp_conf is not None: + bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal( + signal_data=bcg_data, + sampling_rate=bcg_fs, + window_size_sec=bcg_low_amp_conf["window_size_sec"], + stride_sec=bcg_low_amp_conf["stride_sec"], + amplitude_threshold=bcg_low_amp_conf["amplitude_threshold"], + merge_gap_sec=bcg_low_amp_conf["merge_gap_sec"], + min_duration_sec=bcg_low_amp_conf["min_duration_sec"] + ) + print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}") + else: + bcg_low_amp_mask, bcg_low_amp_position_list = None, None + print("bcg_low_amp_mask is None") + # 分析Bcg的高幅值伪迹区间 + bcg_movement_conf = conf.get("bcg_movement", None) + if bcg_movement_conf is not None: + raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement( + signal_data=bcg_data, + sampling_rate=bcg_fs, + window_size_sec=bcg_movement_conf["window_size_sec"], + stride_sec=bcg_movement_conf["stride_sec"], + std_median_multiplier=bcg_movement_conf["std_median_multiplier"], + compare_intervals_sec=bcg_movement_conf["compare_intervals_sec"], + interval_multiplier=bcg_movement_conf["interval_multiplier"], + merge_gap_sec=bcg_movement_conf["merge_gap_sec"], + min_duration_sec=bcg_movement_conf["min_duration_sec"] + ) + print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}") + else: + bcg_movement_mask = None + print("bcg_movement_mask is None") + # 分析Bcg的幅值突变区间 + if bcg_movement_mask is not None: + bcg_amp_change_mask, bcg_amp_change_list = signal_method.position_based_sleep_recognition_v2( + signal_data=bcg_data, + movement_mask=bcg_movement_mask, + sampling_rate=bcg_fs) + print(f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}") + else: + bcg_amp_change_mask = None + print("bcg_amp_change_mask is None") + + + # 如果signal_data采样率过,进行降采样 + if signal_fs == 1000: + signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100) + signal_fs = 100 + + draw_tools.draw_signal_with_mask(samp_id=samp_id, + signal_data=signal_data, + signal_fs=signal_fs, + resp_data=resp_data, + resp_fs=resp_fs, + bcg_data=bcg_data, + bcg_fs=bcg_fs, + signal_disable_mask=manual_disable_mask, + resp_low_amp_mask=resp_low_amp_mask, + resp_movement_mask=resp_movement_mask, + resp_change_mask=resp_amp_change_mask, + resp_sa_mask=None, + bcg_low_amp_mask=bcg_low_amp_mask, + bcg_movement_mask=bcg_movement_mask, + bcg_change_mask=bcg_amp_change_mask) diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index dfff364..c30c3c5 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -12,19 +12,37 @@ select_ids: root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS +resp: + downsample_fs_1: 100 + downsample_fs_2: 10 + resp_filter: filter_type: bandpass low_cut: 0.01 high_cut: 0.7 - order: 10 + order: 2 resp_low_amp: - windows_size_sec: 1 - stride_sec: None - amplitude_threshold: 50 + window_size_sec: 1 + stride_sec: + amplitude_threshold: 20 merge_gap_sec: 10 min_duration_sec: 5 +resp_movement: + window_size_sec: 2 + stride_sec: + std_median_multiplier: 4.5 + compare_intervals_sec: + - 30 + - 60 + interval_multiplier: 2.5 + merge_gap_sec: 10 + min_duration_sec: 5 + +bcg: + downsample_fs: 100 + bcg_filter: filter_type: bandpass low_cut: 1 diff --git a/draw_tools/__init__.py b/draw_tools/__init__.py index e69de29..5d4efe2 100644 --- a/draw_tools/__init__.py +++ b/draw_tools/__init__.py @@ -0,0 +1 @@ +from draw_tools.draw_statics import draw_signal_with_mask \ No newline at end of file diff --git a/draw_tools/draw_statics.py b/draw_tools/draw_statics.py index 88f790e..b4e401a 100644 --- a/draw_tools/draw_statics.py +++ b/draw_tools/draw_statics.py @@ -1,6 +1,8 @@ from matplotlib.axes import Axes from matplotlib.gridspec import GridSpec from matplotlib.colors import PowerNorm +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches import seaborn as sns import numpy as np @@ -74,8 +76,7 @@ def draw_ax_confusion_matrix(ax:Axes, confusion_matrix, segment_count_matrix, co ax.text(len(time_labels) + 0.5, len(amp_labels) + 0.5, f"[{int(total_segments)}]{valid_signal_length}\n({total_percent:.2f}%)", ha='center', va='center') -import matplotlib.pyplot as plt -import matplotlib.patches as mpatches + def draw_ax_amp(ax, signal_name, original_times, origin_signal, no_movement_signal, mav_values, movement_position_list, low_amp_position_list, signal_second_length, aml_list=None): @@ -172,4 +173,127 @@ def draw_signal_metrics(bcg_origin_signal, resp_origin_signal, bcg_no_movement_s if show: plt.show() - plt.close() \ No newline at end of file + plt.close() + + +def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, resp_fs, bcg_fs, + signal_disable_mask, resp_low_amp_mask, resp_movement_mask, resp_change_mask, + resp_sa_mask, bcg_low_amp_mask, bcg_movement_mask, bcg_change_mask + ): + # 第一行绘制去工频噪声的原始信号,右侧为不可用区间标记,左侧为信号幅值纵坐标 + # 第二行绘制呼吸分量,右侧低幅值、高幅值、幅值变换标记、SA标签,左侧为呼吸幅值纵坐标 + # 第三行绘制心冲击分量,右侧为低幅值、高幅值、幅值变换标记、,左侧为心冲击幅值纵坐标 + # mask为None,则生成全Nan掩码 + def _none_to_nan_mask(mask, ref): + if mask is None: + return np.full_like(ref, np.nan) + else: + # 将mask中的0替换为nan,1替换为1 + mask = np.where(mask == 0, np.nan, 1) + return mask + + signal_disable_mask = _none_to_nan_mask(signal_disable_mask, signal_data) + resp_low_amp_mask = _none_to_nan_mask(resp_low_amp_mask, resp_data) + resp_movement_mask = _none_to_nan_mask(resp_movement_mask, resp_data) + resp_change_mask = _none_to_nan_mask(resp_change_mask, resp_data) + resp_sa_mask = _none_to_nan_mask(resp_sa_mask, resp_data) + bcg_low_amp_mask = _none_to_nan_mask(bcg_low_amp_mask, bcg_data) + bcg_movement_mask = _none_to_nan_mask(bcg_movement_mask, bcg_data) + bcg_change_mask = _none_to_nan_mask(bcg_change_mask, bcg_data) + + + fig = plt.figure(figsize=(18, 10)) + ax0 = fig.add_subplot(3, 1, 1) + ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue') + ax0.set_title(f'Sample {samp_id} - Raw Signal Data') + ax0.set_ylabel('Amplitude') + # ax0.set_xticklabels([]) + + ax0_twin = ax0.twinx() + ax0_twin.plot(np.linspace(0, len(signal_disable_mask), len(signal_disable_mask)), signal_disable_mask, + color='red', alpha=0.5) + ax0_twin.autoscale(enable=False, axis='y', tight=True) + ax0_twin.set_ylim((-2, 2)) + ax0_twin.set_ylabel('Disable Mask') + ax0_twin.set_yticks([0, 1]) + ax0_twin.set_yticklabels(['Enabled', 'Disabled']) + ax0_twin.grid(False) + ax0_twin.legend(['Disable Mask'], loc='upper right') + + + ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) + ax1.plot(np.linspace(0, len(resp_data) // resp_fs, len(resp_data)), resp_data, color='orange') + ax1.set_ylabel('Amplitude') + ax1.set_xticklabels([]) + ax1_twin = ax1.twinx() + ax1_twin.plot(np.linspace(0, len(resp_low_amp_mask), len(resp_low_amp_mask)), resp_low_amp_mask*-1, + color='blue', alpha=0.5, label='Low Amplitude Mask') + ax1_twin.plot(np.linspace(0, len(resp_movement_mask), len(resp_movement_mask)), resp_movement_mask*-2, + color='red', alpha=0.5, label='Movement Mask') + ax1_twin.plot(np.linspace(0, len(resp_change_mask), len(resp_change_mask)), resp_change_mask*-3, + color='green', alpha=0.5, label='Amplitude Change Mask') + ax1_twin.plot(np.linspace(0, len(resp_sa_mask), len(resp_sa_mask)), resp_sa_mask, + color='purple', alpha=0.5, label='SA Mask') + ax1_twin.autoscale(enable=False, axis='y', tight=True) + ax1_twin.set_ylim((-4, 5)) + # ax1_twin.set_ylabel('Resp Masks') + # ax1_twin.set_yticks([0, 1]) + # ax1_twin.set_yticklabels(['No', 'Yes']) + ax1_twin.grid(False) + + ax1_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask', 'SA Mask'], loc='upper right') + ax1.set_title(f'Sample {samp_id} - Respiration Component') + + ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) + ax2.plot(np.linspace(0, len(bcg_data) // bcg_fs, len(bcg_data)), bcg_data, color='green') + ax2.set_ylabel('Amplitude') + ax2.set_xlabel('Time (s)') + ax2_twin = ax2.twinx() + ax2_twin.plot(np.linspace(0, len(bcg_low_amp_mask), len(bcg_low_amp_mask)), bcg_low_amp_mask*-1, + color='blue', alpha=0.5, label='Low Amplitude Mask') + ax2_twin.plot(np.linspace(0, len(bcg_movement_mask), len(bcg_movement_mask)), bcg_movement_mask*-2, + color='red', alpha=0.5, label='Movement Mask') + ax2_twin.plot(np.linspace(0, len(bcg_change_mask), len(bcg_change_mask)), bcg_change_mask*-3, + color='green', alpha=0.5, label='Amplitude Change Mask') + ax2_twin.autoscale(enable=False, axis='y', tight=True) + ax2_twin.set_ylim((-4, 2)) + ax2_twin.set_ylabel('BCG Masks') + # ax2_twin.set_yticks([0, 1]) + # ax2_twin.set_yticklabels(['No', 'Yes']) + ax2_twin.grid(False) + ax2_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask'], loc='upper right') + # ax2.set_title(f'Sample {samp_id} - BCG Component') + + ax0_twin._lim_lock = False + ax1_twin._lim_lock = False + ax2_twin._lim_lock = False + + def on_lims_change(event_ax): + if getattr(event_ax, '_lim_lock', False): + return + try: + event_ax._lim_lock = True + + if event_ax == ax0_twin: + # 重新锁定 ax1 的 Y 轴范围 + ax0_twin.set_ylim(-2, 2) + elif event_ax == ax1_twin: + ax1_twin.set_ylim(-3, 5) + elif event_ax == ax2_twin: + ax2_twin.set_ylim(-4, 2) + + finally: + event_ax._lim_lock = False + + + ax0_twin.callbacks.connect('ylim_changed', on_lims_change) + ax1_twin.callbacks.connect('ylim_changed', on_lims_change) + ax2_twin.callbacks.connect('ylim_changed', on_lims_change) + + + plt.tight_layout() + plt.show() + + + + diff --git a/signal_method/__init__.py b/signal_method/__init__.py index 46eac36..a1d61f2 100644 --- a/signal_method/__init__.py +++ b/signal_method/__init__.py @@ -1 +1 @@ -from signal_method.rule_base_event import detect_low_amplitude_signal \ No newline at end of file +from signal_method.rule_base_event import detect_low_amplitude_signal, detect_movement, position_based_sleep_recognition_v2 \ No newline at end of file diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index 8de49da..a8e2480 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -379,7 +379,7 @@ def position_based_sleep_recognition_v1(signal_data, movement_mask, sampling_rat return position_changes, position_change_times -def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rate=100, window_size_sec=30): +def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rate=100): """ :param signal_data: @@ -445,4 +445,4 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat else: position_changes.append(0) # 0表示不存在姿势变化 - return position_changes, position_change_times + return np.array(position_changes), position_change_times diff --git a/utils/__init__.py b/utils/__init__.py index faaebaf..ae2ee06 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,4 +1,4 @@ from utils.HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel from utils.operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask from utils.event_map import E2N -from utils.signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter \ No newline at end of file +from utils.signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel \ No newline at end of file diff --git a/utils/signal_process.py b/utils/signal_process.py index dea0ea1..c657d3f 100644 --- a/utils/signal_process.py +++ b/utils/signal_process.py @@ -21,6 +21,22 @@ def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10,sample_rate=100 raise ValueError("Please choose a type of fliter") +def bessel(data, _type, low_cut=0.0, high_cut=0.0, order=4, sample_rate=1000): + if _type == "lowpass": # 低通滤波处理 + b, a = signal.bessel(order, low_cut / (sample_rate * 0.5), btype='lowpass', analog=False, norm='mag') + return signal.filtfilt(b, a, np.array(data)) + elif _type == "bandpass": # 带通滤波处理 + low = low_cut / (sample_rate * 0.5) + high = high_cut / (sample_rate * 0.5) + b, a = signal.bessel(order, [low, high], btype='bandpass', analog=False, norm='mag') + return signal.filtfilt(b, a, np.array(data)) + elif _type == "highpass": # 高通滤波处理 + b, a = signal.bessel(order, high_cut / (sample_rate * 0.5), btype='highpass', analog=False, norm='mag') + return signal.filtfilt(b, a, np.array(data)) + else: # 警告,滤波器类型必须有 + raise ValueError("Please choose a type of fliter") + + @timing_decorator() def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=100000): """ @@ -75,8 +91,8 @@ def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=1 return downsampled_signal @timing_decorator() -def average_filter(raw_data, sample_rate, window_size=20): - kernel = np.ones(window_size * sample_rate) / (window_size * sample_rate) +def average_filter(raw_data, sample_rate, window_size_sec=20): + kernel = np.ones(window_size_sec * sample_rate) / (window_size_sec * sample_rate) filtered = ndimage.convolve1d(raw_data, kernel, mode='reflect') convolve_filter_signal = raw_data - filtered return convolve_filter_signal