DataPrepare/HYS_process.py

224 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
本脚本完成对呼研所数据的处理,包含以下功能:
1. 数据读取与预处理
从传入路径中,进行数据和标签的读取,并进行初步的预处理
预处理包括为数据进行滤波、去噪等操作
2. 数据清洗与异常值处理
3. 输出清晰后的统计信息
4. 数据保存
将处理后的数据保存到指定路径,便于后续使用
主要是保存切分后的数据位置和标签
5. 可视化
提供数据处理前后的可视化对比,帮助理解数据变化
绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等
todo: 使用mask 屏蔽无用区间
# 低幅值区间规则标定与剔除
# 高幅值连续体动规则标定与剔除
# 手动标定不可用区间提剔除
"""
from pathlib import Path
import draw_tools
import utils
import numpy as np
import signal_method
import os
from matplotlib import pyplot as plt
os.environ['DISPLAY'] = "localhost:11.0"
def process_one_signal(samp_id):
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
if not signal_path:
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
signal_path = signal_path[0]
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
label_path = (label_root_path / f"{samp_id}").glob("SA Label_corrected.csv")
if not label_path:
raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}")
label_path = list(label_path)[0]
print(f"Processing Label_corrected file: {label_path}")
signal_data = utils.read_signal_txt(signal_path)
signal_length = len(signal_data)
print(f"signal_length: {signal_length}")
signal_fs = int(signal_path.stem.split("_")[-1])
print(f"signal_fs: {signal_fs}")
signal_second = signal_length // signal_fs
print(f"signal_second: {signal_second}")
# 根据采样率进行截断
signal_data = signal_data[:signal_second * signal_fs]
# 滤波
# 50Hz陷波滤波器
# signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs)
print("Applying 50Hz notch filter...")
signal_data = utils.notch_filter(data=signal_data, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs)
resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs)
resp_fs = conf["resp"]["downsample_fs_1"]
resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs)
resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20)
resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"],
low_cut=conf["resp_filter"]["low_cut"],
high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"],
sample_rate=resp_fs)
print("Begin plotting signal data...")
# fig = plt.figure(figsize=(12, 8))
# # 绘制三个图raw_data、resp_data_1、resp_data_2
# ax0 = fig.add_subplot(3, 1, 1)
# ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue')
# ax0.set_title('Raw Signal Data')
# ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
# ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange')
# ax1.set_title('Resp Data after Average Filtering')
# ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
# ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green')
# ax2.set_title('Resp Data after Butterworth Filtering')
# plt.tight_layout()
# plt.show()
bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"],
low_cut=conf["bcg_filter"]["low_cut"],
high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"],
sample_rate=signal_fs)
# 降采样
old_resp_fs = resp_fs
resp_fs = conf["resp"]["downsample_fs_2"]
resp_data = utils.downsample_signal_fast(original_signal=resp_data_2, original_fs=old_resp_fs, target_fs=resp_fs)
bcg_fs = conf["bcg"]["downsample_fs"]
bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs)
label_data = utils.read_label_csv(path=label_path)
event_mask, score_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data)
manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[
all_samp_disable_df["id"] == samp_id])
print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}")
# 分析Resp的低幅值区间
resp_low_amp_conf = conf.get("resp_low_amp", None)
if resp_low_amp_conf is not None:
resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal(
signal_data=resp_data,
sampling_rate=resp_fs,
**resp_low_amp_conf
)
print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}")
else:
resp_low_amp_mask, resp_low_amp_position_list = None, None
print("resp_low_amp_mask is None")
# 分析Resp的高幅值伪迹区间
resp_movement_conf = conf.get("resp_movement", None)
if resp_movement_conf is not None:
raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement(
signal_data=resp_data,
sampling_rate=resp_fs,
**resp_movement_conf
)
print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}")
else:
resp_movement_mask = None
print("resp_movement_mask is None")
# 分析Resp的幅值突变区间
if resp_movement_mask is not None:
resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v2(
signal_data=resp_data,
movement_mask=resp_movement_mask,
sampling_rate=resp_fs)
print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}")
else:
resp_amp_change_mask = None
print("amp_change_mask is None")
# 分析Bcg的低幅值区间
bcg_low_amp_conf = conf.get("bcg_low_amp", None)
if bcg_low_amp_conf is not None:
bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal(
signal_data=bcg_data,
sampling_rate=bcg_fs,
**bcg_low_amp_conf
)
print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}")
else:
bcg_low_amp_mask, bcg_low_amp_position_list = None, None
print("bcg_low_amp_mask is None")
# 分析Bcg的高幅值伪迹区间
bcg_movement_conf = conf.get("bcg_movement", None)
if bcg_movement_conf is not None:
raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement(
signal_data=bcg_data,
sampling_rate=bcg_fs,
**bcg_movement_conf
)
print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}")
else:
bcg_movement_mask = None
print("bcg_movement_mask is None")
# 分析Bcg的幅值突变区间
if bcg_movement_mask is not None:
bcg_amp_change_mask, bcg_amp_change_list = signal_method.position_based_sleep_recognition_v2(
signal_data=bcg_data,
movement_mask=bcg_movement_mask,
sampling_rate=bcg_fs)
print(f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}")
else:
bcg_amp_change_mask = None
print("bcg_amp_change_mask is None")
# 如果signal_data采样率过进行降采样
if signal_fs == 1000:
signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100)
signal_fs = 100
draw_tools.draw_signal_with_mask(samp_id=samp_id,
signal_data=signal_data,
signal_fs=signal_fs,
resp_data=resp_data,
resp_fs=resp_fs,
bcg_data=bcg_data,
bcg_fs=bcg_fs,
signal_disable_mask=manual_disable_mask,
resp_low_amp_mask=resp_low_amp_mask,
resp_movement_mask=resp_movement_mask,
resp_change_mask=resp_amp_change_mask,
resp_sa_mask=event_mask,
bcg_low_amp_mask=bcg_low_amp_mask,
bcg_movement_mask=bcg_movement_mask,
bcg_change_mask=bcg_amp_change_mask)
if __name__ == '__main__':
yaml_path = Path("./dataset_config/HYS_config.yaml")
disable_df_path = Path("./排除区间.xlsx")
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
print(f"select_ids: {select_ids}")
print(f"root_path: {root_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned"
label_root_path = root_path / "Label"
all_samp_disable_df = utils.read_disable_excel(disable_df_path)
process_one_signal(select_ids[0])