284 lines
13 KiB
Python
284 lines
13 KiB
Python
"""
|
||
本脚本完成对呼研所数据的处理,包含以下功能:
|
||
1. 数据读取与预处理
|
||
从传入路径中,进行数据和标签的读取,并进行初步的预处理
|
||
预处理包括为数据进行滤波、去噪等操作
|
||
2. 数据清洗与异常值处理
|
||
3. 输出清晰后的统计信息
|
||
4. 数据保存
|
||
将处理后的数据保存到指定路径,便于后续使用
|
||
主要是保存切分后的数据位置和标签
|
||
5. 可视化
|
||
提供数据处理前后的可视化对比,帮助理解数据变化
|
||
绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等
|
||
|
||
todo: 使用mask 屏蔽无用区间
|
||
|
||
|
||
# 低幅值区间规则标定与剔除
|
||
# 高幅值连续体动规则标定与剔除
|
||
# 手动标定不可用区间提剔除
|
||
"""
|
||
|
||
from pathlib import Path
|
||
import shutil
|
||
import draw_tools
|
||
import utils
|
||
import numpy as np
|
||
import signal_method
|
||
import os
|
||
from matplotlib import pyplot as plt
|
||
os.environ['DISPLAY'] = "localhost:10.0"
|
||
|
||
def process_one_signal(samp_id, show=False):
|
||
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
|
||
if not signal_path:
|
||
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
|
||
signal_path = signal_path[0]
|
||
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
|
||
|
||
label_path = (label_root_path / f"{samp_id}").glob("SA Label_corrected.csv")
|
||
if not label_path:
|
||
raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}")
|
||
label_path = list(label_path)[0]
|
||
print(f"Processing Label_corrected file: {label_path}")
|
||
|
||
# 保存处理后的数据和标签
|
||
save_samp_path = save_path / f"{samp_id}"
|
||
save_samp_path.mkdir(parents=True, exist_ok=True)
|
||
|
||
signal_data_raw = utils.read_signal_txt(signal_path)
|
||
signal_length = len(signal_data_raw)
|
||
print(f"signal_length: {signal_length}")
|
||
signal_fs = int(signal_path.stem.split("_")[-1])
|
||
print(f"signal_fs: {signal_fs}")
|
||
signal_second = signal_length // signal_fs
|
||
print(f"signal_second: {signal_second}")
|
||
|
||
# 根据采样率进行截断
|
||
signal_data_raw = signal_data_raw[:signal_second * signal_fs]
|
||
|
||
# 滤波
|
||
# 50Hz陷波滤波器
|
||
# signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs)
|
||
print("Applying 50Hz notch filter...")
|
||
signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs)
|
||
|
||
resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs)
|
||
resp_fs = conf["resp"]["downsample_fs_1"]
|
||
resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs)
|
||
resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20)
|
||
resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"],
|
||
low_cut=conf["resp_filter"]["low_cut"],
|
||
high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"],
|
||
sample_rate=resp_fs)
|
||
print("Begin plotting signal data...")
|
||
|
||
|
||
# fig = plt.figure(figsize=(12, 8))
|
||
# # 绘制三个图raw_data、resp_data_1、resp_data_2
|
||
# ax0 = fig.add_subplot(3, 1, 1)
|
||
# ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue')
|
||
# ax0.set_title('Raw Signal Data')
|
||
# ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
|
||
# ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange')
|
||
# ax1.set_title('Resp Data after Average Filtering')
|
||
# ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
|
||
# ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green')
|
||
# ax2.set_title('Resp Data after Butterworth Filtering')
|
||
# plt.tight_layout()
|
||
# plt.show()
|
||
|
||
bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"],
|
||
low_cut=conf["bcg_filter"]["low_cut"],
|
||
high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"],
|
||
sample_rate=signal_fs)
|
||
|
||
# 降采样
|
||
old_resp_fs = resp_fs
|
||
resp_fs = conf["resp"]["downsample_fs_2"]
|
||
resp_data = utils.downsample_signal_fast(original_signal=resp_data_2, original_fs=old_resp_fs, target_fs=resp_fs)
|
||
bcg_fs = conf["bcg"]["downsample_fs"]
|
||
bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs)
|
||
|
||
label_data = utils.read_label_csv(path=label_path)
|
||
event_mask, score_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data)
|
||
|
||
manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[
|
||
all_samp_disable_df["id"] == samp_id])
|
||
print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}")
|
||
|
||
# 分析Resp的低幅值区间
|
||
resp_low_amp_conf = conf.get("resp_low_amp", None)
|
||
if resp_low_amp_conf is not None:
|
||
resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal(
|
||
signal_data=resp_data,
|
||
sampling_rate=resp_fs,
|
||
**resp_low_amp_conf
|
||
)
|
||
print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}, count_low_amp_positions: {len(resp_low_amp_position_list)}")
|
||
else:
|
||
resp_low_amp_mask, resp_low_amp_position_list = None, None
|
||
print("resp_low_amp_mask is None")
|
||
|
||
# 分析Resp的高幅值伪迹区间
|
||
resp_movement_conf = conf.get("resp_movement", None)
|
||
if resp_movement_conf is not None:
|
||
raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement(
|
||
signal_data=resp_data,
|
||
sampling_rate=resp_fs,
|
||
**resp_movement_conf
|
||
)
|
||
print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}")
|
||
else:
|
||
resp_movement_mask, resp_movement_position_list = None, None
|
||
print("resp_movement_mask is None")
|
||
|
||
resp_movement_revise_conf = conf.get("resp_movement_revise", None)
|
||
if resp_movement_mask is not None and resp_movement_revise_conf is not None:
|
||
resp_movement_mask, resp_movement_position_list = signal_method.movement_revise(
|
||
signal_data=resp_data,
|
||
movement_mask=resp_movement_mask,
|
||
movement_list=resp_movement_position_list,
|
||
sampling_rate=resp_fs,
|
||
**resp_movement_revise_conf,
|
||
verbose=False
|
||
)
|
||
print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}")
|
||
else:
|
||
print("resp_movement_mask revise is skipped")
|
||
|
||
|
||
# 分析Resp的幅值突变区间
|
||
resp_amp_change_conf = conf.get("resp_amp_change", None)
|
||
if resp_amp_change_conf is not None and resp_movement_mask is not None:
|
||
resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v3(
|
||
signal_data=resp_data,
|
||
movement_mask=resp_movement_mask,
|
||
movement_list=resp_movement_position_list,
|
||
sampling_rate=resp_fs,
|
||
**resp_amp_change_conf,
|
||
verbose=True)
|
||
print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}")
|
||
else:
|
||
resp_amp_change_mask = None
|
||
print("amp_change_mask is None")
|
||
|
||
|
||
|
||
# 分析Bcg的低幅值区间
|
||
bcg_low_amp_conf = conf.get("bcg_low_amp", None)
|
||
if bcg_low_amp_conf is not None:
|
||
bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal(
|
||
signal_data=bcg_data,
|
||
sampling_rate=bcg_fs,
|
||
**bcg_low_amp_conf
|
||
)
|
||
print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}, count_low_amp_positions: {len(bcg_low_amp_position_list)}")
|
||
else:
|
||
bcg_low_amp_mask, bcg_low_amp_position_list = None, None
|
||
print("bcg_low_amp_mask is None")
|
||
# 分析Bcg的高幅值伪迹区间
|
||
bcg_movement_conf = conf.get("bcg_movement", None)
|
||
if bcg_movement_conf is not None:
|
||
raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement(
|
||
signal_data=bcg_data,
|
||
sampling_rate=bcg_fs,
|
||
**bcg_movement_conf
|
||
)
|
||
print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}, count_movement_positions: {len(bcg_movement_position_list)}")
|
||
else:
|
||
bcg_movement_mask = None
|
||
print("bcg_movement_mask is None")
|
||
# 分析Bcg的幅值突变区间
|
||
if bcg_movement_mask is not None:
|
||
bcg_amp_change_mask, bcg_amp_change_list = signal_method.position_based_sleep_recognition_v2(
|
||
signal_data=bcg_data,
|
||
movement_mask=bcg_movement_mask,
|
||
sampling_rate=bcg_fs)
|
||
print(f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}, count_amp_change_positions: {len(bcg_amp_change_list)}")
|
||
else:
|
||
bcg_amp_change_mask = None
|
||
print("bcg_amp_change_mask is None")
|
||
|
||
|
||
# 如果signal_data采样率过,进行降采样
|
||
if signal_fs == 1000:
|
||
signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100)
|
||
signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs, target_fs=100)
|
||
signal_fs = 100
|
||
|
||
draw_tools.draw_signal_with_mask(samp_id=samp_id,
|
||
signal_data=signal_data,
|
||
signal_fs=signal_fs,
|
||
resp_data=resp_data,
|
||
resp_fs=resp_fs,
|
||
bcg_data=bcg_data,
|
||
bcg_fs=bcg_fs,
|
||
signal_disable_mask=manual_disable_mask,
|
||
resp_low_amp_mask=resp_low_amp_mask,
|
||
resp_movement_mask=resp_movement_mask,
|
||
resp_change_mask=resp_amp_change_mask,
|
||
resp_sa_mask=event_mask,
|
||
bcg_low_amp_mask=bcg_low_amp_mask,
|
||
bcg_movement_mask=bcg_movement_mask,
|
||
bcg_change_mask=bcg_amp_change_mask,
|
||
show=show,
|
||
save_path=save_samp_path / f"{samp_id}_Signal_Plots.png")
|
||
|
||
|
||
|
||
|
||
# 复制事件文件 到保存路径
|
||
sa_label_save_name = f"{samp_id}" + label_path.name
|
||
shutil.copyfile(label_path, save_samp_path / sa_label_save_name)
|
||
|
||
# 新建一个dataframe,分别是秒数、SA标签,SA质量标签,禁用标签,Resp低幅值标签,Resp体动标签,Resp幅值突变标签,Bcg低幅值标签,Bcg体动标签,Bcg幅值突变标签
|
||
save_dict = {
|
||
"Second": np.arange(signal_second),
|
||
"SA_Label": event_mask,
|
||
"SA_Score": score_mask,
|
||
"Disable_Label": manual_disable_mask,
|
||
"Resp_LowAmp_Label": resp_low_amp_mask if resp_low_amp_mask is not None else np.zeros(signal_second, dtype=int),
|
||
"Resp_Movement_Label": resp_movement_mask if resp_movement_mask is not None else np.zeros(signal_second, dtype=int),
|
||
"Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second, dtype=int),
|
||
"Bcg_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int),
|
||
"Bcg_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second, dtype=int),
|
||
"Bcg_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second, dtype=int)
|
||
}
|
||
|
||
mask_label_save_name = f"{samp_id}_Processed_Labels.csv"
|
||
utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict)
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
if __name__ == '__main__':
|
||
yaml_path = Path("./dataset_config/HYS_config.yaml")
|
||
disable_df_path = Path("./排除区间.xlsx")
|
||
|
||
conf = utils.load_dataset_conf(yaml_path)
|
||
select_ids = conf["select_ids"]
|
||
root_path = Path(conf["root_path"])
|
||
save_path = Path(conf["save_path"])
|
||
|
||
print(f"select_ids: {select_ids}")
|
||
print(f"root_path: {root_path}")
|
||
print(f"save_path: {save_path}")
|
||
|
||
org_signal_root_path = root_path / "OrgBCG_Aligned"
|
||
label_root_path = root_path / "Label"
|
||
|
||
all_samp_disable_df = utils.read_disable_excel(disable_df_path)
|
||
|
||
process_one_signal(select_ids[9], show=True)
|
||
|
||
# for samp_id in select_ids:
|
||
# print(f"Processing sample ID: {samp_id}")
|
||
# process_one_signal(samp_id, show=False)
|
||
# print(f"Finished processing sample ID: {samp_id}\n\n")
|
||
|