更新数据处理模块,添加信号标准化和绘图功能,重构部分函数以提高可读性

This commit is contained in:
marques 2025-11-14 18:39:50 +08:00
parent 1a0761c6c8
commit ed4205f5b8
17 changed files with 774 additions and 382 deletions

2
.gitignore vendored
View File

@ -253,3 +253,5 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
output/*
!output/

View File

@ -0,0 +1,166 @@
import sys
from pathlib import Path
import os
import numpy as np
os.environ['DISPLAY'] = "localhost:10.0"
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import utils
import signal_method
import draw_tools
import shutil
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
if not signal_path:
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
signal_path = signal_path[0]
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv")
print(f"mask_excel_path: {mask_excel_path}")
event_mask, event_list = utils.read_mask_execl(mask_excel_path)
bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float)
bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs)
normalized_resp_signal = signal_method.normalize_resp_signal(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"])
# 如果signal_data采样率过进行降采样
if signal_fs == 1000:
bcg_signal_notch = utils.downsample_signal_fast(original_signal=bcg_signal_notch, original_fs=signal_fs, target_fs=100)
bcg_signal_raw = utils.downsample_signal_fast(original_signal=bcg_signal_raw, original_fs=signal_fs,
target_fs=100)
signal_fs = 100
if bcg_fs == 1000:
bcg_signal = utils.downsample_signal_fast(original_signal=bcg_signal, original_fs=bcg_fs, target_fs=100)
bcg_fs = 100
# draw_tools.draw_signal_with_mask(samp_id=samp_id,
# signal_data=resp_signal,
# signal_fs=resp_fs,
# resp_data=normalized_resp_signal,
# resp_fs=resp_fs,
# bcg_data=bcg_signal,
# bcg_fs=bcg_fs,
# signal_disable_mask=event_mask["Disable_Label"],
# resp_low_amp_mask=event_mask["Resp_LowAmp_Label"],
# resp_movement_mask=event_mask["Resp_Movement_Label"],
# resp_change_mask=event_mask["Resp_AmpChange_Label"],
# resp_sa_mask=event_mask["SA_Label"],
# bcg_low_amp_mask=event_mask["BCG_LowAmp_Label"],
# bcg_movement_mask=event_mask["BCG_Movement_Label"],
# bcg_change_mask=event_mask["BCG_AmpChange_Label"],
# show=show,
# save_path=None)
segment_list = utils.resp_split(dataset_config, event_mask, event_list)
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
# 复制mask到processed_Labels文件夹
save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv"
shutil.copyfile(mask_excel_path, save_mask_excel_path)
# 复制SA Label_corrected.csv到processed_Labels文件夹
sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_corrected.csv")
if sa_label_corrected_path.exists():
save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv"
shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path)
else:
print(f"Warning: {sa_label_corrected_path} does not exist.")
# 保存处理后的信号和截取的片段列表
save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz"
save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz"
bcg_data = {
"bcg_signal_notch": {
"name": "BCG_Signal_Notch",
"data": bcg_signal_notch,
"fs": signal_fs,
"length": len(bcg_signal_notch),
"second": len(bcg_signal_notch) // signal_fs
},
"bcg_signal":{
"name": "BCG_Signal_Raw",
"data": bcg_signal,
"fs": bcg_fs,
"length": len(bcg_signal),
"second": len(bcg_signal) // bcg_fs
},
"resp_signal": {
"name": "Resp_Signal",
"data": normalized_resp_signal,
"fs": resp_fs,
"length": len(normalized_resp_signal),
"second": len(normalized_resp_signal) // resp_fs
}
}
np.savez_compressed(save_signal_path, **bcg_data)
np.savez_compressed(save_segment_path,
segment_list=segment_list)
print(f"Saved processed signals to: {save_signal_path}")
print(f"Saved segment list to: {save_segment_path}")
if draw_segment:
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8])
psg_data["HR"] = {
"name": "HR",
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"]),
"fs": psg_data["ECG_Sync"]["fs"],
"length": psg_data["ECG_Sync"]["length"],
"second": psg_data["ECG_Sync"]["second"]
}
psg_label = utils.read_psg_label(sa_label_corrected_path)
psg_event_mask, _ = utils.generate_event_mask(event_df=psg_label, signal_second=psg_data["ECG_Sync"]["second"], use_correct=False)
draw_tools.draw_psg_bcg_label(psg_data=psg_data,
psg_label=psg_event_mask,
bcg_data=bcg_data,
event_mask=event_mask,
segment_list=segment_list)
if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/HYS_config.yaml"
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
mask_path = Path(conf["mask_save_path"])
save_path = Path(conf["dataset_config"]["dataset_save_path"])
dataset_config = conf["dataset_config"]
save_processed_signal_path = save_path / "Signals"
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
save_segment_list_path = save_path / "Segments_List"
save_segment_list_path.mkdir(parents=True, exist_ok=True)
save_processed_label_path = save_path / "Labels"
save_processed_label_path.mkdir(parents=True, exist_ok=True)
print(f"select_ids: {select_ids}")
print(f"root_path: {root_path}")
print(f"save_path: {save_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned"
build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True)
#
# for samp_id in select_ids:
# print(f"Processing sample ID: {samp_id}")
# build_HYS_dataset_segment(samp_id, show=False)

View File

@ -11,7 +11,7 @@ select_ids:
- 960 - 960
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
save_path: /mnt/disk_code/marques/dataprepare/output/HYS mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS
resp: resp:
downsample_fs_1: 100 downsample_fs_1: 100
@ -78,3 +78,8 @@ bcg_movement:
merge_gap_sec: 20 merge_gap_sec: 20
min_duration_sec: 4 min_duration_sec: 4
dataset_config:
window_sec: 180
stride_sec: 60
dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_dataset

View File

@ -1 +1,2 @@
from .draw_statics import draw_signal_with_mask from .draw_statics import draw_signal_with_mask
from .draw_label import draw_psg_bcg_label, draw_resp_label

230
draw_tools/draw_label.py Normal file
View File

@ -0,0 +1,230 @@
from matplotlib.axes import Axes
from matplotlib.gridspec import GridSpec
from matplotlib.colors import PowerNorm
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import numpy as np
import utils
# 添加with_prediction参数
psg_chn_name2ax = {
"SpO2": 0,
"Flow T": 1,
"Flow P": 2,
"Effort Tho": 3,
"Effort Abd": 4,
"HR": 5,
"resp": 6,
"bcg": 7,
"Stage": 8
}
resp_chn_name2ax = {
"resp": 0,
"bcg": 1,
}
def create_psg_bcg_figure():
fig = plt.figure(figsize=(12, 8), dpi=100)
gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1])
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
axes = []
for i in range(9):
ax = fig.add_subplot(gs[i])
axes.append(ax)
axes[0].grid(True)
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
axes[0].set_ylim((85, 100))
axes[0].tick_params(axis='x', colors="white")
axes[1].grid(True)
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
axes[1].tick_params(axis='x', colors="white")
axes[2].grid(True)
# axes[2].xaxis.set_major_formatter(Params.FORMATTER)
axes[2].tick_params(axis='x', colors="white")
axes[3].grid(True)
# axes[3].xaxis.set_major_formatter(Params.FORMATTER)
axes[3].tick_params(axis='x', colors="white")
axes[4].grid(True)
# axes[4].xaxis.set_major_formatter(Params.FORMATTER)
axes[4].tick_params(axis='x', colors="white")
axes[5].grid(True)
axes[5].tick_params(axis='x', colors="white")
axes[6].grid(True)
# axes[5].xaxis.set_major_formatter(Params.FORMATTER)
axes[6].tick_params(axis='x', colors="white")
axes[7].grid(True)
# axes[6].xaxis.set_major_formatter(Params.FORMATTER)
axes[7].tick_params(axis='x', colors="white")
axes[8].grid(True)
# axes[7].xaxis.set_major_formatter(Params.FORMATTER)
return fig, axes
def create_resp_figure():
fig = plt.figure(figsize=(12, 6), dpi=100)
gs = GridSpec(2, 1, height_ratios=[3, 2])
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
axes = []
for i in range(2):
ax = fig.add_subplot(gs[i])
axes.append(ax)
axes[0].grid(True)
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
axes[0].tick_params(axis='x', colors="white")
axes[1].grid(True)
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
axes[1].tick_params(axis='x', colors="white")
return fig, axes
def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None,
event_codes: list[int] = None, multi_labels=None):
signal_fs = signal_data["fs"]
chn_signal = signal_data["data"]
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs)
ax.plot(time_axis, chn_signal[segment_start * signal_fs:segment_end * signal_fs], color='black',
label=signal_data["name"])
if event_mask is not None:
if multi_labels is None and event_codes is not None:
for event_code in event_codes:
mask = event_mask[segment_start:segment_end].repeat(signal_fs) == event_code
y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * mask).astype(float)
np.place(y, y == 0, np.nan)
ax.plot(time_axis, y, color=utils.ColorCycle[event_code])
elif multi_labels == "resp" and event_codes is not None:
ax.set_ylim(-6, 6)
# 建立第二个y轴坐标
ax2 = ax.twinx()
ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
color='blue', alpha=0.8, label='Low Amplitude Mask')
ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
color='orange', alpha=0.8, label='Movement Mask')
ax2.plot(time_axis, event_mask["Resp_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3,
color='green', alpha=0.8, label='Amplitude Change Mask')
for event_code in event_codes:
sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code
score_mask = event_mask["SA_Score_Alpha"][segment_start:segment_end].repeat(signal_fs)
y = (sa_mask * score_mask).astype(float)
np.place(y, y == 0, np.nan)
ax2.plot(time_axis, y, color=utils.ColorCycle[event_code])
ax2.set_ylim(-4, 5)
elif multi_labels == "bcg" and event_codes is not None:
# 建立第二个y轴坐标
ax2 = ax.twinx()
ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1,
color='blue', alpha=0.8, label='Low Amplitude Mask')
ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2,
color='orange', alpha=0.8, label='Movement Mask')
ax2.plot(time_axis, event_mask["BCG_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3,
color='green', alpha=0.8, label='Amplitude Change Mask')
ax2.set_ylim(-4, 4)
ax.set_ylabel("Amplitude")
ax.legend(loc=1)
def plt_stage_on_ax(ax, stage_data, segment_start, segment_end):
stage_signal = stage_data["data"]
stage_fs = stage_data["fs"]
time_axis = np.linspace(segment_start / stage_fs, segment_end / stage_fs, segment_end - segment_start)
ax.plot(time_axis, stage_signal[segment_start:segment_end], color='black', label=stage_data["name"])
ax.set_ylim(0, 6)
ax.set_yticks([1, 2, 3, 4, 5])
ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"])
ax.set_ylabel("Stage")
ax.legend(loc=1)
def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end):
spo2_signal = spo2_data["data"]
spo2_fs = spo2_data["fs"]
time_axis = np.linspace(segment_start / spo2_fs, segment_end / spo2_fs, segment_end - segment_start)
ax.plot(time_axis, spo2_signal[segment_start:segment_end], color='black', label=spo2_data["name"])
if spo2_signal[segment_start:segment_end].min() < 85:
ax.set_ylim((spo2_signal[segment_start:segment_end].min() - 5, 100))
else:
ax.set_ylim((85, 100))
ax.set_ylabel("SpO2 (%)")
ax.legend(loc=1)
def score_mask2alpha(score_mask):
alpha_mask = np.zeros_like(score_mask, dtype=float)
alpha_mask[score_mask <= 0] = 0
alpha_mask[score_mask == 1] = 0.9
alpha_mask[score_mask == 2] = 0.6
alpha_mask[score_mask == 3] = 0.1
return alpha_mask
def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list):
for mask in event_mask.keys():
if mask.startswith("Resp_") or mask.endswith("BCG_"):
event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0)
event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"])
event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0)
fig, axes = create_psg_bcg_figure()
for segment_start, segment_end in segment_list:
print(f"Drawing segment: {segment_start} to {segment_end} seconds")
for ax in axes:
ax.cla()
plt_spo2_on_ax(axes[psg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end)
plt_stage_on_ax(axes[psg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4])
plt.show()
print(f"Finished drawing segment: {segment_start} to {segment_end} seconds")
def draw_resp_label(resp_data, resp_label, segment_list):
for mask in resp_label.keys():
if mask.startswith("Resp_"):
resp_label[mask] = utils.none_to_nan_mask(resp_label[mask], 0)
resp_label["Resp_Score_Alpha"] = score_mask2alpha(resp_label["Resp_Score"])
resp_label["Resp_Label_Alpha"] = utils.none_to_nan_mask(resp_label["Resp_Label_Alpha"], 0)
fig, axes = create_resp_figure()
for segment_start, segment_end in segment_list:
for ax in axes:
ax.cla()
plt_signal_label_on_ax(axes[resp_chn_name2ax["resp"]], resp_data["resp_signal"], segment_start, segment_end,
resp_label, multi_labels="resp", event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[resp_chn_name2ax["bcg"]], resp_data["bcg_signal"], segment_start, segment_end,
resp_label, multi_labels="bcg", event_codes=[1, 2, 3, 4])
plt.show()

View File

@ -18,15 +18,19 @@
# 高幅值连续体动规则标定与剔除 # 高幅值连续体动规则标定与剔除
# 手动标定不可用区间提剔除 # 手动标定不可用区间提剔除
""" """
import sys
from pathlib import Path from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import shutil import shutil
import draw_tools import draw_tools
import utils import utils
import numpy as np import numpy as np
import signal_method import signal_method
import os import os
from matplotlib import pyplot as plt
os.environ['DISPLAY'] = "localhost:10.0" os.environ['DISPLAY'] = "localhost:10.0"
@ -48,56 +52,14 @@ def process_one_signal(samp_id, show=False):
save_samp_path = save_path / f"{samp_id}" save_samp_path = save_path / f"{samp_id}"
save_samp_path.mkdir(parents=True, exist_ok=True) save_samp_path.mkdir(parents=True, exist_ok=True)
signal_data_raw = utils.read_signal_txt(signal_path) signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, verbose=True)
signal_length = len(signal_data_raw)
print(f"signal_length: {signal_length}")
signal_fs = int(signal_path.stem.split("_")[-1])
print(f"signal_fs: {signal_fs}")
signal_second = signal_length // signal_fs
print(f"signal_second: {signal_second}")
# 根据采样率进行截断 signal_data, resp_data, resp_fs, bcg_data, bcg_fs = signal_method.signal_filter_split(conf=conf, signal_data_raw=signal_data_raw, signal_fs=signal_fs)
signal_data_raw = signal_data_raw[:signal_second * signal_fs]
# 滤波
# 50Hz陷波滤波器
# signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs)
print("Applying 50Hz notch filter...")
signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs)
resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs)
resp_fs = conf["resp"]["downsample_fs_1"]
resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs)
resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20)
resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"],
low_cut=conf["resp_filter"]["low_cut"],
high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"],
sample_rate=resp_fs)
print("Begin plotting signal data...")
# fig = plt.figure(figsize=(12, 8))
# # 绘制三个图raw_data、resp_data_1、resp_data_2
# ax0 = fig.add_subplot(3, 1, 1)
# ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue')
# ax0.set_title('Raw Signal Data')
# ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
# ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange')
# ax1.set_title('Resp Data after Average Filtering')
# ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
# ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green')
# ax2.set_title('Resp Data after Butterworth Filtering')
# plt.tight_layout()
# plt.show()
bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"],
low_cut=conf["bcg_filter"]["low_cut"],
high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"],
sample_rate=signal_fs)
# 降采样 # 降采样
old_resp_fs = resp_fs old_resp_fs = resp_fs
resp_fs = conf["resp"]["downsample_fs_2"] resp_fs = conf["resp"]["downsample_fs_2"]
resp_data = utils.downsample_signal_fast(original_signal=resp_data_2, original_fs=old_resp_fs, target_fs=resp_fs) resp_data = utils.downsample_signal_fast(original_signal=resp_data, original_fs=old_resp_fs, target_fs=resp_fs)
bcg_fs = conf["bcg"]["downsample_fs"] bcg_fs = conf["bcg"]["downsample_fs"]
bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs) bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs)
@ -233,7 +195,7 @@ def process_one_signal(samp_id, show=False):
save_path=save_samp_path / f"{samp_id}_Signal_Plots.png") save_path=save_samp_path / f"{samp_id}_Signal_Plots.png")
# 复制事件文件 到保存路径 # 复制事件文件 到保存路径
sa_label_save_name = f"{samp_id}" + label_path.name sa_label_save_name = f"{samp_id}_" + label_path.name
shutil.copyfile(label_path, save_samp_path / sa_label_save_name) shutil.copyfile(label_path, save_samp_path / sa_label_save_name)
# 新建一个dataframe分别是秒数、SA标签SA质量标签禁用标签Resp低幅值标签Resp体动标签Resp幅值突变标签Bcg低幅值标签Bcg体动标签Bcg幅值突变标签 # 新建一个dataframe分别是秒数、SA标签SA质量标签禁用标签Resp低幅值标签Resp体动标签Resp幅值突变标签Bcg低幅值标签Bcg体动标签Bcg幅值突变标签
@ -247,10 +209,10 @@ def process_one_signal(samp_id, show=False):
dtype=int), dtype=int),
"Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second, "Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second,
dtype=int), dtype=int),
"Bcg_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int), "BCG_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int),
"Bcg_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second, "BCG_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second,
dtype=int), dtype=int),
"Bcg_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second, "BCG_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second,
dtype=int) dtype=int)
} }
@ -259,13 +221,13 @@ def process_one_signal(samp_id, show=False):
if __name__ == '__main__': if __name__ == '__main__':
yaml_path = Path("../dataset_config/HYS_config.yaml") yaml_path = project_root_path / "dataset_config/HYS_config.yaml"
disable_df_path = Path("../排除区间.xlsx") disable_df_path = project_root_path / "排除区间.xlsx"
conf = utils.load_dataset_conf(yaml_path) conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"] select_ids = conf["select_ids"]
root_path = Path(conf["root_path"]) root_path = Path(conf["root_path"])
save_path = Path(conf["save_path"]) save_path = Path(conf["mask_save_path"])
print(f"select_ids: {select_ids}") print(f"select_ids: {select_ids}")
print(f"root_path: {root_path}") print(f"root_path: {root_path}")
@ -276,9 +238,9 @@ if __name__ == '__main__':
all_samp_disable_df = utils.read_disable_excel(disable_df_path) all_samp_disable_df = utils.read_disable_excel(disable_df_path)
process_one_signal(select_ids[6], show=True) # process_one_signal(select_ids[6], show=True)
# #
# for samp_id in select_ids: for samp_id in select_ids:
# print(f"Processing sample ID: {samp_id}") print(f"Processing sample ID: {samp_id}")
# process_one_signal(samp_id, show=False) process_one_signal(samp_id, show=False)
# print(f"Finished processing sample ID: {samp_id}\n\n") print(f"Finished processing sample ID: {samp_id}\n\n")

View File

@ -1,277 +0,0 @@
"""
本脚本完成对呼研所数据的处理包含以下功能
1. 数据读取与预处理
从传入路径中进行数据和标签的读取并进行初步的预处理
预处理包括为数据进行滤波去噪等操作
2. 数据清洗与异常值处理
3. 输出清晰后的统计信息
4. 数据保存
将处理后的数据保存到指定路径便于后续使用
主要是保存切分后的数据位置和标签
5. 可视化
提供数据处理前后的可视化对比帮助理解数据变化
绘制多条可用性趋势图展示数据的可用区间体动区间低幅值区间等
todo: 使用mask 屏蔽无用区间
# 低幅值区间规则标定与剔除
# 高幅值连续体动规则标定与剔除
# 手动标定不可用区间提剔除
"""
from pathlib import Path
import shutil
import draw_tools
import utils
import numpy as np
import signal_method
import os
from matplotlib import pyplot as plt
os.environ['DISPLAY'] = "localhost:10.0"
def process_one_signal(samp_id, show=False):
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
if not signal_path:
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
signal_path = signal_path[0]
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
label_path = (label_root_path / f"{samp_id}").glob("SA Label_corrected.csv")
if not label_path:
raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}")
label_path = list(label_path)[0]
print(f"Processing Label_corrected file: {label_path}")
signal_data_raw = utils.read_signal_txt(signal_path)
signal_length = len(signal_data_raw)
print(f"signal_length: {signal_length}")
signal_fs = int(signal_path.stem.split("_")[-1])
print(f"signal_fs: {signal_fs}")
signal_second = signal_length // signal_fs
print(f"signal_second: {signal_second}")
# 根据采样率进行截断
signal_data_raw = signal_data_raw[:signal_second * signal_fs]
# 滤波
# 50Hz陷波滤波器
# signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs)
print("Applying 50Hz notch filter...")
signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs)
resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs)
resp_fs = conf["resp"]["downsample_fs_1"]
resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs)
resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20)
resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"],
low_cut=conf["resp_filter"]["low_cut"],
high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"],
sample_rate=resp_fs)
print("Begin plotting signal data...")
# fig = plt.figure(figsize=(12, 8))
# # 绘制三个图raw_data、resp_data_1、resp_data_2
# ax0 = fig.add_subplot(3, 1, 1)
# ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue')
# ax0.set_title('Raw Signal Data')
# ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
# ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange')
# ax1.set_title('Resp Data after Average Filtering')
# ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
# ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green')
# ax2.set_title('Resp Data after Butterworth Filtering')
# plt.tight_layout()
# plt.show()
bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"],
low_cut=conf["bcg_filter"]["low_cut"],
high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"],
sample_rate=signal_fs)
# 降采样
old_resp_fs = resp_fs
resp_fs = conf["resp"]["downsample_fs_2"]
resp_data = utils.downsample_signal_fast(original_signal=resp_data_2, original_fs=old_resp_fs, target_fs=resp_fs)
bcg_fs = conf["bcg"]["downsample_fs"]
bcg_data = utils.downsample_signal_fast(original_signal=bcg_data, original_fs=signal_fs, target_fs=bcg_fs)
label_data = utils.read_label_csv(path=label_path)
event_mask, score_mask = utils.generate_event_mask(signal_second=signal_second, event_df=label_data)
manual_disable_mask = utils.generate_disable_mask(signal_second=signal_second, disable_df=all_samp_disable_df[
all_samp_disable_df["id"] == samp_id])
print(f"disable_mask_shape: {manual_disable_mask.shape}, num_disable: {np.sum(manual_disable_mask == 0)}")
# 分析Resp的低幅值区间
resp_low_amp_conf = conf.get("resp_low_amp", None)
if resp_low_amp_conf is not None:
resp_low_amp_mask, resp_low_amp_position_list = signal_method.detect_low_amplitude_signal(
signal_data=resp_data,
sampling_rate=resp_fs,
**resp_low_amp_conf
)
print(f"resp_low_amp_mask_shape: {resp_low_amp_mask.shape}, num_low_amp: {np.sum(resp_low_amp_mask == 1)}, count_low_amp_positions: {len(resp_low_amp_position_list)}")
else:
resp_low_amp_mask, resp_low_amp_position_list = None, None
print("resp_low_amp_mask is None")
# 分析Resp的高幅值伪迹区间
resp_movement_conf = conf.get("resp_movement", None)
if resp_movement_conf is not None:
raw_resp_movement_mask, resp_movement_mask, raw_resp_movement_position_list, resp_movement_position_list = signal_method.detect_movement(
signal_data=resp_data,
sampling_rate=resp_fs,
**resp_movement_conf
)
print(f"resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}")
else:
resp_movement_mask, resp_movement_position_list = None, None
print("resp_movement_mask is None")
resp_movement_revise_conf = conf.get("resp_movement_revise", None)
if resp_movement_mask is not None and resp_movement_revise_conf is not None:
resp_movement_mask, resp_movement_position_list = signal_method.movement_revise(
signal_data=resp_data,
movement_mask=resp_movement_mask,
movement_list=resp_movement_position_list,
sampling_rate=resp_fs,
**resp_movement_revise_conf,
verbose=False
)
print(f"After revise, resp_movement_mask_shape: {resp_movement_mask.shape}, num_movement: {np.sum(resp_movement_mask == 1)}, count_movement_positions: {len(resp_movement_position_list)}")
else:
print("resp_movement_mask revise is skipped")
# 分析Resp的幅值突变区间
resp_amp_change_conf = conf.get("resp_amp_change", None)
if resp_amp_change_conf is not None and resp_movement_mask is not None:
resp_amp_change_mask, resp_amp_change_list = signal_method.position_based_sleep_recognition_v3(
signal_data=resp_data,
movement_mask=resp_movement_mask,
movement_list=resp_movement_position_list,
sampling_rate=resp_fs,
**resp_amp_change_conf)
print(f"amp_change_mask_shape: {resp_amp_change_mask.shape}, num_amp_change: {np.sum(resp_amp_change_mask == 1)}, count_amp_change_positions: {len(resp_amp_change_list)}")
else:
resp_amp_change_mask = None
print("amp_change_mask is None")
# 分析Bcg的低幅值区间
bcg_low_amp_conf = conf.get("bcg_low_amp", None)
if bcg_low_amp_conf is not None:
bcg_low_amp_mask, bcg_low_amp_position_list = signal_method.detect_low_amplitude_signal(
signal_data=bcg_data,
sampling_rate=bcg_fs,
**bcg_low_amp_conf
)
print(f"bcg_low_amp_mask_shape: {bcg_low_amp_mask.shape}, num_low_amp: {np.sum(bcg_low_amp_mask == 1)}, count_low_amp_positions: {len(bcg_low_amp_position_list)}")
else:
bcg_low_amp_mask, bcg_low_amp_position_list = None, None
print("bcg_low_amp_mask is None")
# 分析Bcg的高幅值伪迹区间
bcg_movement_conf = conf.get("bcg_movement", None)
if bcg_movement_conf is not None:
raw_bcg_movement_mask, bcg_movement_mask, raw_bcg_movement_position_list, bcg_movement_position_list = signal_method.detect_movement(
signal_data=bcg_data,
sampling_rate=bcg_fs,
**bcg_movement_conf
)
print(f"bcg_movement_mask_shape: {bcg_movement_mask.shape}, num_movement: {np.sum(bcg_movement_mask == 1)}, count_movement_positions: {len(bcg_movement_position_list)}")
else:
bcg_movement_mask = None
print("bcg_movement_mask is None")
# 分析Bcg的幅值突变区间
if bcg_movement_mask is not None:
bcg_amp_change_mask, bcg_amp_change_list = signal_method.position_based_sleep_recognition_v2(
signal_data=bcg_data,
movement_mask=bcg_movement_mask,
sampling_rate=bcg_fs)
print(f"bcg_amp_change_mask_shape: {bcg_amp_change_mask.shape}, num_amp_change: {np.sum(bcg_amp_change_mask == 1)}, count_amp_change_positions: {len(bcg_amp_change_list)}")
else:
bcg_amp_change_mask = None
print("bcg_amp_change_mask is None")
# 如果signal_data采样率过进行降采样
if signal_fs == 1000:
signal_data = utils.downsample_signal_fast(original_signal=signal_data, original_fs=signal_fs, target_fs=100)
signal_data_raw = utils.downsample_signal_fast(original_signal=signal_data_raw, original_fs=signal_fs, target_fs=100)
signal_fs = 100
if show:
draw_tools.draw_signal_with_mask(samp_id=samp_id,
signal_data=signal_data,
signal_fs=signal_fs,
resp_data=resp_data,
resp_fs=resp_fs,
bcg_data=bcg_data,
bcg_fs=bcg_fs,
signal_disable_mask=manual_disable_mask,
resp_low_amp_mask=resp_low_amp_mask,
resp_movement_mask=resp_movement_mask,
resp_change_mask=resp_amp_change_mask,
resp_sa_mask=event_mask,
bcg_low_amp_mask=bcg_low_amp_mask,
bcg_movement_mask=bcg_movement_mask,
bcg_change_mask=bcg_amp_change_mask)
# 保存处理后的数据和标签
save_samp_path = save_path / f"{samp_id}"
save_samp_path.mkdir(parents=True, exist_ok=True)
# 复制事件文件 到保存路径
sa_label_save_name = f"{samp_id}" + label_path.name
shutil.copyfile(label_path, save_samp_path / sa_label_save_name)
# 新建一个dataframe分别是秒数、SA标签SA质量标签禁用标签Resp低幅值标签Resp体动标签Resp幅值突变标签Bcg低幅值标签Bcg体动标签Bcg幅值突变标签
save_dict = {
"Second": np.arange(signal_second),
"SA_Label": event_mask,
"SA_Score": score_mask,
"Disable_Label": manual_disable_mask,
"Resp_LowAmp_Label": resp_low_amp_mask if resp_low_amp_mask is not None else np.zeros(signal_second, dtype=int),
"Resp_Movement_Label": resp_movement_mask if resp_movement_mask is not None else np.zeros(signal_second, dtype=int),
"Resp_AmpChange_Label": resp_amp_change_mask if resp_amp_change_mask is not None else np.zeros(signal_second, dtype=int),
"Bcg_LowAmp_Label": bcg_low_amp_mask if bcg_low_amp_mask is not None else np.zeros(signal_second, dtype=int),
"Bcg_Movement_Label": bcg_movement_mask if bcg_movement_mask is not None else np.zeros(signal_second, dtype=int),
"Bcg_AmpChange_Label": bcg_amp_change_mask if bcg_amp_change_mask is not None else np.zeros(signal_second, dtype=int)
}
mask_label_save_name = f"{samp_id}_Processed_Labels.csv"
utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict)
if __name__ == '__main__':
yaml_path = Path("../dataset_config/ZD5Y_config.yaml")
disable_df_path = Path("../排除区间.xlsx")
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
save_path = Path(conf["save_path"])
print(f"select_ids: {select_ids}")
print(f"root_path: {root_path}")
print(f"save_path: {save_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned"
label_root_path = root_path / "Label"
all_samp_disable_df = utils.read_disable_excel(disable_df_path)
process_one_signal(select_ids[1], show=True)
# for samp_id in select_ids:
# print(f"Processing sample ID: {samp_id}")
# process_one_signal(samp_id, show=False)
# print(f"Finished processing sample ID: {samp_id}\n\n")

View File

@ -2,3 +2,5 @@ from .rule_base_event import detect_low_amplitude_signal, detect_movement
from .rule_base_event import position_based_sleep_recognition_v2, position_based_sleep_recognition_v3 from .rule_base_event import position_based_sleep_recognition_v2, position_based_sleep_recognition_v3
from .rule_base_event import movement_revise from .rule_base_event import movement_revise
from .time_metrics import calc_mav_by_slide_windows from .time_metrics import calc_mav_by_slide_windows
from .signal_process import signal_filter_split, rpeak2hr
from .normalize_method import normalize_resp_signal

View File

@ -0,0 +1,36 @@
import utils
import pandas as pd
import numpy as np
from scipy import signal
def normalize_resp_signal(resp_signal: np.ndarray, resp_fs, movement_mask, enable_list):
# 根据呼吸信号的幅值改变区间对每段进行Z-Score标准化
normalized_resp_signal = np.zeros_like(resp_signal)
# 全部填成nan
normalized_resp_signal[:] = np.nan
resp_signal_no_movement = resp_signal.copy()
resp_signal_no_movement[np.array(movement_mask == 1).repeat(resp_fs)] = np.nan
for i in range(len(enable_list)):
enable_start = enable_list[i][0] * resp_fs
enable_end = enable_list[i][1] * resp_fs
segment = resp_signal_no_movement[enable_start:enable_end]
# print(f"Normalizing segment {i+1}/{len(enable_list)}: start={enable_start}, end={enable_end}, length={len(segment)}")
segment_mean = np.nanmean(segment)
segment_std = np.nanstd(segment)
if segment_std == 0:
raise ValueError(f"segment_std is zero! segment_start: {enable_start}, segment_end: {enable_end}")
# 同下一个enable区间的体动一起进行标准化
if i <= len(enable_list) - 2:
enable_end = enable_list[i + 1][0] * resp_fs
raw_segment = resp_signal[enable_start:enable_end]
normalized_resp_signal[enable_start:enable_end] = (raw_segment - segment_mean) / segment_std
return normalized_resp_signal

View File

@ -0,0 +1,62 @@
import numpy as np
import utils
def signal_filter_split(conf, signal_data_raw, signal_fs):
# 滤波
# 50Hz陷波滤波器
# signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs)
print("Applying 50Hz notch filter...")
signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs)
resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs)
resp_fs = conf["resp"]["downsample_fs_1"]
resp_data_1 = utils.downsample_signal_fast(original_signal=resp_data_0, original_fs=signal_fs, target_fs=resp_fs)
resp_data_1 = utils.average_filter(raw_data=resp_data_1, sample_rate=resp_fs, window_size_sec=20)
resp_data_2 = utils.butterworth(data=resp_data_1, _type=conf["resp_filter"]["filter_type"],
low_cut=conf["resp_filter"]["low_cut"],
high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"],
sample_rate=resp_fs)
print("Begin plotting signal data...")
# fig = plt.figure(figsize=(12, 8))
# # 绘制三个图raw_data、resp_data_1、resp_data_2
# ax0 = fig.add_subplot(3, 1, 1)
# ax0.plot(np.linspace(0, len(signal_data) // signal_fs, len(signal_data)), signal_data, color='blue')
# ax0.set_title('Raw Signal Data')
# ax1 = fig.add_subplot(3, 1, 2, sharex=ax0)
# ax1.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_1)), resp_data_1, color='orange')
# ax1.set_title('Resp Data after Average Filtering')
# ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
# ax2.plot(np.linspace(0, len(resp_data_1) // resp_fs, len(resp_data_2)), resp_data_2, color='green')
# ax2.set_title('Resp Data after Butterworth Filtering')
# plt.tight_layout()
# plt.show()
bcg_data = utils.butterworth(data=signal_data, _type=conf["bcg_filter"]["filter_type"],
low_cut=conf["bcg_filter"]["low_cut"],
high_cut=conf["bcg_filter"]["high_cut"], order=conf["bcg_filter"]["order"],
sample_rate=signal_fs)
return signal_data, resp_data_2, resp_fs, bcg_data, signal_fs
def rpeak2hr(rpeak_indices, signal_length):
hr_signal = np.zeros(signal_length)
for i in range(1, len(rpeak_indices)):
rri = rpeak_indices[i] - rpeak_indices[i - 1]
if rri == 0:
continue
hr = 60 * 1000 / rri # 心率单位bpm
if hr > 120:
hr = 120
elif hr < 30:
hr = 30
hr_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = hr
# 填充最后一个R峰之后的心率值
if len(rpeak_indices) > 1:
hr_signal[rpeak_indices[-1]:] = hr_signal[rpeak_indices[-2]]
return hr_signal

View File

@ -1,9 +1,11 @@
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import utils
from .event_map import N2Chn
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from .operation_tools import event_mask_2_list
# 尝试导入 Polars # 尝试导入 Polars
try: try:
import polars as pl import polars as pl
@ -13,15 +15,17 @@ except ImportError:
HAS_POLARS = False HAS_POLARS = False
def read_signal_txt(path: Union[str, Path]) -> np.ndarray: def read_signal_txt(path: Union[str, Path], dtype, verbose=True, is_peak=False):
""" """
Read a txt file and return the first column as a numpy array. Read a txt file and return the first column as a numpy array.
Args: Args:
path (str | Path): Path to the txt file. :param path:
:param verbose:
:param dtype:
Returns: Returns:
np.ndarray: The first column of the txt file as a numpy array. np.ndarray: The first column of the txt file as a numpy array.
""" """
path = Path(path) path = Path(path)
if not path.exists(): if not path.exists():
@ -29,10 +33,30 @@ def read_signal_txt(path: Union[str, Path]) -> np.ndarray:
if HAS_POLARS: if HAS_POLARS:
df = pl.read_csv(path, has_header=False, infer_schema_length=0) df = pl.read_csv(path, has_header=False, infer_schema_length=0)
return df[:, 0].to_numpy().astype(float) signal_data_raw = df[:, 0].to_numpy().astype(dtype)
else: else:
df = pd.read_csv(path, header=None, dtype=float) df = pd.read_csv(path, header=None, dtype=dtype)
return df.iloc[:, 0].to_numpy() signal_data_raw = df.iloc[:, 0].to_numpy()
signal_original_length = len(signal_data_raw)
signal_fs = int(path.stem.split("_")[-1])
if is_peak:
signal_second = None
signal_length = None
else:
signal_second = signal_original_length // signal_fs
# 根据采样率进行截断
signal_data_raw = signal_data_raw[:signal_second * signal_fs]
signal_length = len(signal_data_raw)
if verbose:
print(f"Signal file read from {path}")
print(f"signal_fs: {signal_fs}")
print(f"signal_original_length: {signal_original_length}")
print(f"signal_after_cut_off_length: {signal_length}")
print(f"signal_second: {signal_second}")
return signal_data_raw, signal_length, signal_fs, signal_second
def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame: def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame:
@ -172,3 +196,99 @@ def read_disable_excel(path: Union[str, Path]) -> pd.DataFrame:
df["start"] = df["start"].astype(int) df["start"] = df["start"].astype(int)
df["end"] = df["end"].astype(int) df["end"] = df["end"].astype(int)
return df return df
def read_mask_execl(path: Union[str, Path]):
"""
Read an Excel file and return the mask as a numpy array.
Args:
path (str | Path): Path to the Excel file.
Returns:
np.ndarray: The mask as a numpy array.
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
df = pd.read_csv(path)
event_mask = df.to_dict(orient="list")
for key in event_mask:
event_mask[key] = np.array(event_mask[key])
event_list = {"RespAmpChangeSegment": event_mask_2_list(1 - event_mask["Resp_AmpChange_Label"]),
"BCGAmpChangeSegment": event_mask_2_list(1 - event_mask["BCG_AmpChange_Label"]),
"EnableSegment": event_mask_2_list(1 - event_mask["Disable_Label"]),}
return event_mask, event_list
def read_psg_channel(path_str: Union[str, Path], channel_number: list[int]):
"""
读取PSG文件中特定通道的数据
参数:
path_str (Union[str, Path]): 存放PSG文件的文件夹路径
channel_name (str): 需要读取的通道名称
返回:
np.ndarray: 指定通道的数据数组
"""
path = Path(path_str)
if not path.exists():
raise FileNotFoundError(f"PSG Dir not found: {path}")
if not path.is_dir():
raise NotADirectoryError(f"PSG Dir not found: {path}")
channel_data = {}
# 遍历检查通道对应的文件是否存在
for ch_id in channel_number:
ch_name = N2Chn[ch_id]
ch_path = list(path.glob(f"{ch_name}*.txt"))
if not any(ch_path):
raise FileNotFoundError(f"PSG Channel file not found: {ch_path}")
if len(ch_path) > 1:
print(f"Warning!!! PSG Channel file more than one: {ch_path}")
if ch_id == 8:
# sleep stage 特例 读取为整数
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=str, verbose=True)
# 转换为整数数组
for stage_str, stage_number in utils.Stage2N.items():
np.place(ch_signal, ch_signal == stage_str, stage_number)
ch_signal = ch_signal.astype(int)
elif ch_id == 1:
# Rpeak 特例 读取为整数
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=int, verbose=True, is_peak=True)
else:
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=float, verbose=True)
channel_data[ch_name] = {
"name": ch_name,
"path": ch_path[0],
"data": ch_signal,
"length": ch_length,
"fs": ch_fs,
"second": ch_second
}
return channel_data
def read_psg_label(path: Union[str, Path], verbose=True):
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
# 直接用pandas读取 包含中文 故指定编码
df = pd.read_csv(path, encoding="gbk")
if verbose:
print(f"Label file read from {path}, number of rows: {len(df)}")
# 丢掉Event type为空的行
df = df.dropna(subset=["Event type"], how='all').reset_index(drop=True)
return df

View File

@ -1,7 +1,10 @@
from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel, read_psg_label
from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list
from .operation_tools import merge_short_gaps, remove_short_durations from .operation_tools import merge_short_gaps, remove_short_durations
from .operation_tools import collect_values from .operation_tools import collect_values
from .operation_tools import save_process_label from .operation_tools import save_process_label
from .event_map import E2N from .operation_tools import none_to_nan_mask
from .signal_process import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel from .split_method import resp_split
from .HYS_FileReader import read_mask_execl, read_psg_channel
from .event_map import E2N, N2Chn, Stage2N, ColorCycle
from .filter_func import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel

View File

@ -5,3 +5,38 @@ E2N = {
"Obstructive apnea": 3, "Obstructive apnea": 3,
"Mixed apnea": 4 "Mixed apnea": 4
} }
N2Chn = {
1: "Rpeak",
2: "ECG_Sync",
3: "Effort Tho",
4: "Effort Abd",
5: "Flow P",
6: "Flow T",
7: "SpO2",
8: "5_class"
}
Stage2N = {
"W": 5,
"N1": 3,
"N2": 2,
"N3": 1,
"R": 4,
}
# 设定事件和其对应颜色
# event_code color event
# 0 黑色 背景
# 1 粉色 低通气
# 2 蓝色 中枢性
# 3 红色 阻塞型
# 4 灰色 混合型
# 5 绿色 血氧饱和度下降
# 6 橙色 大体动
# 7 橙色 小体动
# 8 橙色 深呼吸
# 9 橙色 脉冲体动
# 10 橙色 无效片段
ColorCycle = ["black", "pink", "blue", "red", "silver", "green", "orange", "orange", "orange", "orange",
"orange"]

View File

@ -198,16 +198,25 @@ def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray:
return disable_mask return disable_mask
def generate_event_mask(signal_second: int, event_df): def generate_event_mask(signal_second: int, event_df, use_correct=True):
event_mask = np.zeros(signal_second, dtype=int) event_mask = np.zeros(signal_second, dtype=int)
score_mask = np.zeros(signal_second, dtype=int) score_mask = np.zeros(signal_second, dtype=int)
if use_correct:
start_name = "correct_Start"
end_name = "correct_End"
event_type_name = "correct_EventsType"
else:
start_name = "Start"
end_name = "End"
event_type_name = "Event type"
# 剔除start = -1 的行 # 剔除start = -1 的行
event_df = event_df[event_df["correct_Start"] >= 0] event_df = event_df[event_df[start_name] >= 0]
for _, row in event_df.iterrows(): for _, row in event_df.iterrows():
start = row["correct_Start"] start = row[start_name]
end = row["correct_End"] + 1 end = row[end_name] + 1
event_mask[start:end] = E2N[row["correct_EventsType"]] event_mask[start:end] = E2N[row[event_type_name]]
score_mask[start:end] = row["score"] score_mask[start:end] = row["score"]
return event_mask, score_mask return event_mask, score_mask
@ -243,3 +252,12 @@ def collect_values(arr: np.ndarray, index: int, step: int, limit: int, mask=None
def save_process_label(save_path: Path, save_dict: dict): def save_process_label(save_path: Path, save_dict: dict):
save_df = pd.DataFrame(save_dict) save_df = pd.DataFrame(save_dict)
save_df.to_csv(save_path, index=False) save_df.to_csv(save_path, index=False)
def none_to_nan_mask(mask, ref):
"""将None转换为与ref形状相同的nan掩码"""
if mask is None:
return np.full_like(ref, np.nan)
else:
# 将mask中的0替换为nan其他的保持
mask = np.where(mask == 0, np.nan, mask)
return mask

27
utils/split_method.py Normal file
View File

@ -0,0 +1,27 @@
def resp_split(dataset_config, event_mask, event_list):
# 提取体动区间和呼吸低幅值区间
enable_list = event_list["EnableSegment"]
# 读取数据集配置
window_sec = dataset_config["window_sec"]
stride_sec = dataset_config["stride_sec"]
segment_list = []
# 遍历每个enable区间, 如果最后一个窗口不足stride的1/2则舍弃否则以enable_end为结尾截取一个窗口
for enable_start, enable_end in enable_list:
current_start = enable_start
while current_start + window_sec <= enable_end:
segment_list.append((current_start, current_start + window_sec))
current_start += stride_sec
# 检查最后一个窗口是否需要添加
if (enable_end - current_start >= stride_sec / 2) and (enable_end - current_start >= window_sec):
segment_list.append((enable_end - window_sec, enable_end))
return segment_list