250 lines
12 KiB
Python
250 lines
12 KiB
Python
"""
|
||
本脚本完成对呼研所数据的处理,包含以下功能:
|
||
1. 数据读取与预处理
|
||
从传入路径中,进行数据和标签的读取,并进行初步的预处理
|
||
预处理包括为数据进行滤波、去噪等操作
|
||
2. 数据清洗与异常值处理
|
||
3. 输出清晰后的统计信息
|
||
4. 数据保存
|
||
将处理后的数据保存到指定路径,便于后续使用
|
||
主要是保存切分后的数据位置和标签
|
||
5. 可视化
|
||
提供数据处理前后的可视化对比,帮助理解数据变化
|
||
绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等
|
||
|
||
todo: 使用mask 屏蔽无用区间
|
||
|
||
|
||
# 低幅值区间规则标定与剔除
|
||
# 高幅值连续体动规则标定与剔除
|
||
# 手动标定不可用区间提剔除
|
||
"""
|
||
|
||
from pathlib import Path
|
||
|
||
import draw_tools
|
||
import utils
|
||
import numpy as np
|
||
import signal_method
|
||
import os
|
||
from matplotlib import pyplot as plt
|
||
os.environ['DISPLAY'] = "localhost:10.0"
|
||
|
||
def process_one_signal(samp_id):
|
||
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
|
||
if not signal_path:
|
||
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
|
||
signal_path = signal_path[0]
|
||
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
|
||
|
||
label_path = (label_root_path / f"{samp_id}").glob("SA Label_corrected.csv")
|
||
if not label_path:
|
||
raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}")
|
||
label_path = list(label_path)[0]
|
||
print(f"Processing Label_corrected file: {label_path}")
|
||
|
||
signal_data_raw = utils.read_signal_txt(signal_path)
|
||
signal_length = len(signal_data_raw)
|
||
print(f"signal_length: {signal_length}")
|
||
signal_fs = int(signal_path.stem.split("_")[-1])
|
||
print(f"signal_fs: {signal_fs}")
|
||
signal_second = signal_length // signal_fs
|
||
print(f"signal_second: {signal_second}")
|
||
|
||
# 根据采样率进行截断
|
||
signal_data_raw = signal_data_raw[:signal_second * signal_fs]
|
||
|
||
# 滤波
|
||
# 50Hz陷波滤波器
|
||
# signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs)
|
||
print("Applying 50Hz notch filter...")
|
||
signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs)
|
||
|
||
resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs)
|
||
resp_fs = conf["resp"]["downsample_fs_1"]
|
||
resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs)
|
||
resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20)
|
||
resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"],
|
||
low_cut=conf["resp_filter"]["low_cut"],
|
||
high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"],
|
||
sample_rate=resp_fs)
|
||
print("Begin plotting signal data...")
|
||
|
||
|
||
# fig = plt.figure(figsize=(12, 8))
|
||
# # 绘制三个图raw_data、resp_data_1、resp_data_2
|
||
# ax0 = fig.add_subplot(3, 1, 1)
|
||
# ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue')
|
||
# ax0.set_title('Raw Signal Data')
|
||
# ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
|
||
# ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange')
|
||
# ax1.set_title('Resp Data after Average Filtering')
|
||
# ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
|
||
# ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green')
|
||
# ax2.set_title('Resp Data after Butterworth Filtering')
|
||
# plt.tight_layout()
|
||
# plt.show()
|
||
|
||
bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"],
|
||
low_cut=conf["bcg_filter"]["low_cut"],
|
||
high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"],
|
||
sample_rate=signal_fs)
|
||
|
||
# 降采样
|
||
old_resp_fs = resp_fs
|
||
resp_fs = conf["resp"]["downsample_fs_2"]
|
||
resp_data = utils.downsample_signal_fast(original_signal=resp_data_2, original_fs=old_resp_fs, target_fs=resp_fs)
|
||
bcg_fs = conf["bcg"]["downsample_fs"]
|
||
bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs)
|
||
|
||
label_data = utils.read_label_csv(path=label_path)
|
||
event_mask, score_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data)
|
||
|
||
manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[
|
||
all_samp_disable_df["id"] == samp_id])
|
||
print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}")
|
||
|
||
# 分析Resp的低幅值区间
|
||
resp_low_amp_conf = conf.get("resp_low_amp", None)
|
||
if resp_low_amp_conf is not None:
|
||
resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal(
|
||
signal_data=resp_data,
|
||
sampling_rate=resp_fs,
|
||
**resp_low_amp_conf
|
||
)
|
||
print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}, count_low_amp_positions: {len(resp_low_amp_position_list)}")
|
||
else:
|
||
resp_low_amp_mask, resp_low_amp_position_list = None, None
|
||
print("resp_low_amp_mask is None")
|
||
|
||
# 分析Resp的高幅值伪迹区间
|
||
resp_movement_conf = conf.get("resp_movement", None)
|
||
if resp_movement_conf is not None:
|
||
raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement(
|
||
signal_data=resp_data,
|
||
sampling_rate=resp_fs,
|
||
**resp_movement_conf
|
||
)
|
||
print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}")
|
||
else:
|
||
resp_movement_mask = None
|
||
print("resp_movement_mask is None")
|
||
|
||
if resp_movement_mask is not None:
|
||
# 左右翻转resp_data
|
||
reverse_resp_data = resp_data[::-1]
|
||
_, resp_movement_mask_reverse, _, resp_movement_position_list_reverse = signal_method.detect_movement(
|
||
signal_data=reverse_resp_data,
|
||
sampling_rate=resp_fs,
|
||
**resp_movement_conf
|
||
)
|
||
print(f"resp_movement_mask_reverse_shape: {resp_movement_mask_reverse.shape}, num_movement_reverse: {np.sum(resp_movement_mask_reverse == 1)}, count_movement_positions_reverse: {len(resp_movement_position_list_reverse)}")
|
||
# 将resp_movement_mask_reverse翻转回来
|
||
resp_movement_mask_reverse = resp_movement_mask_reverse[::-1]
|
||
else:
|
||
resp_movement_mask_reverse = None
|
||
print("resp_movement_mask_reverse is None")
|
||
|
||
|
||
# 取交集
|
||
if resp_movement_mask is not None and resp_movement_mask_reverse is not None:
|
||
combined_resp_movement_mask = np.logical_and(resp_movement_mask, resp_movement_mask_reverse).astype(int)
|
||
resp_movement_mask = combined_resp_movement_mask
|
||
print(f"combined_resp_movement_mask_shape: {combined_resp_movement_mask.shape}, num_combined_movement: {np.sum(combined_resp_movement_mask == 1)}")
|
||
else:
|
||
print("combined_resp_movement_mask is None")
|
||
|
||
|
||
# 分析Resp的幅值突变区间
|
||
if resp_movement_mask is not None:
|
||
resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v2(
|
||
signal_data=resp_data,
|
||
movement_mask=resp_movement_mask,
|
||
sampling_rate=resp_fs)
|
||
print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}")
|
||
else:
|
||
resp_amp_change_mask = None
|
||
print("amp_change_mask is None")
|
||
|
||
|
||
|
||
# 分析Bcg的低幅值区间
|
||
bcg_low_amp_conf = conf.get("bcg_low_amp", None)
|
||
if bcg_low_amp_conf is not None:
|
||
bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal(
|
||
signal_data=bcg_data,
|
||
sampling_rate=bcg_fs,
|
||
**bcg_low_amp_conf
|
||
)
|
||
print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}, count_low_amp_positions: {len(bcg_low_amp_position_list)}")
|
||
else:
|
||
bcg_low_amp_mask, bcg_low_amp_position_list = None, None
|
||
print("bcg_low_amp_mask is None")
|
||
# 分析Bcg的高幅值伪迹区间
|
||
bcg_movement_conf = conf.get("bcg_movement", None)
|
||
if bcg_movement_conf is not None:
|
||
raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement(
|
||
signal_data=bcg_data,
|
||
sampling_rate=bcg_fs,
|
||
**bcg_movement_conf
|
||
)
|
||
print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}, count_movement_positions: {len(bcg_movement_position_list)}")
|
||
else:
|
||
bcg_movement_mask = None
|
||
print("bcg_movement_mask is None")
|
||
# 分析Bcg的幅值突变区间
|
||
if bcg_movement_mask is not None:
|
||
bcg_amp_change_mask, bcg_amp_change_list = signal_method.position_based_sleep_recognition_v2(
|
||
signal_data=bcg_data,
|
||
movement_mask=bcg_movement_mask,
|
||
sampling_rate=bcg_fs)
|
||
print(f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}, count_amp_change_positions: {len(bcg_amp_change_list)}")
|
||
else:
|
||
bcg_amp_change_mask = None
|
||
print("bcg_amp_change_mask is None")
|
||
|
||
|
||
# 如果signal_data采样率过,进行降采样
|
||
if signal_fs == 1000:
|
||
signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100)
|
||
signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs, target_fs=100)
|
||
signal_fs = 100
|
||
|
||
draw_tools.draw_signal_with_mask(samp_id=samp_id,
|
||
signal_data=signal_data,
|
||
signal_fs=signal_fs,
|
||
resp_data=resp_data,
|
||
resp_fs=resp_fs,
|
||
bcg_data=bcg_data,
|
||
bcg_fs=bcg_fs,
|
||
signal_disable_mask=manual_disable_mask,
|
||
resp_low_amp_mask=resp_low_amp_mask,
|
||
resp_movement_mask=resp_movement_mask,
|
||
resp_change_mask=resp_amp_change_mask,
|
||
resp_sa_mask=event_mask,
|
||
bcg_low_amp_mask=bcg_low_amp_mask,
|
||
bcg_movement_mask=bcg_movement_mask,
|
||
bcg_change_mask=bcg_amp_change_mask)
|
||
|
||
|
||
|
||
|
||
|
||
if __name__ == '__main__':
|
||
yaml_path = Path("./dataset_config/HYS_config.yaml")
|
||
disable_df_path = Path("./排除区间.xlsx")
|
||
|
||
conf = utils.load_dataset_conf(yaml_path)
|
||
select_ids = conf["select_ids"]
|
||
root_path = Path(conf["root_path"])
|
||
|
||
print(f"select_ids: {select_ids}")
|
||
print(f"root_path: {root_path}")
|
||
|
||
org_signal_root_path = root_path / "OrgBCG_Aligned"
|
||
label_root_path = root_path / "Label"
|
||
|
||
all_samp_disable_df = utils.read_disable_excel(disable_df_path)
|
||
|
||
process_one_signal(select_ids[5])
|