From 40fdda649791f3c945f1fe46325746d6551f762e Mon Sep 17 00:00:00 2001 From: marques Date: Mon, 10 Nov 2025 18:38:26 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=AD=E5=A4=A7=E4=BA=94=E9=99=A2=E4=B8=BA?= =?UTF-8?q?=E5=8D=A0=E4=BD=8D=EF=BC=8C=E5=91=BC=E7=A0=94=E6=89=80=E5=B7=B2?= =?UTF-8?q?=E5=8F=AF=E4=BB=A5=E6=AD=A3=E5=B8=B8=E5=AF=BC=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- HYS_process.py | 56 ++++++- ZD5Y_process.py | 277 +++++++++++++++++++++++++++++++ dataset_config/HYS_config.yaml | 11 +- dataset_config/ZD5Y_config.yaml | 88 ++++++++++ draw_tools/draw_statics.py | 10 +- signal_method/rule_base_event.py | 131 ++++++++++----- utils/__init__.py | 1 + utils/operation_tools.py | 7 +- 8 files changed, 525 insertions(+), 56 deletions(-) create mode 100644 ZD5Y_process.py create mode 100644 dataset_config/ZD5Y_config.yaml diff --git a/HYS_process.py b/HYS_process.py index 1eecdf5..50f3c6c 100644 --- a/HYS_process.py +++ b/HYS_process.py @@ -21,7 +21,7 @@ todo: 使用mask 屏蔽无用区间 """ from pathlib import Path - +import shutil import draw_tools import utils import numpy as np @@ -30,7 +30,7 @@ import os from matplotlib import pyplot as plt os.environ['DISPLAY'] = "localhost:10.0" -def process_one_signal(samp_id): +def process_one_signal(samp_id, show=False): signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) if not signal_path: raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}") @@ -43,6 +43,10 @@ def process_one_signal(samp_id): label_path = list(label_path)[0] print(f"Processing Label_corrected file: {label_path}") + # 保存处理后的数据和标签 + save_samp_path = save_path / f"{samp_id}" + save_samp_path.mkdir(parents=True, exist_ok=True) + signal_data_raw = utils.read_signal_txt(signal_path) signal_length = len(signal_data_raw) print(f"signal_length: {signal_length}") @@ -137,7 +141,8 @@ def process_one_signal(samp_id): movement_mask=resp_movement_mask, movement_list=resp_movement_position_list, sampling_rate=resp_fs, - **resp_movement_revise_conf + **resp_movement_revise_conf, + verbose=False ) print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") else: @@ -152,7 +157,8 @@ def process_one_signal(samp_id): movement_mask=resp_movement_mask, movement_list=resp_movement_position_list, sampling_rate=resp_fs, - **resp_amp_change_conf) + **resp_amp_change_conf, + verbose=True) print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}") else: resp_amp_change_mask = None @@ -202,7 +208,7 @@ def process_one_signal(samp_id): signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs, target_fs=100) signal_fs = 100 - draw_tools.draw_signal_with_mask(samp_id=samp_id, + draw_tools.draw_signal_with_mask(samp_id=samp_id, signal_data=signal_data, signal_fs=signal_fs, resp_data=resp_data, @@ -216,7 +222,35 @@ def process_one_signal(samp_id): resp_sa_mask=event_mask, bcg_low_amp_mask=bcg_low_amp_mask, bcg_movement_mask=bcg_movement_mask, - bcg_change_mask=bcg_amp_change_mask) + bcg_change_mask=bcg_amp_change_mask, + show=show, + save_path=save_samp_path / f"{samp_id}_Signal_Plots.png") + + + + + # 复制事件文件 到保存路径 + sa_label_save_name = f"{samp_id}" + label_path.name + shutil.copyfile(label_path, save_samp_path / sa_label_save_name) + + # 新建一个dataframe,分别是秒数、SA标签,SA质量标签,禁用标签,Resp低幅值标签,Resp体动标签,Resp幅值突变标签,Bcg低幅值标签,Bcg体动标签,Bcg幅值突变标签 + save_dict = { + "Second": np.arange(signal_second), + "SA_Label": event_mask, + "SA_Score": score_mask, + "Disable_Label": manual_disable_mask, + "Resp_LowAmp_Label": resp_low_amp_mask if resp_low_amp_mask is not None else np.zeros(signal_second, dtype=int), + "Resp_Movement_Label": resp_movement_mask if resp_movement_mask is not None else np.zeros(signal_second, dtype=int), + "Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second, dtype=int), + "Bcg_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int), + "Bcg_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second, dtype=int), + "Bcg_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second, dtype=int) + } + + mask_label_save_name = f"{samp_id}_Processed_Labels.csv" + utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict) + + @@ -229,13 +263,21 @@ if __name__ == '__main__': conf = utils.load_dataset_conf(yaml_path) select_ids = conf["select_ids"] root_path = Path(conf["root_path"]) + save_path = Path(conf["save_path"]) print(f"select_ids: {select_ids}") print(f"root_path: {root_path}") + print(f"save_path: {save_path}") org_signal_root_path = root_path / "OrgBCG_Aligned" label_root_path = root_path / "Label" all_samp_disable_df = utils.read_disable_excel(disable_df_path) - process_one_signal(select_ids[0]) + process_one_signal(select_ids[9], show=True) + + # for samp_id in select_ids: + # print(f"Processing sample ID: {samp_id}") + # process_one_signal(samp_id, show=False) + # print(f"Finished processing sample ID: {samp_id}\n\n") + diff --git a/ZD5Y_process.py b/ZD5Y_process.py new file mode 100644 index 0000000..23dc4c8 --- /dev/null +++ b/ZD5Y_process.py @@ -0,0 +1,277 @@ +""" +本脚本完成对呼研所数据的处理,包含以下功能: +1. 数据读取与预处理 + 从传入路径中,进行数据和标签的读取,并进行初步的预处理 + 预处理包括为数据进行滤波、去噪等操作 +2. 数据清洗与异常值处理 +3. 输出清晰后的统计信息 +4. 数据保存 + 将处理后的数据保存到指定路径,便于后续使用 + 主要是保存切分后的数据位置和标签 +5. 可视化 + 提供数据处理前后的可视化对比,帮助理解数据变化 + 绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等 + +todo: 使用mask 屏蔽无用区间 + + +# 低幅值区间规则标定与剔除 +# 高幅值连续体动规则标定与剔除 +# 手动标定不可用区间提剔除 +""" + +from pathlib import Path +import shutil +import draw_tools +import utils +import numpy as np +import signal_method +import os +from matplotlib import pyplot as plt +os.environ['DISPLAY'] = "localhost:10.0" + +def process_one_signal(samp_id, show=False): + signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) + if not signal_path: + raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}") + signal_path = signal_path[0] + print(f"Processing OrgBCG_Sync signal file: {signal_path}") + + label_path = (label_root_path / f"{samp_id}").glob("SA Label_corrected.csv") + if not label_path: + raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}") + label_path = list(label_path)[0] + print(f"Processing Label_corrected file: {label_path}") + + signal_data_raw = utils.read_signal_txt(signal_path) + signal_length = len(signal_data_raw) + print(f"signal_length: {signal_length}") + signal_fs = int(signal_path.stem.split("_")[-1]) + print(f"signal_fs: {signal_fs}") + signal_second = signal_length // signal_fs + print(f"signal_second: {signal_second}") + + # 根据采样率进行截断 + signal_data_raw = signal_data_raw[:signal_second * signal_fs] + + # 滤波 + # 50Hz陷波滤波器 + # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) + print("Applying 50Hz notch filter...") + signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) + + resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs) + resp_fs = conf["resp"]["downsample_fs_1"] + resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs) + resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20) + resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"], + low_cut=conf["resp_filter"]["low_cut"], + high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"], + sample_rate=resp_fs) + print("Begin plotting signal data...") + + + # fig = plt.figure(figsize=(12, 8)) + # # 绘制三个图raw_data、resp_data_1、resp_data_2 + # ax0 = fig.add_subplot(3, 1, 1) + # ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue') + # ax0.set_title('Raw Signal Data') + # ax1 = fig.add_subplot(3, 1, 2, sharex=ax0) + # ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange') + # ax1.set_title('Resp Data after Average Filtering') + # ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) + # ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green') + # ax2.set_title('Resp Data after Butterworth Filtering') + # plt.tight_layout() + # plt.show() + + bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"], + low_cut=conf["bcg_filter"]["low_cut"], + high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"], + sample_rate=signal_fs) + + # 降采样 + old_resp_fs = resp_fs + resp_fs = conf["resp"]["downsample_fs_2"] + resp_data = utils.downsample_signal_fast(original_signal=resp_data_2, original_fs=old_resp_fs, target_fs=resp_fs) + bcg_fs = conf["bcg"]["downsample_fs"] + bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs) + + label_data = utils.read_label_csv(path=label_path) + event_mask, score_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data) + + manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[ + all_samp_disable_df["id"] == samp_id]) + print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}") + + # 分析Resp的低幅值区间 + resp_low_amp_conf = conf.get("resp_low_amp", None) + if resp_low_amp_conf is not None: + resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal( + signal_data=resp_data, + sampling_rate=resp_fs, + **resp_low_amp_conf + ) + print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}, count_low_amp_positions: {len(resp_low_amp_position_list)}") + else: + resp_low_amp_mask, resp_low_amp_position_list = None, None + print("resp_low_amp_mask is None") + + # 分析Resp的高幅值伪迹区间 + resp_movement_conf = conf.get("resp_movement", None) + if resp_movement_conf is not None: + raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement( + signal_data=resp_data, + sampling_rate=resp_fs, + **resp_movement_conf + ) + print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") + else: + resp_movement_mask, resp_movement_position_list = None, None + print("resp_movement_mask is None") + + resp_movement_revise_conf = conf.get("resp_movement_revise", None) + if resp_movement_mask is not None and resp_movement_revise_conf is not None: + resp_movement_mask, resp_movement_position_list = signal_method.movement_revise( + signal_data=resp_data, + movement_mask=resp_movement_mask, + movement_list=resp_movement_position_list, + sampling_rate=resp_fs, + **resp_movement_revise_conf, + verbose=False + ) + print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}") + else: + print("resp_movement_mask revise is skipped") + + + # 分析Resp的幅值突变区间 + resp_amp_change_conf = conf.get("resp_amp_change", None) + if resp_amp_change_conf is not None and resp_movement_mask is not None: + resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v3( + signal_data=resp_data, + movement_mask=resp_movement_mask, + movement_list=resp_movement_position_list, + sampling_rate=resp_fs, + **resp_amp_change_conf) + print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}") + else: + resp_amp_change_mask = None + print("amp_change_mask is None") + + + + # 分析Bcg的低幅值区间 + bcg_low_amp_conf = conf.get("bcg_low_amp", None) + if bcg_low_amp_conf is not None: + bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal( + signal_data=bcg_data, + sampling_rate=bcg_fs, + **bcg_low_amp_conf + ) + print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}, count_low_amp_positions: {len(bcg_low_amp_position_list)}") + else: + bcg_low_amp_mask, bcg_low_amp_position_list = None, None + print("bcg_low_amp_mask is None") + # 分析Bcg的高幅值伪迹区间 + bcg_movement_conf = conf.get("bcg_movement", None) + if bcg_movement_conf is not None: + raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement( + signal_data=bcg_data, + sampling_rate=bcg_fs, + **bcg_movement_conf + ) + print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}, count_movement_positions: {len(bcg_movement_position_list)}") + else: + bcg_movement_mask = None + print("bcg_movement_mask is None") + # 分析Bcg的幅值突变区间 + if bcg_movement_mask is not None: + bcg_amp_change_mask, bcg_amp_change_list = signal_method.position_based_sleep_recognition_v2( + signal_data=bcg_data, + movement_mask=bcg_movement_mask, + sampling_rate=bcg_fs) + print(f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}, count_amp_change_positions: {len(bcg_amp_change_list)}") + else: + bcg_amp_change_mask = None + print("bcg_amp_change_mask is None") + + + # 如果signal_data采样率过,进行降采样 + if signal_fs == 1000: + signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100) + signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs, target_fs=100) + signal_fs = 100 + if show: + draw_tools.draw_signal_with_mask(samp_id=samp_id, + signal_data=signal_data, + signal_fs=signal_fs, + resp_data=resp_data, + resp_fs=resp_fs, + bcg_data=bcg_data, + bcg_fs=bcg_fs, + signal_disable_mask=manual_disable_mask, + resp_low_amp_mask=resp_low_amp_mask, + resp_movement_mask=resp_movement_mask, + resp_change_mask=resp_amp_change_mask, + resp_sa_mask=event_mask, + bcg_low_amp_mask=bcg_low_amp_mask, + bcg_movement_mask=bcg_movement_mask, + bcg_change_mask=bcg_amp_change_mask) + + + # 保存处理后的数据和标签 + save_samp_path = save_path / f"{samp_id}" + save_samp_path.mkdir(parents=True, exist_ok=True) + + # 复制事件文件 到保存路径 + sa_label_save_name = f"{samp_id}" + label_path.name + shutil.copyfile(label_path, save_samp_path / sa_label_save_name) + + # 新建一个dataframe,分别是秒数、SA标签,SA质量标签,禁用标签,Resp低幅值标签,Resp体动标签,Resp幅值突变标签,Bcg低幅值标签,Bcg体动标签,Bcg幅值突变标签 + save_dict = { + "Second": np.arange(signal_second), + "SA_Label": event_mask, + "SA_Score": score_mask, + "Disable_Label": manual_disable_mask, + "Resp_LowAmp_Label": resp_low_amp_mask if resp_low_amp_mask is not None else np.zeros(signal_second, dtype=int), + "Resp_Movement_Label": resp_movement_mask if resp_movement_mask is not None else np.zeros(signal_second, dtype=int), + "Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second, dtype=int), + "Bcg_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int), + "Bcg_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second, dtype=int), + "Bcg_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second, dtype=int) + } + + mask_label_save_name = f"{samp_id}_Processed_Labels.csv" + utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict) + + + + + + + +if __name__ == '__main__': + yaml_path = Path("./dataset_config/ZD5Y_config.yaml") + disable_df_path = Path("./排除区间.xlsx") + + conf = utils.load_dataset_conf(yaml_path) + select_ids = conf["select_ids"] + root_path = Path(conf["root_path"]) + save_path = Path(conf["save_path"]) + + print(f"select_ids: {select_ids}") + print(f"root_path: {root_path}") + print(f"save_path: {save_path}") + + org_signal_root_path = root_path / "OrgBCG_Aligned" + label_root_path = root_path / "Label" + + all_samp_disable_df = utils.read_disable_excel(disable_df_path) + + process_one_signal(select_ids[1], show=True) + + # for samp_id in select_ids: + # print(f"Processing sample ID: {samp_id}") + # process_one_signal(samp_id, show=False) + # print(f"Finished processing sample ID: {samp_id}\n\n") \ No newline at end of file diff --git a/dataset_config/HYS_config.yaml b/dataset_config/HYS_config.yaml index 101c1b0..e6d02dc 100644 --- a/dataset_config/HYS_config.yaml +++ b/dataset_config/HYS_config.yaml @@ -11,6 +11,7 @@ select_ids: - 960 root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS +save_path: /mnt/disk_code/marques/dataprepare/output/HYS resp: downsample_fs_1: 100 @@ -32,25 +33,25 @@ resp_low_amp: resp_movement: window_size_sec: 20 stride_sec: 1 - std_median_multiplier: 5 + std_median_multiplier: 4 compare_intervals_sec: - 60 - 120 # - 180 - interval_multiplier: 3.5 + interval_multiplier: 3 merge_gap_sec: 30 min_duration_sec: 1 resp_movement_revise: up_interval_multiplier: 3 - down_interval_multiplier: 1.5 + down_interval_multiplier: 2 compare_intervals_sec: 30 merge_gap_sec: 10 min_duration_sec: 1 resp_amp_change: - mav_calc_window_sec: 5 - threshold_amplitude: 0.1 + mav_calc_window_sec: 1 + threshold_amplitude: 0.25 threshold_energy: 0.4 diff --git a/dataset_config/ZD5Y_config.yaml b/dataset_config/ZD5Y_config.yaml new file mode 100644 index 0000000..ff479fd --- /dev/null +++ b/dataset_config/ZD5Y_config.yaml @@ -0,0 +1,88 @@ +select_ids: + - 3103 + - 3105 + - 3106 + - 3107 + - 3108 + - 3110 + - 3203 + - 3204 + - 3205 + - 3207 + - 3208 + - 3209 + - 3212 + - 3301 + - 3303 + - 3307 + - 3403 + - 3504 + +root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/ZD5Y +save_path: /mnt/disk_code/marques/dataprepare/output/ZD5Y + +resp: + downsample_fs_1: 100 + downsample_fs_2: 10 + +resp_filter: + filter_type: bandpass + low_cut: 0.01 + high_cut: 0.7 + order: 3 + +resp_low_amp: + window_size_sec: 30 + stride_sec: + amplitude_threshold: 3 + merge_gap_sec: 60 + min_duration_sec: 60 + +resp_movement: + window_size_sec: 20 + stride_sec: 1 + std_median_multiplier: 5 + compare_intervals_sec: + - 60 + - 120 +# - 180 + interval_multiplier: 3.5 + merge_gap_sec: 30 + min_duration_sec: 1 + +resp_movement_revise: + up_interval_multiplier: 3 + down_interval_multiplier: 2 + compare_intervals_sec: 30 + merge_gap_sec: 10 + min_duration_sec: 1 + +resp_amp_change: + mav_calc_window_sec: 1 + threshold_amplitude: 0.25 + threshold_energy: 0.4 + + +bcg: + downsample_fs: 100 + +bcg_filter: + filter_type: bandpass + low_cut: 1 + high_cut: 10 + order: 10 + +bcg_low_amp: + window_size_sec: 1 + stride_sec: + amplitude_threshold: 8 + merge_gap_sec: 20 + min_duration_sec: 3 + + +bcg_movement: + window_size_sec: 2 + stride_sec: + merge_gap_sec: 20 + min_duration_sec: 4 + diff --git a/draw_tools/draw_statics.py b/draw_tools/draw_statics.py index f44cde0..f94679d 100644 --- a/draw_tools/draw_statics.py +++ b/draw_tools/draw_statics.py @@ -178,7 +178,7 @@ def draw_signal_metrics(bcg_origin_signal, resp_origin_signal, bcg_no_movement_s def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, resp_fs, bcg_fs, signal_disable_mask, resp_low_amp_mask, resp_movement_mask, resp_change_mask, - resp_sa_mask, bcg_low_amp_mask, bcg_movement_mask, bcg_change_mask + resp_sa_mask, bcg_low_amp_mask, bcg_movement_mask, bcg_change_mask, show=False, save_path=None ): # 第一行绘制去工频噪声的原始信号,右侧为不可用区间标记,左侧为信号幅值纵坐标 # 第二行绘制呼吸分量,右侧低幅值、高幅值、幅值变换标记、SA标签,左侧为呼吸幅值纵坐标 @@ -292,10 +292,12 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs, ax0_twin.callbacks.connect('ylim_changed', on_lims_change) ax1_twin.callbacks.connect('ylim_changed', on_lims_change) ax2_twin.callbacks.connect('ylim_changed', on_lims_change) - - plt.tight_layout() - plt.show() + + if save_path is not None: + plt.savefig(save_path, dpi=300) + if show: + plt.show() diff --git a/signal_method/rule_base_event.py b/signal_method/rule_base_event.py index 107d55b..d13e6cb 100644 --- a/signal_method/rule_base_event.py +++ b/signal_method/rule_base_event.py @@ -169,7 +169,7 @@ def detect_movement(signal_data, sampling_rate, window_size_sec=2, stride_sec=No def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up_interval_multiplier: float, - down_interval_multiplier: float, compare_intervals_sec, merge_gap_sec, min_duration_sec): + down_interval_multiplier: float, compare_intervals_sec, merge_gap_sec, min_duration_sec, verbose=False): """ 基于标准差对已有体动掩码进行修正。 用于大尺度的体动检测后的位置精细修正 @@ -189,13 +189,13 @@ def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up compare_size = int(compare_intervals_sec // (stride_size / sampling_rate)) _, mav = calc_mav_by_slide_windows(signal_data, movement_mask=None, low_amp_mask=None, sampling_rate=sampling_rate, - window_second=2, step_second=1, - inner_window_second=1) + window_second=4, step_second=1, + inner_window_second=4) # 往左右两边取compare_size个点的mav,取平均值 for start, end in movement_list: - left_points = start - 5 - right_points = end + 10 + left_points = start - 20 + right_points = end + 20 left_values = collect_values(arr=mav, index=left_points, step=-1, limit=compare_size, mask=movement_mask) right_values = collect_values(arr=mav, index=right_points, step=1, limit=compare_size, mask=movement_mask) @@ -203,28 +203,58 @@ def movement_revise(signal_data, sampling_rate, movement_mask, movement_list, up left_value_metrics = np.median(left_values) if len(left_values) > 0 else 0 right_value_metrics = np.median(right_values) if len(right_values) > 0 else 0 - if left_value_metrics == 0: - value_metrics = right_value_metrics - elif right_value_metrics == 0: - value_metrics = left_value_metrics - else: - value_metrics = np.mean([left_value_metrics, right_value_metrics]) + # if left_value_metrics == 0: + # value_metrics = right_value_metrics + # elif right_value_metrics == 0: + # value_metrics = left_value_metrics + # else: + # value_metrics = np.mean([left_value_metrics, right_value_metrics]) + + if left_value_metrics == 0: + left_value_metrics = right_value_metrics + elif right_value_metrics == 0: + right_value_metrics = left_value_metrics + + if verbose: + print(f"Revising movement from index {start} to {end}, left_metric: {left_value_metrics:.2f}, right_metric: {right_value_metrics:.2f}") - # 逐秒遍历mav,判断是否需要修正 - # print(f"Revising movement from index {start} to {end}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") for i in range(left_points, right_points): if i < 0 or i >= len(mav): continue - # print(f"Index {i}, mav: {mav[i]:.2f}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") - if mav[i] > (value_metrics * up_interval_multiplier): + if i < start: + value_metrics = left_value_metrics + elif i > end: + value_metrics = right_value_metrics + else: + value_metrics = (left_value_metrics + right_value_metrics) / 2 + + if mav[i] > (value_metrics * up_interval_multiplier) and movement_mask[i] == 0: movement_mask[i] = 1 - # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * up_interval_multiplier:.2f}") - elif mav[i] < (value_metrics * down_interval_multiplier): + if verbose: + print(f"Normal revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_metrics * up_interval_multiplier:.2f}") + elif mav[i] < (value_metrics * down_interval_multiplier) and movement_mask[i] == 1: movement_mask[i] = 0 - # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * down_interval_multiplier:.2f}") - # else: - # print(f"No revision at index {i}, mav: {mav[i]:.2f}, up_threshold: {value_mean * up_interval_multiplier:.2f}, down_threshold: {value_mean * down_interval_multiplier:.2f}") - # + if verbose: + print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_metrics * down_interval_multiplier:.2f}") + else: + if verbose: + print(f"No revision at index {i}, mav: {mav[i]:.2f}, up_threshold: {value_metrics * up_interval_multiplier:.2f}, down_threshold: {value_metrics * down_interval_multiplier:.2f}") + # + # 逐秒遍历mav,判断是否需要修正 + # print(f"Revising movement from index {start} to {end}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") + # for i in range(left_points, right_points): + # if i < 0 or i >= len(mav): + # continue + # # print(f"Index {i}, mav: {mav[i]:.2f}, left_mean: {left_value_mean:.2f}, right_mean: {right_value_mean:.2f}, mean: {value_mean:.2f}") + # if mav[i] > (value_metrics * up_interval_multiplier): + # movement_mask[i] = 1 + # # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * up_interval_multiplier:.2f}") + # elif mav[i] < (value_metrics * down_interval_multiplier): + # movement_mask[i] = 0 + # # print(f"Movement revised at index {i}, mav: {mav[i]:.2f}, threshold: {value_mean * down_interval_multiplier:.2f}") + # # else: + # # print(f"No revision at index {i}, mav: {mav[i]:.2f}, up_threshold: {value_mean * up_interval_multiplier:.2f}, down_threshold: {value_mean * down_interval_multiplier:.2f}") + # # # 如果需要合并间隔小的体动状态 if merge_gap_sec > 0: movement_mask = merge_short_gaps(movement_mask, time_points, merge_gap_sec) @@ -520,7 +550,7 @@ def position_based_sleep_recognition_v2(signal_data, movement_mask, sampling_rat def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_list, sampling_rate, mav_calc_window_sec, - threshold_amplitude, threshold_energy): + threshold_amplitude, threshold_energy, verbose=False): """ :param threshold_energy: @@ -569,9 +599,18 @@ def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_lis def calc_mav_by_quantiles(data_segment): # 先计算所有的mav值 + if len(data_segment) % (mav_calc_window_sec * sampling_rate) != 0: + data_segment = data_segment[:-(len(data_segment) % (mav_calc_window_sec * sampling_rate))] + mav_values = np.nanmax(data_segment.reshape(-1, mav_calc_window_sec * sampling_rate), axis=0) - np.nanmin( data_segment.reshape(-1, mav_calc_window_sec * sampling_rate)) # 计算分位数 + q20 = np.nanpercentile(mav_values, 20) + q80 = np.nanpercentile(mav_values, 80) + + mav_values = mav_values[(mav_values >= q20) & (mav_values <= q80)] + mav = np.nanmean(mav_values) + return mav position_changes = np.zeros(len(signal_data) // sampling_rate, dtype=int) position_change_list = [] @@ -579,14 +618,17 @@ def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_lis pre_valid_start = valid_list[0][0] * sampling_rate pre_valid_end = valid_list[0][1] * sampling_rate - print(f"Total movement segments to analyze: {len(movement_list)}") - print(f"Total valid segments available: {len(valid_list)}") + if verbose: + print(f"Total movement segments to analyze: {len(movement_list)}") + print(f"Total valid segments available: {len(valid_list)}") for i in range(len(movement_list)): - print(f"Analyzing movement segment {i + 1}/{len(movement_list)}") + if verbose: + print(f"Analyzing movement segment {i + 1}/{len(movement_list)}") if i + 1 >= len(valid_list): - print("No more valid segments to compare. Ending analysis.") + if verbose: + print("No more valid segments to compare. Ending analysis.") break next_valid_start = valid_list[i + 1][0] * sampling_rate @@ -597,25 +639,33 @@ def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_lis # 避免过短的片段 if movement_end - movement_start <= sampling_rate: # 小于1秒的片段不考虑 - print( - f"Skipping movement segment {i + 1} due to insufficient length. movement start: {movement_start}, movement end: {movement_end}") + if verbose: + print( + f"Skipping movement segment {i + 1} due to insufficient length. movement start: {movement_start}, movement end: {movement_end}") continue # 计算前后片段的幅值和能量 - left_mav = clac_mav(signal_data_no_movement[pre_valid_start:pre_valid_end]) - right_mav = clac_mav(signal_data_no_movement[next_valid_start:next_valid_end]) - left_energy = clac_energy(signal_data_no_movement[pre_valid_start:pre_valid_end]) - right_energy = clac_energy(signal_data_no_movement[next_valid_start:next_valid_end]) + # left_mav = clac_mav(signal_data_no_movement[pre_valid_start:pre_valid_end]) + # right_mav = clac_mav(signal_data_no_movement[next_valid_start:next_valid_end]) + # left_energy = clac_energy(signal_data_no_movement[pre_valid_start:pre_valid_end]) + # right_energy = clac_energy(signal_data_no_movement[next_valid_start:next_valid_end]) + + left_mav = calc_mav_by_quantiles(signal_data_no_movement[pre_valid_start:pre_valid_end]) + right_mav = calc_mav_by_quantiles(signal_data_no_movement[next_valid_start:next_valid_end]) + # 计算幅值指标的变化率 amplitude_change = abs(right_mav - left_mav) / max(left_mav, 1e-6) - # 计算能量指标的变化率 - energy_change = abs(right_energy - left_energy) / max(left_energy, 1e-6) + # # 计算能量指标的变化率 + # energy_change = abs(right_energy - left_energy) / max(left_energy, 1e-6) - significant_change = (amplitude_change > threshold_amplitude) and (energy_change > threshold_energy) + # significant_change = (amplitude_change > threshold_amplitude) and (energy_change > threshold_energy) + significant_change = (amplitude_change > threshold_amplitude) if significant_change: - print( - f"Significant position change detected between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}") + # print( + # f"Significant position change detected between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}") + if verbose: + print(f"Significant position change detected between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}") # 记录姿势变化发生的时间点 用当前分割的体动的起始位置和结束位置表示 position_changes[movement_start:movement_end] = 1 position_change_list.append(movement_list[i]) @@ -624,8 +674,11 @@ def position_based_sleep_recognition_v3(signal_data, movement_mask, movement_lis pre_valid_end = next_valid_end else: - print( - f"No significant position change between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}") + # print( + # f"No significant position change between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}, left_energy={left_energy:.2f}, right_energy={right_energy:.2f}, energy_change={energy_change:.2f}") + if verbose: + print(f"No significant position change between segments {movement_start} and {movement_end}: left:{pre_valid_start}to{pre_valid_end} left_mav={left_mav:.2f}, right_mav={right_mav:.2f}, amplitude_change={amplitude_change:.2f}") + # 仅更新前片段 pre_valid_start = pre_valid_start pre_valid_end = next_valid_end diff --git a/utils/__init__.py b/utils/__init__.py index c89b90c..68e7772 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -2,5 +2,6 @@ from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list from .operation_tools import merge_short_gaps, remove_short_durations from .operation_tools import collect_values +from .operation_tools import save_process_label from .event_map import E2N from .signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel \ No newline at end of file diff --git a/utils/operation_tools.py b/utils/operation_tools.py index 5739097..23205be 100644 --- a/utils/operation_tools.py +++ b/utils/operation_tools.py @@ -237,4 +237,9 @@ def collect_values(arr: np.ndarray, index: int, step: int, limit: int, mask=None values.append(arr[index]) count += 1 index += step - return values \ No newline at end of file + return values + + +def save_process_label(save_path: Path, save_dict: dict): + save_df = pd.DataFrame(save_dict) + save_df.to_csv(save_path, index=False)