DataPrepare/event_mask_process/HYS_process.py

247 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
本脚本完成对呼研所数据的处理,包含以下功能:
1. 数据读取与预处理
从传入路径中,进行数据和标签的读取,并进行初步的预处理
预处理包括为数据进行滤波、去噪等操作
2. 数据清洗与异常值处理
3. 输出清晰后的统计信息
4. 数据保存
将处理后的数据保存到指定路径,便于后续使用
主要是保存切分后的数据位置和标签
5. 可视化
提供数据处理前后的可视化对比,帮助理解数据变化
绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等
# 低幅值区间规则标定与剔除
# 高幅值连续体动规则标定与剔除
# 手动标定不可用区间提剔除
"""
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import shutil
import draw_tools
import utils
import numpy as np
import signal_method
import os
os.environ['DISPLAY'] = "localhost:10.0"
def process_one_signal(samp_id, show=False):
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
if not signal_path:
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
signal_path = signal_path[0]
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
label_path = (label_root_path / f"{samp_id}").glob("SA Label_corrected.csv")
if not label_path:
raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}")
label_path = list(label_path)[0]
print(f"Processing Label_corrected file: {label_path}")
# 保存处理后的数据和标签
save_samp_path = save_path / f"{samp_id}"
save_samp_path.mkdir(parents=True, exist_ok=True)
signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, verbose=True)
signal_data, resp_data, resp_fs, bcg_data, bcg_fs = signal_method.signal_filter_split(conf=conf, signal_data_raw=signal_data_raw, signal_fs=signal_fs)
# 降采样
old_resp_fs = resp_fs
resp_fs = conf["resp"]["downsample_fs_2"]
resp_data = utils.downsample_signal_fast(original_signal=resp_data, original_fs=old_resp_fs, target_fs=resp_fs)
bcg_fs = conf["bcg"]["downsample_fs"]
bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs)
label_data = utils.read_label_csv(path=label_path)
event_mask, score_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data)
manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[
all_samp_disable_df["id"] == samp_id])
print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}")
# 分析Resp的低幅值区间
resp_low_amp_conf = conf.get("resp_low_amp", None)
if resp_low_amp_conf is not None:
resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal(
signal_data=resp_data,
sampling_rate=resp_fs,
**resp_low_amp_conf
)
print(
f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}, count_low_amp_positions: {len(resp_low_amp_position_list)}")
else:
resp_low_amp_mask, resp_low_amp_position_list = None, None
print("resp_low_amp_mask is None")
# 分析Resp的高幅值伪迹区间
resp_movement_conf = conf.get("resp_movement", None)
if resp_movement_conf is not None:
raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement(
signal_data=resp_data,
sampling_rate=resp_fs,
**resp_movement_conf
)
print(
f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}")
else:
resp_movement_mask, resp_movement_position_list = None, None
print("resp_movement_mask is None")
resp_movement_revise_conf = conf.get("resp_movement_revise", None)
if resp_movement_mask is not None and resp_movement_revise_conf is not None:
resp_movement_mask, resp_movement_position_list = signal_method.movement_revise(
signal_data=resp_data,
movement_mask=resp_movement_mask,
movement_list=resp_movement_position_list,
sampling_rate=resp_fs,
**resp_movement_revise_conf,
verbose=False
)
print(
f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}")
else:
print("resp_movement_mask revise is skipped")
# 分析Resp的幅值突变区间
resp_amp_change_conf = conf.get("resp_amp_change", None)
if resp_amp_change_conf is not None and resp_movement_mask is not None:
resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v3(
signal_data=resp_data,
movement_mask=resp_movement_mask,
movement_list=resp_movement_position_list,
sampling_rate=resp_fs,
**resp_amp_change_conf,
verbose=False)
print(
f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}")
else:
resp_amp_change_mask = None
print("amp_change_mask is None")
# 分析Bcg的低幅值区间
bcg_low_amp_conf = conf.get("bcg_low_amp", None)
if bcg_low_amp_conf is not None:
bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal(
signal_data=bcg_data,
sampling_rate=bcg_fs,
**bcg_low_amp_conf
)
print(
f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}, count_low_amp_positions: {len(bcg_low_amp_position_list)}")
else:
bcg_low_amp_mask, bcg_low_amp_position_list = None, None
print("bcg_low_amp_mask is None")
# 分析Bcg的高幅值伪迹区间
bcg_movement_conf = conf.get("bcg_movement", None)
if bcg_movement_conf is not None:
raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement(
signal_data=bcg_data,
sampling_rate=bcg_fs,
**bcg_movement_conf
)
print(
f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}, count_movement_positions: {len(bcg_movement_position_list)}")
else:
bcg_movement_mask = None
print("bcg_movement_mask is None")
# 分析Bcg的幅值突变区间
if bcg_movement_mask is not None:
bcg_amp_change_mask, bcg_amp_change_list = signal_method.position_based_sleep_recognition_v2(
signal_data=bcg_data,
movement_mask=bcg_movement_mask,
sampling_rate=bcg_fs)
print(
f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}, count_amp_change_positions: {len(bcg_amp_change_list)}")
else:
bcg_amp_change_mask = None
print("bcg_amp_change_mask is None")
# 如果signal_data采样率过进行降采样
if signal_fs == 1000:
signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100)
signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs,
target_fs=100)
signal_fs = 100
draw_tools.draw_signal_with_mask(samp_id=samp_id,
signal_data=signal_data,
signal_fs=signal_fs,
resp_data=resp_data,
resp_fs=resp_fs,
bcg_data=bcg_data,
bcg_fs=bcg_fs,
signal_disable_mask=manual_disable_mask,
resp_low_amp_mask=resp_low_amp_mask,
resp_movement_mask=resp_movement_mask,
resp_change_mask=resp_amp_change_mask,
resp_sa_mask=event_mask,
bcg_low_amp_mask=bcg_low_amp_mask,
bcg_movement_mask=bcg_movement_mask,
bcg_change_mask=bcg_amp_change_mask,
show=show,
save_path=save_samp_path / f"{samp_id}_Signal_Plots.png")
# 复制事件文件 到保存路径
sa_label_save_name = f"{samp_id}_" + label_path.name
shutil.copyfile(label_path, save_samp_path / sa_label_save_name)
# 新建一个dataframe分别是秒数、SA标签SA质量标签禁用标签Resp低幅值标签Resp体动标签Resp幅值突变标签Bcg低幅值标签Bcg体动标签Bcg幅值突变标签
save_dict = {
"Second": np.arange(signal_second),
"SA_Label": event_mask,
"SA_Score": score_mask,
"Disable_Label": manual_disable_mask,
"Resp_LowAmp_Label": resp_low_amp_mask if resp_low_amp_mask is not None else np.zeros(signal_second, dtype=int),
"Resp_Movement_Label": resp_movement_mask if resp_movement_mask is not None else np.zeros(signal_second,
dtype=int),
"Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second,
dtype=int),
"BCG_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int),
"BCG_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second,
dtype=int),
"BCG_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second,
dtype=int)
}
mask_label_save_name = f"{samp_id}_Processed_Labels.csv"
utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict)
if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/HYS_config.yaml"
disable_df_path = project_root_path / "排除区间.xlsx"
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
save_path = Path(conf["mask_save_path"])
print(f"select_ids: {select_ids}")
print(f"root_path: {root_path}")
print(f"save_path: {save_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned"
label_root_path = root_path / "Label"
all_samp_disable_df = utils.read_disable_excel(disable_df_path)
# process_one_signal(select_ids[6], show=True)
#
for samp_id in select_ids:
print(f"Processing sample ID: {samp_id}")
process_one_signal(samp_id, show=False)
print(f"Finished processing sample ID: {samp_id}\n\n")