更新数据处理模块,添加信号标准化和绘图功能,重构部分函数以提高可读性
This commit is contained in:
parent
1a0761c6c8
commit
ed4205f5b8
2
.gitignore
vendored
2
.gitignore
vendored
@ -253,3 +253,5 @@ cython_debug/
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
output/*
|
||||
!output/
|
||||
|
||||
@ -0,0 +1,166 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
os.environ['DISPLAY'] = "localhost:10.0"
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
project_root_path = Path(__file__).resolve().parent.parent
|
||||
|
||||
import utils
|
||||
import signal_method
|
||||
import draw_tools
|
||||
import shutil
|
||||
|
||||
|
||||
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
|
||||
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
|
||||
if not signal_path:
|
||||
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
|
||||
signal_path = signal_path[0]
|
||||
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
|
||||
|
||||
mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv")
|
||||
print(f"mask_excel_path: {mask_excel_path}")
|
||||
|
||||
event_mask, event_list = utils.read_mask_execl(mask_excel_path)
|
||||
|
||||
bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float)
|
||||
|
||||
bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs)
|
||||
normalized_resp_signal = signal_method.normalize_resp_signal(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"])
|
||||
|
||||
|
||||
# 如果signal_data采样率过,进行降采样
|
||||
if signal_fs == 1000:
|
||||
bcg_signal_notch = utils.downsample_signal_fast(original_signal=bcg_signal_notch, original_fs=signal_fs, target_fs=100)
|
||||
bcg_signal_raw = utils.downsample_signal_fast(original_signal=bcg_signal_raw, original_fs=signal_fs,
|
||||
target_fs=100)
|
||||
signal_fs = 100
|
||||
|
||||
if bcg_fs == 1000:
|
||||
bcg_signal = utils.downsample_signal_fast(original_signal=bcg_signal, original_fs=bcg_fs, target_fs=100)
|
||||
bcg_fs = 100
|
||||
|
||||
# draw_tools.draw_signal_with_mask(samp_id=samp_id,
|
||||
# signal_data=resp_signal,
|
||||
# signal_fs=resp_fs,
|
||||
# resp_data=normalized_resp_signal,
|
||||
# resp_fs=resp_fs,
|
||||
# bcg_data=bcg_signal,
|
||||
# bcg_fs=bcg_fs,
|
||||
# signal_disable_mask=event_mask["Disable_Label"],
|
||||
# resp_low_amp_mask=event_mask["Resp_LowAmp_Label"],
|
||||
# resp_movement_mask=event_mask["Resp_Movement_Label"],
|
||||
# resp_change_mask=event_mask["Resp_AmpChange_Label"],
|
||||
# resp_sa_mask=event_mask["SA_Label"],
|
||||
# bcg_low_amp_mask=event_mask["BCG_LowAmp_Label"],
|
||||
# bcg_movement_mask=event_mask["BCG_Movement_Label"],
|
||||
# bcg_change_mask=event_mask["BCG_AmpChange_Label"],
|
||||
# show=show,
|
||||
# save_path=None)
|
||||
|
||||
segment_list = utils.resp_split(dataset_config, event_mask, event_list)
|
||||
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
|
||||
|
||||
|
||||
# 复制mask到processed_Labels文件夹
|
||||
save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv"
|
||||
shutil.copyfile(mask_excel_path, save_mask_excel_path)
|
||||
|
||||
# 复制SA Label_corrected.csv到processed_Labels文件夹
|
||||
sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_corrected.csv")
|
||||
if sa_label_corrected_path.exists():
|
||||
save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv"
|
||||
shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path)
|
||||
else:
|
||||
print(f"Warning: {sa_label_corrected_path} does not exist.")
|
||||
|
||||
# 保存处理后的信号和截取的片段列表
|
||||
save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz"
|
||||
save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz"
|
||||
|
||||
bcg_data = {
|
||||
"bcg_signal_notch": {
|
||||
"name": "BCG_Signal_Notch",
|
||||
"data": bcg_signal_notch,
|
||||
"fs": signal_fs,
|
||||
"length": len(bcg_signal_notch),
|
||||
"second": len(bcg_signal_notch) // signal_fs
|
||||
},
|
||||
"bcg_signal":{
|
||||
"name": "BCG_Signal_Raw",
|
||||
"data": bcg_signal,
|
||||
"fs": bcg_fs,
|
||||
"length": len(bcg_signal),
|
||||
"second": len(bcg_signal) // bcg_fs
|
||||
},
|
||||
"resp_signal": {
|
||||
"name": "Resp_Signal",
|
||||
"data": normalized_resp_signal,
|
||||
"fs": resp_fs,
|
||||
"length": len(normalized_resp_signal),
|
||||
"second": len(normalized_resp_signal) // resp_fs
|
||||
}
|
||||
}
|
||||
|
||||
np.savez_compressed(save_signal_path, **bcg_data)
|
||||
np.savez_compressed(save_segment_path,
|
||||
segment_list=segment_list)
|
||||
print(f"Saved processed signals to: {save_signal_path}")
|
||||
print(f"Saved segment list to: {save_segment_path}")
|
||||
|
||||
if draw_segment:
|
||||
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8])
|
||||
psg_data["HR"] = {
|
||||
"name": "HR",
|
||||
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"]),
|
||||
"fs": psg_data["ECG_Sync"]["fs"],
|
||||
"length": psg_data["ECG_Sync"]["length"],
|
||||
"second": psg_data["ECG_Sync"]["second"]
|
||||
}
|
||||
|
||||
|
||||
psg_label = utils.read_psg_label(sa_label_corrected_path)
|
||||
psg_event_mask, _ = utils.generate_event_mask(event_df=psg_label, signal_second=psg_data["ECG_Sync"]["second"], use_correct=False)
|
||||
draw_tools.draw_psg_bcg_label(psg_data=psg_data,
|
||||
psg_label=psg_event_mask,
|
||||
bcg_data=bcg_data,
|
||||
event_mask=event_mask,
|
||||
segment_list=segment_list)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
yaml_path = project_root_path / "dataset_config/HYS_config.yaml"
|
||||
|
||||
conf = utils.load_dataset_conf(yaml_path)
|
||||
select_ids = conf["select_ids"]
|
||||
root_path = Path(conf["root_path"])
|
||||
mask_path = Path(conf["mask_save_path"])
|
||||
save_path = Path(conf["dataset_config"]["dataset_save_path"])
|
||||
dataset_config = conf["dataset_config"]
|
||||
|
||||
save_processed_signal_path = save_path / "Signals"
|
||||
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_segment_list_path = save_path / "Segments_List"
|
||||
save_segment_list_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_processed_label_path = save_path / "Labels"
|
||||
save_processed_label_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"select_ids: {select_ids}")
|
||||
print(f"root_path: {root_path}")
|
||||
print(f"save_path: {save_path}")
|
||||
|
||||
org_signal_root_path = root_path / "OrgBCG_Aligned"
|
||||
psg_signal_root_path = root_path / "PSG_Aligned"
|
||||
|
||||
build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True)
|
||||
#
|
||||
# for samp_id in select_ids:
|
||||
# print(f"Processing sample ID: {samp_id}")
|
||||
# build_HYS_dataset_segment(samp_id, show=False)
|
||||
@ -11,7 +11,7 @@ select_ids:
|
||||
- 960
|
||||
|
||||
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
|
||||
save_path: /mnt/disk_code/marques/dataprepare/output/HYS
|
||||
mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS
|
||||
|
||||
resp:
|
||||
downsample_fs_1: 100
|
||||
@ -43,11 +43,11 @@ resp_movement:
|
||||
min_duration_sec: 1
|
||||
|
||||
resp_movement_revise:
|
||||
up_interval_multiplier: 3
|
||||
down_interval_multiplier: 2
|
||||
compare_intervals_sec: 30
|
||||
merge_gap_sec: 10
|
||||
min_duration_sec: 1
|
||||
up_interval_multiplier: 3
|
||||
down_interval_multiplier: 2
|
||||
compare_intervals_sec: 30
|
||||
merge_gap_sec: 10
|
||||
min_duration_sec: 1
|
||||
|
||||
resp_amp_change:
|
||||
mav_calc_window_sec: 4
|
||||
@ -56,7 +56,7 @@ resp_amp_change:
|
||||
|
||||
|
||||
bcg:
|
||||
downsample_fs: 100
|
||||
downsample_fs: 100
|
||||
|
||||
bcg_filter:
|
||||
filter_type: bandpass
|
||||
@ -73,8 +73,13 @@ bcg_low_amp:
|
||||
|
||||
|
||||
bcg_movement:
|
||||
window_size_sec: 2
|
||||
stride_sec:
|
||||
merge_gap_sec: 20
|
||||
min_duration_sec: 4
|
||||
window_size_sec: 2
|
||||
stride_sec:
|
||||
merge_gap_sec: 20
|
||||
min_duration_sec: 4
|
||||
|
||||
|
||||
dataset_config:
|
||||
window_sec: 180
|
||||
stride_sec: 60
|
||||
dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset
|
||||
|
||||
@ -1 +1,2 @@
|
||||
from .draw_statics import draw_signal_with_mask
|
||||
from .draw_statics import draw_signal_with_mask
|
||||
from .draw_label import draw_psg_bcg_label, draw_resp_label
|
||||
230
draw_tools/draw_label.py
Normal file
230
draw_tools/draw_label.py
Normal file
@ -0,0 +1,230 @@
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.gridspec import GridSpec
|
||||
from matplotlib.colors import PowerNorm
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
import seaborn as sns
|
||||
import numpy as np
|
||||
|
||||
import utils
|
||||
|
||||
# 添加with_prediction参数
|
||||
|
||||
psg_chn_name2ax = {
|
||||
"SpO2": 0,
|
||||
"Flow T": 1,
|
||||
"Flow P": 2,
|
||||
"Effort Tho": 3,
|
||||
"Effort Abd": 4,
|
||||
"HR": 5,
|
||||
"resp": 6,
|
||||
"bcg": 7,
|
||||
"Stage": 8
|
||||
}
|
||||
|
||||
resp_chn_name2ax = {
|
||||
"resp": 0,
|
||||
"bcg": 1,
|
||||
}
|
||||
|
||||
|
||||
def create_psg_bcg_figure():
|
||||
fig = plt.figure(figsize=(12, 8), dpi=100)
|
||||
gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1])
|
||||
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
|
||||
axes = []
|
||||
for i in range(9):
|
||||
ax = fig.add_subplot(gs[i])
|
||||
axes.append(ax)
|
||||
|
||||
axes[0].grid(True)
|
||||
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[0].set_ylim((85, 100))
|
||||
axes[0].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[1].grid(True)
|
||||
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[1].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[2].grid(True)
|
||||
# axes[2].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[2].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[3].grid(True)
|
||||
# axes[3].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[3].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[4].grid(True)
|
||||
# axes[4].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[4].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[5].grid(True)
|
||||
axes[5].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[6].grid(True)
|
||||
# axes[5].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[6].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[7].grid(True)
|
||||
# axes[6].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[7].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[8].grid(True)
|
||||
# axes[7].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
|
||||
return fig, axes
|
||||
|
||||
|
||||
def create_resp_figure():
|
||||
fig = plt.figure(figsize=(12, 6), dpi=100)
|
||||
gs = GridSpec(2, 1, height_ratios=[3, 2])
|
||||
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
|
||||
axes = []
|
||||
for i in range(2):
|
||||
ax = fig.add_subplot(gs[i])
|
||||
axes.append(ax)
|
||||
|
||||
axes[0].grid(True)
|
||||
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[0].tick_params(axis='x', colors="white")
|
||||
|
||||
axes[1].grid(True)
|
||||
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
axes[1].tick_params(axis='x', colors="white")
|
||||
|
||||
return fig, axes
|
||||
|
||||
|
||||
def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None,
|
||||
event_codes: list[int] = None, multi_labels=None):
|
||||
signal_fs = signal_data["fs"]
|
||||
chn_signal = signal_data["data"]
|
||||
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs)
|
||||
ax.plot(time_axis, chn_signal[segment_start * signal_fs:segment_end * signal_fs], color='black',
|
||||
label=signal_data["name"])
|
||||
if event_mask is not None:
|
||||
if multi_labels is None and event_codes is not None:
|
||||
for event_code in event_codes:
|
||||
mask = event_mask[segment_start:segment_end].repeat(signal_fs) == event_code
|
||||
y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * mask).astype(float)
|
||||
np.place(y, y == 0, np.nan)
|
||||
ax.plot(time_axis, y, color=utils.ColorCycle[event_code])
|
||||
elif multi_labels == "resp" and event_codes is not None:
|
||||
ax.set_ylim(-6, 6)
|
||||
# 建立第二个y轴坐标
|
||||
ax2 = ax.twinx()
|
||||
ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
|
||||
color='blue', alpha=0.8, label='Low Amplitude Mask')
|
||||
ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
|
||||
color='orange', alpha=0.8, label='Movement Mask')
|
||||
ax2.plot(time_axis, event_mask["Resp_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3,
|
||||
color='green', alpha=0.8, label='Amplitude Change Mask')
|
||||
for event_code in event_codes:
|
||||
sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code
|
||||
score_mask = event_mask["SA_Score_Alpha"][segment_start:segment_end].repeat(signal_fs)
|
||||
y = (sa_mask * score_mask).astype(float)
|
||||
np.place(y, y == 0, np.nan)
|
||||
ax2.plot(time_axis, y, color=utils.ColorCycle[event_code])
|
||||
ax2.set_ylim(-4, 5)
|
||||
elif multi_labels == "bcg" and event_codes is not None:
|
||||
# 建立第二个y轴坐标
|
||||
ax2 = ax.twinx()
|
||||
ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
|
||||
color='blue', alpha=0.8, label='Low Amplitude Mask')
|
||||
ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
|
||||
color='orange', alpha=0.8, label='Movement Mask')
|
||||
ax2.plot(time_axis, event_mask["BCG_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3,
|
||||
color='green', alpha=0.8, label='Amplitude Change Mask')
|
||||
|
||||
ax2.set_ylim(-4, 4)
|
||||
|
||||
ax.set_ylabel("Amplitude")
|
||||
ax.legend(loc=1)
|
||||
|
||||
|
||||
def plt_stage_on_ax(ax, stage_data, segment_start, segment_end):
|
||||
stage_signal = stage_data["data"]
|
||||
stage_fs = stage_data["fs"]
|
||||
time_axis = np.linspace(segment_start / stage_fs, segment_end / stage_fs, segment_end - segment_start)
|
||||
ax.plot(time_axis, stage_signal[segment_start:segment_end], color='black', label=stage_data["name"])
|
||||
ax.set_ylim(0, 6)
|
||||
ax.set_yticks([1, 2, 3, 4, 5])
|
||||
ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"])
|
||||
ax.set_ylabel("Stage")
|
||||
ax.legend(loc=1)
|
||||
|
||||
|
||||
def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end):
|
||||
spo2_signal = spo2_data["data"]
|
||||
spo2_fs = spo2_data["fs"]
|
||||
time_axis = np.linspace(segment_start / spo2_fs, segment_end / spo2_fs, segment_end - segment_start)
|
||||
ax.plot(time_axis, spo2_signal[segment_start:segment_end], color='black', label=spo2_data["name"])
|
||||
|
||||
if spo2_signal[segment_start:segment_end].min() < 85:
|
||||
ax.set_ylim((spo2_signal[segment_start:segment_end].min() - 5, 100))
|
||||
else:
|
||||
ax.set_ylim((85, 100))
|
||||
ax.set_ylabel("SpO2 (%)")
|
||||
ax.legend(loc=1)
|
||||
|
||||
|
||||
def score_mask2alpha(score_mask):
|
||||
alpha_mask = np.zeros_like(score_mask, dtype=float)
|
||||
alpha_mask[score_mask <= 0] = 0
|
||||
alpha_mask[score_mask == 1] = 0.9
|
||||
alpha_mask[score_mask == 2] = 0.6
|
||||
alpha_mask[score_mask == 3] = 0.1
|
||||
return alpha_mask
|
||||
|
||||
|
||||
def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list):
|
||||
for mask in event_mask.keys():
|
||||
if mask.startswith("Resp_") or mask.endswith("BCG_"):
|
||||
event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0)
|
||||
|
||||
event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"])
|
||||
event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0)
|
||||
|
||||
fig, axes = create_psg_bcg_figure()
|
||||
for segment_start, segment_end in segment_list:
|
||||
print(f"Drawing segment: {segment_start} to {segment_end} seconds")
|
||||
for ax in axes:
|
||||
ax.cla()
|
||||
|
||||
plt_spo2_on_ax(axes[psg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end)
|
||||
plt_stage_on_ax(axes[psg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end)
|
||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end,
|
||||
psg_label, event_codes=[1, 2, 3, 4])
|
||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end,
|
||||
psg_label, event_codes=[1, 2, 3, 4])
|
||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end,
|
||||
psg_label, event_codes=[1, 2, 3, 4])
|
||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end,
|
||||
psg_label, event_codes=[1, 2, 3, 4])
|
||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
|
||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
|
||||
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4])
|
||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
|
||||
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4])
|
||||
plt.show()
|
||||
print(f"Finished drawing segment: {segment_start} to {segment_end} seconds")
|
||||
|
||||
|
||||
def draw_resp_label(resp_data, resp_label, segment_list):
|
||||
for mask in resp_label.keys():
|
||||
if mask.startswith("Resp_"):
|
||||
resp_label[mask] = utils.none_to_nan_mask(resp_label[mask], 0)
|
||||
|
||||
resp_label["Resp_Score_Alpha"] = score_mask2alpha(resp_label["Resp_Score"])
|
||||
resp_label["Resp_Label_Alpha"] = utils.none_to_nan_mask(resp_label["Resp_Label_Alpha"], 0)
|
||||
|
||||
fig, axes = create_resp_figure()
|
||||
for segment_start, segment_end in segment_list:
|
||||
for ax in axes:
|
||||
ax.cla()
|
||||
|
||||
plt_signal_label_on_ax(axes[resp_chn_name2ax["resp"]], resp_data["resp_signal"], segment_start, segment_end,
|
||||
resp_label, multi_labels="resp", event_codes=[1, 2, 3, 4])
|
||||
plt_signal_label_on_ax(axes[resp_chn_name2ax["bcg"]], resp_data["bcg_signal"], segment_start, segment_end,
|
||||
resp_label, multi_labels="bcg", event_codes=[1, 2, 3, 4])
|
||||
plt.show()
|
||||
@ -18,15 +18,19 @@
|
||||
# 高幅值连续体动规则标定与剔除
|
||||
# 手动标定不可用区间提剔除
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
project_root_path = Path(__file__).resolve().parent.parent
|
||||
|
||||
import shutil
|
||||
import draw_tools
|
||||
import utils
|
||||
import numpy as np
|
||||
import signal_method
|
||||
import os
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
os.environ['DISPLAY'] = "localhost:10.0"
|
||||
|
||||
@ -48,56 +52,14 @@ def process_one_signal(samp_id, show=False):
|
||||
save_samp_path = save_path / f"{samp_id}"
|
||||
save_samp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
signal_data_raw = utils.read_signal_txt(signal_path)
|
||||
signal_length = len(signal_data_raw)
|
||||
print(f"signal_length: {signal_length}")
|
||||
signal_fs = int(signal_path.stem.split("_")[-1])
|
||||
print(f"signal_fs: {signal_fs}")
|
||||
signal_second = signal_length // signal_fs
|
||||
print(f"signal_second: {signal_second}")
|
||||
signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, verbose=True)
|
||||
|
||||
# 根据采样率进行截断
|
||||
signal_data_raw = signal_data_raw[:signal_second * signal_fs]
|
||||
|
||||
# 滤波
|
||||
# 50Hz陷波滤波器
|
||||
# signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs)
|
||||
print("Applying 50Hz notch filter...")
|
||||
signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs)
|
||||
|
||||
resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs)
|
||||
resp_fs = conf["resp"]["downsample_fs_1"]
|
||||
resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs)
|
||||
resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20)
|
||||
resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"],
|
||||
low_cut=conf["resp_filter"]["low_cut"],
|
||||
high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"],
|
||||
sample_rate=resp_fs)
|
||||
print("Begin plotting signal data...")
|
||||
|
||||
# fig = plt.figure(figsize=(12, 8))
|
||||
# # 绘制三个图raw_data、resp_data_1、resp_data_2
|
||||
# ax0 = fig.add_subplot(3, 1, 1)
|
||||
# ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue')
|
||||
# ax0.set_title('Raw Signal Data')
|
||||
# ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
|
||||
# ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange')
|
||||
# ax1.set_title('Resp Data after Average Filtering')
|
||||
# ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
|
||||
# ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green')
|
||||
# ax2.set_title('Resp Data after Butterworth Filtering')
|
||||
# plt.tight_layout()
|
||||
# plt.show()
|
||||
|
||||
bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"],
|
||||
low_cut=conf["bcg_filter"]["low_cut"],
|
||||
high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"],
|
||||
sample_rate=signal_fs)
|
||||
signal_data, resp_data, resp_fs, bcg_data, bcg_fs = signal_method.signal_filter_split(conf=conf, signal_data_raw=signal_data_raw, signal_fs=signal_fs)
|
||||
|
||||
# 降采样
|
||||
old_resp_fs = resp_fs
|
||||
resp_fs = conf["resp"]["downsample_fs_2"]
|
||||
resp_data = utils.downsample_signal_fast(original_signal=resp_data_2, original_fs=old_resp_fs, target_fs=resp_fs)
|
||||
resp_data = utils.downsample_signal_fast(original_signal=resp_data, original_fs=old_resp_fs, target_fs=resp_fs)
|
||||
bcg_fs = conf["bcg"]["downsample_fs"]
|
||||
bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs)
|
||||
|
||||
@ -214,26 +176,26 @@ def process_one_signal(samp_id, show=False):
|
||||
target_fs=100)
|
||||
signal_fs = 100
|
||||
|
||||
draw_tools.draw_signal_with_mask(samp_id=samp_id,
|
||||
signal_data=signal_data,
|
||||
signal_fs=signal_fs,
|
||||
resp_data=resp_data,
|
||||
resp_fs=resp_fs,
|
||||
bcg_data=bcg_data,
|
||||
bcg_fs=bcg_fs,
|
||||
signal_disable_mask=manual_disable_mask,
|
||||
resp_low_amp_mask=resp_low_amp_mask,
|
||||
resp_movement_mask=resp_movement_mask,
|
||||
resp_change_mask=resp_amp_change_mask,
|
||||
resp_sa_mask=event_mask,
|
||||
bcg_low_amp_mask=bcg_low_amp_mask,
|
||||
bcg_movement_mask=bcg_movement_mask,
|
||||
bcg_change_mask=bcg_amp_change_mask,
|
||||
show=show,
|
||||
save_path=save_samp_path / f"{samp_id}_Signal_Plots.png")
|
||||
draw_tools.draw_signal_with_mask(samp_id=samp_id,
|
||||
signal_data=signal_data,
|
||||
signal_fs=signal_fs,
|
||||
resp_data=resp_data,
|
||||
resp_fs=resp_fs,
|
||||
bcg_data=bcg_data,
|
||||
bcg_fs=bcg_fs,
|
||||
signal_disable_mask=manual_disable_mask,
|
||||
resp_low_amp_mask=resp_low_amp_mask,
|
||||
resp_movement_mask=resp_movement_mask,
|
||||
resp_change_mask=resp_amp_change_mask,
|
||||
resp_sa_mask=event_mask,
|
||||
bcg_low_amp_mask=bcg_low_amp_mask,
|
||||
bcg_movement_mask=bcg_movement_mask,
|
||||
bcg_change_mask=bcg_amp_change_mask,
|
||||
show=show,
|
||||
save_path=save_samp_path / f"{samp_id}_Signal_Plots.png")
|
||||
|
||||
# 复制事件文件 到保存路径
|
||||
sa_label_save_name = f"{samp_id}" + label_path.name
|
||||
sa_label_save_name = f"{samp_id}_" + label_path.name
|
||||
shutil.copyfile(label_path, save_samp_path / sa_label_save_name)
|
||||
|
||||
# 新建一个dataframe,分别是秒数、SA标签,SA质量标签,禁用标签,Resp低幅值标签,Resp体动标签,Resp幅值突变标签,Bcg低幅值标签,Bcg体动标签,Bcg幅值突变标签
|
||||
@ -247,10 +209,10 @@ def process_one_signal(samp_id, show=False):
|
||||
dtype=int),
|
||||
"Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second,
|
||||
dtype=int),
|
||||
"Bcg_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int),
|
||||
"Bcg_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second,
|
||||
"BCG_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int),
|
||||
"BCG_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second,
|
||||
dtype=int),
|
||||
"Bcg_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second,
|
||||
"BCG_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second,
|
||||
dtype=int)
|
||||
}
|
||||
|
||||
@ -259,13 +221,13 @@ def process_one_signal(samp_id, show=False):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
yaml_path = Path("../dataset_config/HYS_config.yaml")
|
||||
disable_df_path = Path("../排除区间.xlsx")
|
||||
yaml_path = project_root_path / "dataset_config/HYS_config.yaml"
|
||||
disable_df_path = project_root_path / "排除区间.xlsx"
|
||||
|
||||
conf = utils.load_dataset_conf(yaml_path)
|
||||
select_ids = conf["select_ids"]
|
||||
root_path = Path(conf["root_path"])
|
||||
save_path = Path(conf["save_path"])
|
||||
save_path = Path(conf["mask_save_path"])
|
||||
|
||||
print(f"select_ids: {select_ids}")
|
||||
print(f"root_path: {root_path}")
|
||||
@ -276,9 +238,9 @@ if __name__ == '__main__':
|
||||
|
||||
all_samp_disable_df = utils.read_disable_excel(disable_df_path)
|
||||
|
||||
process_one_signal(select_ids[6], show=True)
|
||||
# process_one_signal(select_ids[6], show=True)
|
||||
#
|
||||
# for samp_id in select_ids:
|
||||
# print(f"Processing sample ID: {samp_id}")
|
||||
# process_one_signal(samp_id, show=False)
|
||||
# print(f"Finished processing sample ID: {samp_id}\n\n")
|
||||
for samp_id in select_ids:
|
||||
print(f"Processing sample ID: {samp_id}")
|
||||
process_one_signal(samp_id, show=False)
|
||||
print(f"Finished processing sample ID: {samp_id}\n\n")
|
||||
|
||||
@ -1,277 +0,0 @@
|
||||
"""
|
||||
本脚本完成对呼研所数据的处理,包含以下功能:
|
||||
1. 数据读取与预处理
|
||||
从传入路径中,进行数据和标签的读取,并进行初步的预处理
|
||||
预处理包括为数据进行滤波、去噪等操作
|
||||
2. 数据清洗与异常值处理
|
||||
3. 输出清晰后的统计信息
|
||||
4. 数据保存
|
||||
将处理后的数据保存到指定路径,便于后续使用
|
||||
主要是保存切分后的数据位置和标签
|
||||
5. 可视化
|
||||
提供数据处理前后的可视化对比,帮助理解数据变化
|
||||
绘制多条可用性趋势图,展示数据的可用区间、体动区间、低幅值区间等
|
||||
|
||||
todo: 使用mask 屏蔽无用区间
|
||||
|
||||
|
||||
# 低幅值区间规则标定与剔除
|
||||
# 高幅值连续体动规则标定与剔除
|
||||
# 手动标定不可用区间提剔除
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import draw_tools
|
||||
import utils
|
||||
import numpy as np
|
||||
import signal_method
|
||||
import os
|
||||
from matplotlib import pyplot as plt
|
||||
os.environ['DISPLAY'] = "localhost:10.0"
|
||||
|
||||
def process_one_signal(samp_id, show=False):
|
||||
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
|
||||
if not signal_path:
|
||||
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
|
||||
signal_path = signal_path[0]
|
||||
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
|
||||
|
||||
label_path = (label_root_path / f"{samp_id}").glob("SA Label_corrected.csv")
|
||||
if not label_path:
|
||||
raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}")
|
||||
label_path = list(label_path)[0]
|
||||
print(f"Processing Label_corrected file: {label_path}")
|
||||
|
||||
signal_data_raw = utils.read_signal_txt(signal_path)
|
||||
signal_length = len(signal_data_raw)
|
||||
print(f"signal_length: {signal_length}")
|
||||
signal_fs = int(signal_path.stem.split("_")[-1])
|
||||
print(f"signal_fs: {signal_fs}")
|
||||
signal_second = signal_length // signal_fs
|
||||
print(f"signal_second: {signal_second}")
|
||||
|
||||
# 根据采样率进行截断
|
||||
signal_data_raw = signal_data_raw[:signal_second * signal_fs]
|
||||
|
||||
# 滤波
|
||||
# 50Hz陷波滤波器
|
||||
# signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs)
|
||||
print("Applying 50Hz notch filter...")
|
||||
signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs)
|
||||
|
||||
resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs)
|
||||
resp_fs = conf["resp"]["downsample_fs_1"]
|
||||
resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs)
|
||||
resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20)
|
||||
resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"],
|
||||
low_cut=conf["resp_filter"]["low_cut"],
|
||||
high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"],
|
||||
sample_rate=resp_fs)
|
||||
print("Begin plotting signal data...")
|
||||
|
||||
|
||||
# fig = plt.figure(figsize=(12, 8))
|
||||
# # 绘制三个图raw_data、resp_data_1、resp_data_2
|
||||
# ax0 = fig.add_subplot(3, 1, 1)
|
||||
# ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue')
|
||||
# ax0.set_title('Raw Signal Data')
|
||||
# ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
|
||||
# ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange')
|
||||
# ax1.set_title('Resp Data after Average Filtering')
|
||||
# ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
|
||||
# ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green')
|
||||
# ax2.set_title('Resp Data after Butterworth Filtering')
|
||||
# plt.tight_layout()
|
||||
# plt.show()
|
||||
|
||||
bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"],
|
||||
low_cut=conf["bcg_filter"]["low_cut"],
|
||||
high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"],
|
||||
sample_rate=signal_fs)
|
||||
|
||||
# 降采样
|
||||
old_resp_fs = resp_fs
|
||||
resp_fs = conf["resp"]["downsample_fs_2"]
|
||||
resp_data = utils.downsample_signal_fast(original_signal=resp_data_2, original_fs=old_resp_fs, target_fs=resp_fs)
|
||||
bcg_fs = conf["bcg"]["downsample_fs"]
|
||||
bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs)
|
||||
|
||||
label_data = utils.read_label_csv(path=label_path)
|
||||
event_mask, score_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data)
|
||||
|
||||
manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[
|
||||
all_samp_disable_df["id"] == samp_id])
|
||||
print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}")
|
||||
|
||||
# 分析Resp的低幅值区间
|
||||
resp_low_amp_conf = conf.get("resp_low_amp", None)
|
||||
if resp_low_amp_conf is not None:
|
||||
resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal(
|
||||
signal_data=resp_data,
|
||||
sampling_rate=resp_fs,
|
||||
**resp_low_amp_conf
|
||||
)
|
||||
print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}, count_low_amp_positions: {len(resp_low_amp_position_list)}")
|
||||
else:
|
||||
resp_low_amp_mask, resp_low_amp_position_list = None, None
|
||||
print("resp_low_amp_mask is None")
|
||||
|
||||
# 分析Resp的高幅值伪迹区间
|
||||
resp_movement_conf = conf.get("resp_movement", None)
|
||||
if resp_movement_conf is not None:
|
||||
raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement(
|
||||
signal_data=resp_data,
|
||||
sampling_rate=resp_fs,
|
||||
**resp_movement_conf
|
||||
)
|
||||
print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}")
|
||||
else:
|
||||
resp_movement_mask, resp_movement_position_list = None, None
|
||||
print("resp_movement_mask is None")
|
||||
|
||||
resp_movement_revise_conf = conf.get("resp_movement_revise", None)
|
||||
if resp_movement_mask is not None and resp_movement_revise_conf is not None:
|
||||
resp_movement_mask, resp_movement_position_list = signal_method.movement_revise(
|
||||
signal_data=resp_data,
|
||||
movement_mask=resp_movement_mask,
|
||||
movement_list=resp_movement_position_list,
|
||||
sampling_rate=resp_fs,
|
||||
**resp_movement_revise_conf,
|
||||
verbose=False
|
||||
)
|
||||
print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}")
|
||||
else:
|
||||
print("resp_movement_mask revise is skipped")
|
||||
|
||||
|
||||
# 分析Resp的幅值突变区间
|
||||
resp_amp_change_conf = conf.get("resp_amp_change", None)
|
||||
if resp_amp_change_conf is not None and resp_movement_mask is not None:
|
||||
resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v3(
|
||||
signal_data=resp_data,
|
||||
movement_mask=resp_movement_mask,
|
||||
movement_list=resp_movement_position_list,
|
||||
sampling_rate=resp_fs,
|
||||
**resp_amp_change_conf)
|
||||
print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}")
|
||||
else:
|
||||
resp_amp_change_mask = None
|
||||
print("amp_change_mask is None")
|
||||
|
||||
|
||||
|
||||
# 分析Bcg的低幅值区间
|
||||
bcg_low_amp_conf = conf.get("bcg_low_amp", None)
|
||||
if bcg_low_amp_conf is not None:
|
||||
bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal(
|
||||
signal_data=bcg_data,
|
||||
sampling_rate=bcg_fs,
|
||||
**bcg_low_amp_conf
|
||||
)
|
||||
print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}, count_low_amp_positions: {len(bcg_low_amp_position_list)}")
|
||||
else:
|
||||
bcg_low_amp_mask, bcg_low_amp_position_list = None, None
|
||||
print("bcg_low_amp_mask is None")
|
||||
# 分析Bcg的高幅值伪迹区间
|
||||
bcg_movement_conf = conf.get("bcg_movement", None)
|
||||
if bcg_movement_conf is not None:
|
||||
raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement(
|
||||
signal_data=bcg_data,
|
||||
sampling_rate=bcg_fs,
|
||||
**bcg_movement_conf
|
||||
)
|
||||
print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}, count_movement_positions: {len(bcg_movement_position_list)}")
|
||||
else:
|
||||
bcg_movement_mask = None
|
||||
print("bcg_movement_mask is None")
|
||||
# 分析Bcg的幅值突变区间
|
||||
if bcg_movement_mask is not None:
|
||||
bcg_amp_change_mask, bcg_amp_change_list = signal_method.position_based_sleep_recognition_v2(
|
||||
signal_data=bcg_data,
|
||||
movement_mask=bcg_movement_mask,
|
||||
sampling_rate=bcg_fs)
|
||||
print(f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}, count_amp_change_positions: {len(bcg_amp_change_list)}")
|
||||
else:
|
||||
bcg_amp_change_mask = None
|
||||
print("bcg_amp_change_mask is None")
|
||||
|
||||
|
||||
# 如果signal_data采样率过,进行降采样
|
||||
if signal_fs == 1000:
|
||||
signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100)
|
||||
signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs, target_fs=100)
|
||||
signal_fs = 100
|
||||
if show:
|
||||
draw_tools.draw_signal_with_mask(samp_id=samp_id,
|
||||
signal_data=signal_data,
|
||||
signal_fs=signal_fs,
|
||||
resp_data=resp_data,
|
||||
resp_fs=resp_fs,
|
||||
bcg_data=bcg_data,
|
||||
bcg_fs=bcg_fs,
|
||||
signal_disable_mask=manual_disable_mask,
|
||||
resp_low_amp_mask=resp_low_amp_mask,
|
||||
resp_movement_mask=resp_movement_mask,
|
||||
resp_change_mask=resp_amp_change_mask,
|
||||
resp_sa_mask=event_mask,
|
||||
bcg_low_amp_mask=bcg_low_amp_mask,
|
||||
bcg_movement_mask=bcg_movement_mask,
|
||||
bcg_change_mask=bcg_amp_change_mask)
|
||||
|
||||
|
||||
# 保存处理后的数据和标签
|
||||
save_samp_path = save_path / f"{samp_id}"
|
||||
save_samp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 复制事件文件 到保存路径
|
||||
sa_label_save_name = f"{samp_id}" + label_path.name
|
||||
shutil.copyfile(label_path, save_samp_path / sa_label_save_name)
|
||||
|
||||
# 新建一个dataframe,分别是秒数、SA标签,SA质量标签,禁用标签,Resp低幅值标签,Resp体动标签,Resp幅值突变标签,Bcg低幅值标签,Bcg体动标签,Bcg幅值突变标签
|
||||
save_dict = {
|
||||
"Second": np.arange(signal_second),
|
||||
"SA_Label": event_mask,
|
||||
"SA_Score": score_mask,
|
||||
"Disable_Label": manual_disable_mask,
|
||||
"Resp_LowAmp_Label": resp_low_amp_mask if resp_low_amp_mask is not None else np.zeros(signal_second, dtype=int),
|
||||
"Resp_Movement_Label": resp_movement_mask if resp_movement_mask is not None else np.zeros(signal_second, dtype=int),
|
||||
"Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second, dtype=int),
|
||||
"Bcg_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int),
|
||||
"Bcg_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second, dtype=int),
|
||||
"Bcg_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second, dtype=int)
|
||||
}
|
||||
|
||||
mask_label_save_name = f"{samp_id}_Processed_Labels.csv"
|
||||
utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
yaml_path = Path("../dataset_config/ZD5Y_config.yaml")
|
||||
disable_df_path = Path("../排除区间.xlsx")
|
||||
|
||||
conf = utils.load_dataset_conf(yaml_path)
|
||||
select_ids = conf["select_ids"]
|
||||
root_path = Path(conf["root_path"])
|
||||
save_path = Path(conf["save_path"])
|
||||
|
||||
print(f"select_ids: {select_ids}")
|
||||
print(f"root_path: {root_path}")
|
||||
print(f"save_path: {save_path}")
|
||||
|
||||
org_signal_root_path = root_path / "OrgBCG_Aligned"
|
||||
label_root_path = root_path / "Label"
|
||||
|
||||
all_samp_disable_df = utils.read_disable_excel(disable_df_path)
|
||||
|
||||
process_one_signal(select_ids[1], show=True)
|
||||
|
||||
# for samp_id in select_ids:
|
||||
# print(f"Processing sample ID: {samp_id}")
|
||||
# process_one_signal(samp_id, show=False)
|
||||
# print(f"Finished processing sample ID: {samp_id}\n\n")
|
||||
@ -1,4 +1,6 @@
|
||||
from .rule_base_event import detect_low_amplitude_signal, detect_movement
|
||||
from .rule_base_event import position_based_sleep_recognition_v2, position_based_sleep_recognition_v3
|
||||
from .rule_base_event import movement_revise
|
||||
from .time_metrics import calc_mav_by_slide_windows
|
||||
from .time_metrics import calc_mav_by_slide_windows
|
||||
from .signal_process import signal_filter_split, rpeak2hr
|
||||
from .normalize_method import normalize_resp_signal
|
||||
36
signal_method/normalize_method.py
Normal file
36
signal_method/normalize_method.py
Normal file
@ -0,0 +1,36 @@
|
||||
import utils
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
|
||||
def normalize_resp_signal(resp_signal: np.ndarray, resp_fs, movement_mask, enable_list):
|
||||
# 根据呼吸信号的幅值改变区间,对每段进行Z-Score标准化
|
||||
normalized_resp_signal = np.zeros_like(resp_signal)
|
||||
# 全部填成nan
|
||||
normalized_resp_signal[:] = np.nan
|
||||
|
||||
resp_signal_no_movement = resp_signal.copy()
|
||||
|
||||
|
||||
resp_signal_no_movement[np.array(movement_mask == 1).repeat(resp_fs)] = np.nan
|
||||
|
||||
|
||||
for i in range(len(enable_list)):
|
||||
enable_start = enable_list[i][0] * resp_fs
|
||||
enable_end = enable_list[i][1] * resp_fs
|
||||
segment = resp_signal_no_movement[enable_start:enable_end]
|
||||
|
||||
# print(f"Normalizing segment {i+1}/{len(enable_list)}: start={enable_start}, end={enable_end}, length={len(segment)}")
|
||||
|
||||
segment_mean = np.nanmean(segment)
|
||||
segment_std = np.nanstd(segment)
|
||||
if segment_std == 0:
|
||||
raise ValueError(f"segment_std is zero! segment_start: {enable_start}, segment_end: {enable_end}")
|
||||
|
||||
# 同下一个enable区间的体动一起进行标准化
|
||||
if i <= len(enable_list) - 2:
|
||||
enable_end = enable_list[i + 1][0] * resp_fs
|
||||
raw_segment = resp_signal[enable_start:enable_end]
|
||||
normalized_resp_signal[enable_start:enable_end] = (raw_segment - segment_mean) / segment_std
|
||||
|
||||
return normalized_resp_signal
|
||||
62
signal_method/signal_process.py
Normal file
62
signal_method/signal_process.py
Normal file
@ -0,0 +1,62 @@
|
||||
import numpy as np
|
||||
|
||||
import utils
|
||||
|
||||
def signal_filter_split(conf, signal_data_raw, signal_fs):
|
||||
# 滤波
|
||||
# 50Hz陷波滤波器
|
||||
# signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs)
|
||||
print("Applying 50Hz notch filter...")
|
||||
signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs)
|
||||
|
||||
resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs)
|
||||
resp_fs = conf["resp"]["downsample_fs_1"]
|
||||
resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs)
|
||||
resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20)
|
||||
resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"],
|
||||
low_cut=conf["resp_filter"]["low_cut"],
|
||||
high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"],
|
||||
sample_rate=resp_fs)
|
||||
print("Begin plotting signal data...")
|
||||
|
||||
# fig = plt.figure(figsize=(12, 8))
|
||||
# # 绘制三个图raw_data、resp_data_1、resp_data_2
|
||||
# ax0 = fig.add_subplot(3, 1, 1)
|
||||
# ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue')
|
||||
# ax0.set_title('Raw Signal Data')
|
||||
# ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
|
||||
# ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange')
|
||||
# ax1.set_title('Resp Data after Average Filtering')
|
||||
# ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
|
||||
# ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green')
|
||||
# ax2.set_title('Resp Data after Butterworth Filtering')
|
||||
# plt.tight_layout()
|
||||
# plt.show()
|
||||
|
||||
bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"],
|
||||
low_cut=conf["bcg_filter"]["low_cut"],
|
||||
high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"],
|
||||
sample_rate=signal_fs)
|
||||
|
||||
|
||||
return signal_data, resp_data_2, resp_fs, bcg_data, signal_fs
|
||||
|
||||
|
||||
|
||||
def rpeak2hr(rpeak_indices, signal_length):
|
||||
hr_signal = np.zeros(signal_length)
|
||||
for i in range(1, len(rpeak_indices)):
|
||||
rri = rpeak_indices[i] - rpeak_indices[i - 1]
|
||||
if rri == 0:
|
||||
continue
|
||||
hr = 60 * 1000 / rri # 心率,单位:bpm
|
||||
if hr > 120:
|
||||
hr = 120
|
||||
elif hr < 30:
|
||||
hr = 30
|
||||
hr_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = hr
|
||||
# 填充最后一个R峰之后的心率值
|
||||
if len(rpeak_indices) > 1:
|
||||
hr_signal[rpeak_indices[-1]:] = hr_signal[rpeak_indices[-2]]
|
||||
return hr_signal
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import utils
|
||||
from .event_map import N2Chn
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from .operation_tools import event_mask_2_list
|
||||
# 尝试导入 Polars
|
||||
try:
|
||||
import polars as pl
|
||||
@ -13,15 +15,17 @@ except ImportError:
|
||||
HAS_POLARS = False
|
||||
|
||||
|
||||
def read_signal_txt(path: Union[str, Path]) -> np.ndarray:
|
||||
def read_signal_txt(path: Union[str, Path], dtype, verbose=True, is_peak=False):
|
||||
"""
|
||||
Read a txt file and return the first column as a numpy array.
|
||||
|
||||
Args:
|
||||
path (str | Path): Path to the txt file.
|
||||
|
||||
:param path:
|
||||
:param verbose:
|
||||
:param dtype:
|
||||
Returns:
|
||||
np.ndarray: The first column of the txt file as a numpy array.
|
||||
|
||||
"""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
@ -29,10 +33,30 @@ def read_signal_txt(path: Union[str, Path]) -> np.ndarray:
|
||||
|
||||
if HAS_POLARS:
|
||||
df = pl.read_csv(path, has_header=False, infer_schema_length=0)
|
||||
return df[:, 0].to_numpy().astype(float)
|
||||
signal_data_raw = df[:, 0].to_numpy().astype(dtype)
|
||||
else:
|
||||
df = pd.read_csv(path, header=None, dtype=float)
|
||||
return df.iloc[:, 0].to_numpy()
|
||||
df = pd.read_csv(path, header=None, dtype=dtype)
|
||||
signal_data_raw = df.iloc[:, 0].to_numpy()
|
||||
|
||||
signal_original_length = len(signal_data_raw)
|
||||
signal_fs = int(path.stem.split("_")[-1])
|
||||
if is_peak:
|
||||
signal_second = None
|
||||
signal_length = None
|
||||
else:
|
||||
signal_second = signal_original_length // signal_fs
|
||||
# 根据采样率进行截断
|
||||
signal_data_raw = signal_data_raw[:signal_second * signal_fs]
|
||||
signal_length = len(signal_data_raw)
|
||||
|
||||
if verbose:
|
||||
print(f"Signal file read from {path}")
|
||||
print(f"signal_fs: {signal_fs}")
|
||||
print(f"signal_original_length: {signal_original_length}")
|
||||
print(f"signal_after_cut_off_length: {signal_length}")
|
||||
print(f"signal_second: {signal_second}")
|
||||
|
||||
return signal_data_raw, signal_length, signal_fs, signal_second
|
||||
|
||||
|
||||
def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame:
|
||||
@ -172,3 +196,99 @@ def read_disable_excel(path: Union[str, Path]) -> pd.DataFrame:
|
||||
df["start"] = df["start"].astype(int)
|
||||
df["end"] = df["end"].astype(int)
|
||||
return df
|
||||
|
||||
|
||||
def read_mask_execl(path: Union[str, Path]):
|
||||
"""
|
||||
Read an Excel file and return the mask as a numpy array.
|
||||
Args:
|
||||
path (str | Path): Path to the Excel file.
|
||||
Returns:
|
||||
np.ndarray: The mask as a numpy array.
|
||||
"""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
df = pd.read_csv(path)
|
||||
event_mask = df.to_dict(orient="list")
|
||||
for key in event_mask:
|
||||
event_mask[key] = np.array(event_mask[key])
|
||||
|
||||
event_list = {"RespAmpChangeSegment": event_mask_2_list(1 - event_mask["Resp_AmpChange_Label"]),
|
||||
"BCGAmpChangeSegment": event_mask_2_list(1 - event_mask["BCG_AmpChange_Label"]),
|
||||
"EnableSegment": event_mask_2_list(1 - event_mask["Disable_Label"]),}
|
||||
|
||||
|
||||
return event_mask, event_list
|
||||
|
||||
|
||||
|
||||
def read_psg_channel(path_str: Union[str, Path], channel_number: list[int]):
|
||||
"""
|
||||
读取PSG文件中特定通道的数据。
|
||||
|
||||
参数:
|
||||
path_str (Union[str, Path]): 存放PSG文件的文件夹路径。
|
||||
channel_name (str): 需要读取的通道名称。
|
||||
返回:
|
||||
np.ndarray: 指定通道的数据数组。
|
||||
"""
|
||||
path = Path(path_str)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"PSG Dir not found: {path}")
|
||||
|
||||
if not path.is_dir():
|
||||
raise NotADirectoryError(f"PSG Dir not found: {path}")
|
||||
channel_data = {}
|
||||
# 遍历检查通道对应的文件是否存在
|
||||
for ch_id in channel_number:
|
||||
ch_name = N2Chn[ch_id]
|
||||
ch_path = list(path.glob(f"{ch_name}*.txt"))
|
||||
|
||||
if not any(ch_path):
|
||||
raise FileNotFoundError(f"PSG Channel file not found: {ch_path}")
|
||||
|
||||
if len(ch_path) > 1:
|
||||
print(f"Warning!!! PSG Channel file more than one: {ch_path}")
|
||||
|
||||
if ch_id == 8:
|
||||
# sleep stage 特例 读取为整数
|
||||
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=str, verbose=True)
|
||||
# 转换为整数数组
|
||||
for stage_str, stage_number in utils.Stage2N.items():
|
||||
np.place(ch_signal, ch_signal == stage_str, stage_number)
|
||||
ch_signal = ch_signal.astype(int)
|
||||
elif ch_id == 1:
|
||||
# Rpeak 特例 读取为整数
|
||||
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=int, verbose=True, is_peak=True)
|
||||
else:
|
||||
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=float, verbose=True)
|
||||
channel_data[ch_name] = {
|
||||
"name": ch_name,
|
||||
"path": ch_path[0],
|
||||
"data": ch_signal,
|
||||
"length": ch_length,
|
||||
"fs": ch_fs,
|
||||
"second": ch_second
|
||||
}
|
||||
|
||||
return channel_data
|
||||
|
||||
|
||||
def read_psg_label(path: Union[str, Path], verbose=True):
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
# 直接用pandas读取 包含中文 故指定编码
|
||||
df = pd.read_csv(path, encoding="gbk")
|
||||
if verbose:
|
||||
print(f"Label file read from {path}, number of rows: {len(df)}")
|
||||
|
||||
# 丢掉Event type为空的行
|
||||
df = df.dropna(subset=["Event type"], how='all').reset_index(drop=True)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel
|
||||
from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel, read_psg_label
|
||||
from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list
|
||||
from .operation_tools import merge_short_gaps, remove_short_durations
|
||||
from .operation_tools import collect_values
|
||||
from .operation_tools import save_process_label
|
||||
from .event_map import E2N
|
||||
from .signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel
|
||||
from .operation_tools import none_to_nan_mask
|
||||
from .split_method import resp_split
|
||||
from .HYS_FileReader import read_mask_execl, read_psg_channel
|
||||
from .event_map import E2N, N2Chn, Stage2N, ColorCycle
|
||||
from .filter_func import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel
|
||||
@ -4,4 +4,39 @@ E2N = {
|
||||
"Central apnea": 2,
|
||||
"Obstructive apnea": 3,
|
||||
"Mixed apnea": 4
|
||||
}
|
||||
}
|
||||
|
||||
N2Chn = {
|
||||
1: "Rpeak",
|
||||
2: "ECG_Sync",
|
||||
3: "Effort Tho",
|
||||
4: "Effort Abd",
|
||||
5: "Flow P",
|
||||
6: "Flow T",
|
||||
7: "SpO2",
|
||||
8: "5_class"
|
||||
}
|
||||
|
||||
Stage2N = {
|
||||
"W": 5,
|
||||
"N1": 3,
|
||||
"N2": 2,
|
||||
"N3": 1,
|
||||
"R": 4,
|
||||
}
|
||||
|
||||
# 设定事件和其对应颜色
|
||||
# event_code color event
|
||||
# 0 黑色 背景
|
||||
# 1 粉色 低通气
|
||||
# 2 蓝色 中枢性
|
||||
# 3 红色 阻塞型
|
||||
# 4 灰色 混合型
|
||||
# 5 绿色 血氧饱和度下降
|
||||
# 6 橙色 大体动
|
||||
# 7 橙色 小体动
|
||||
# 8 橙色 深呼吸
|
||||
# 9 橙色 脉冲体动
|
||||
# 10 橙色 无效片段
|
||||
ColorCycle = ["black", "pink", "blue", "red", "silver", "green", "orange", "orange", "orange", "orange",
|
||||
"orange"]
|
||||
@ -198,16 +198,25 @@ def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray:
|
||||
return disable_mask
|
||||
|
||||
|
||||
def generate_event_mask(signal_second: int, event_df):
|
||||
def generate_event_mask(signal_second: int, event_df, use_correct=True):
|
||||
event_mask = np.zeros(signal_second, dtype=int)
|
||||
score_mask = np.zeros(signal_second, dtype=int)
|
||||
if use_correct:
|
||||
start_name = "correct_Start"
|
||||
end_name = "correct_End"
|
||||
event_type_name = "correct_EventsType"
|
||||
else:
|
||||
start_name = "Start"
|
||||
end_name = "End"
|
||||
event_type_name = "Event type"
|
||||
|
||||
# 剔除start = -1 的行
|
||||
event_df = event_df[event_df["correct_Start"] >= 0]
|
||||
event_df = event_df[event_df[start_name] >= 0]
|
||||
|
||||
for _, row in event_df.iterrows():
|
||||
start = row["correct_Start"]
|
||||
end = row["correct_End"] + 1
|
||||
event_mask[start:end] = E2N[row["correct_EventsType"]]
|
||||
start = row[start_name]
|
||||
end = row[end_name] + 1
|
||||
event_mask[start:end] = E2N[row[event_type_name]]
|
||||
score_mask[start:end] = row["score"]
|
||||
return event_mask, score_mask
|
||||
|
||||
@ -243,3 +252,12 @@ def collect_values(arr: np.ndarray, index: int, step: int, limit: int, mask=None
|
||||
def save_process_label(save_path: Path, save_dict: dict):
|
||||
save_df = pd.DataFrame(save_dict)
|
||||
save_df.to_csv(save_path, index=False)
|
||||
|
||||
def none_to_nan_mask(mask, ref):
|
||||
"""将None转换为与ref形状相同的nan掩码"""
|
||||
if mask is None:
|
||||
return np.full_like(ref, np.nan)
|
||||
else:
|
||||
# 将mask中的0替换为nan,其他的保持
|
||||
mask = np.where(mask == 0, np.nan, mask)
|
||||
return mask
|
||||
27
utils/split_method.py
Normal file
27
utils/split_method.py
Normal file
@ -0,0 +1,27 @@
|
||||
|
||||
|
||||
|
||||
def resp_split(dataset_config, event_mask, event_list):
|
||||
# 提取体动区间和呼吸低幅值区间
|
||||
enable_list = event_list["EnableSegment"]
|
||||
|
||||
# 读取数据集配置
|
||||
window_sec = dataset_config["window_sec"]
|
||||
stride_sec = dataset_config["stride_sec"]
|
||||
|
||||
segment_list = []
|
||||
|
||||
# 遍历每个enable区间, 如果最后一个窗口不足stride的1/2,则舍弃,否则以enable_end为结尾截取一个窗口
|
||||
for enable_start, enable_end in enable_list:
|
||||
current_start = enable_start
|
||||
while current_start + window_sec <= enable_end:
|
||||
segment_list.append((current_start, current_start + window_sec))
|
||||
current_start += stride_sec
|
||||
# 检查最后一个窗口是否需要添加
|
||||
if (enable_end - current_start >= stride_sec / 2) and (enable_end - current_start >= window_sec):
|
||||
segment_list.append((enable_end - window_sec, enable_end))
|
||||
|
||||
return segment_list
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user