优化数据处理模块,增加PSG信号绘图功能,重构部分函数以提高可读性和可维护性

This commit is contained in:
marques 2026-01-19 14:27:26 +08:00
parent d09ffecf70
commit 097c9cbf0b
15 changed files with 1228 additions and 83 deletions

View File

@ -0,0 +1,395 @@
import multiprocessing
import sys
from pathlib import Path
import os
import numpy as np
from utils import N2Chn
os.environ['DISPLAY'] = "localhost:10.0"
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import utils
import signal_method
import draw_tools
import shutil
import gc
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None):
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
total_seconds = min(
psg_data[i]["second"] for i in N2Chn.values() if i != "Rpeak"
)
for i in N2Chn.values():
if i == "Rpeak":
continue
length = int(total_seconds * psg_data[i]["fs"])
psg_data[i]["data"] = psg_data[i]["data"][:length]
psg_data[i]["length"] = length
psg_data[i]["second"] = total_seconds
psg_data["HR"] = {
"name": "HR",
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"],
psg_data["Rpeak"]["fs"]),
"fs": psg_data["ECG_Sync"]["fs"],
"length": psg_data["ECG_Sync"]["length"],
"second": psg_data["ECG_Sync"]["second"]
}
# 预处理与滤波
tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Tho"]["data"], effort_fs=psg_data["Effort Tho"]["fs"])
abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Abd"]["data"], effort_fs=psg_data["Effort Abd"]["fs"])
flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow P"]["data"], effort_fs=psg_data["Flow P"]["fs"])
flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow T"]["data"], effort_fs=psg_data["Flow T"]["fs"])
rri, rri_fs = signal_method.rpeak2rri_interpolation(rpeak_indices=psg_data["Rpeak"]["data"], ecg_fs=psg_data["ECG_Sync"]["fs"], rri_fs=100)
mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv")
if verbose:
print(f"mask_excel_path: {mask_excel_path}")
event_mask, event_list = utils.read_mask_execl(mask_excel_path)
enable_list = [[0, psg_data["Effort Tho"]["second"]]]
normalized_tho_signal = signal_method.normalize_resp_signal_by_segment(tho_data_filt, tho_fs, np.zeros(psg_data["Effort Tho"]["second"]), enable_list)
normalized_abd_signal = signal_method.normalize_resp_signal_by_segment(abd_data_filt, abd_fs, np.zeros(psg_data["Effort Abd"]["second"]), enable_list)
normalized_flowp_signal = signal_method.normalize_resp_signal_by_segment(flowp_data_filt, flowp_fs, np.zeros(psg_data["Flow P"]["second"]), enable_list)
normalized_flowt_signal = signal_method.normalize_resp_signal_by_segment(flowt_data_filt, flowt_fs, np.zeros(psg_data["Flow T"]["second"]), enable_list)
# 都调整至100Hz采样率
target_fs = 100
normalized_tho_signal = utils.adjust_sample_rate(normalized_tho_signal, tho_fs, target_fs)
normalized_abd_signal = utils.adjust_sample_rate(normalized_abd_signal, abd_fs, target_fs)
normalized_flowp_signal = utils.adjust_sample_rate(normalized_flowp_signal, flowp_fs, target_fs)
normalized_flowt_signal = utils.adjust_sample_rate(normalized_flowt_signal, flowt_fs, target_fs)
spo2_data_filt = utils.adjust_sample_rate(psg_data["SpO2"]["data"], psg_data["SpO2"]["fs"], target_fs)
normalized_effort_signal = (normalized_tho_signal + normalized_abd_signal) / 2
rri = utils.adjust_sample_rate(rri, rri_fs, target_fs)
# 调整至相同长度
min_length = min(len(normalized_tho_signal), len(normalized_abd_signal), len(normalized_flowp_signal), len(normalized_flowt_signal), len(spo2_data_filt), len(normalized_effort_signal)
,len(rri))
min_length = min_length - min_length % target_fs # 保证是整数秒
normalized_tho_signal = normalized_tho_signal[:min_length]
normalized_abd_signal = normalized_abd_signal[:min_length]
normalized_flowp_signal = normalized_flowp_signal[:min_length]
normalized_flowt_signal = normalized_flowt_signal[:min_length]
spo2_data_filt = spo2_data_filt[:min_length]
normalized_effort_signal = normalized_effort_signal[:min_length]
rri = rri[:min_length]
tho_second = min_length / target_fs
for i in event_mask.keys():
event_mask[i] = event_mask[i][:int(tho_second)]
spo2_data_filt_fill, spo2_disable_mask = utils.fill_spo2_anomaly(spo2_data=spo2_data_filt,
spo2_fs=target_fs,
max_fill_duration=30,
min_gap_duration=10,)
draw_tools.draw_psg_signal(
samp_id=samp_id,
tho_signal=normalized_tho_signal,
abd_signal=normalized_abd_signal,
flowp_signal=normalized_flowp_signal,
flowt_signal=normalized_flowt_signal,
spo2_signal=spo2_data_filt,
effort_signal=normalized_effort_signal,
rri_signal = rri,
fs=target_fs,
event_mask=event_mask["SA_Label"],
save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots.png",
show=show
)
draw_tools.draw_psg_signal(
samp_id=samp_id,
tho_signal=normalized_tho_signal,
abd_signal=normalized_abd_signal,
flowp_signal=normalized_flowp_signal,
flowt_signal=normalized_flowt_signal,
spo2_signal=spo2_data_filt_fill,
effort_signal=normalized_effort_signal,
rri_signal = rri,
fs=target_fs,
event_mask=event_mask["SA_Label"],
save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots_fill.png",
show=show
)
spo2_disable_mask = spo2_disable_mask[::target_fs]
min_len = min(len(event_mask["Disable_Label"]), len(spo2_disable_mask))
if len(event_mask["Disable_Label"]) != len(spo2_disable_mask):
print(f"Warning: Data length mismatch! Truncating to {min_len}.")
event_mask["Disable_Label"] = event_mask["Disable_Label"][:min_len] & spo2_disable_mask[:min_len]
event_list = {
"EnableSegment": utils.event_mask_2_list(1 - event_mask["Disable_Label"]),
"DisableSegment": utils.event_mask_2_list(event_mask["Disable_Label"])}
spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=95)
segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose)
if verbose:
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
# 复制mask到processed_Labels文件夹
save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv"
shutil.copyfile(mask_excel_path, save_mask_excel_path)
# 复制SA Label_corrected.csv到processed_Labels文件夹
sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_Sync.csv")
if sa_label_corrected_path.exists():
save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv"
shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path)
else:
if verbose:
print(f"Warning: {sa_label_corrected_path} does not exist.")
# 保存处理后的信号和截取的片段列表
save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz"
save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz"
# psg_data更新为处理后的信号
# 用下划线替换键里面的空格
psg_data = {
"Effort Tho": {
"name": "Effort_Tho",
"data": normalized_tho_signal,
"fs": target_fs,
"length": len(normalized_tho_signal),
"second": len(normalized_tho_signal) / target_fs
},
"Effort Abd": {
"name": "Effort_Abd",
"data": normalized_abd_signal,
"fs": target_fs,
"length": len(normalized_abd_signal),
"second": len(normalized_abd_signal) / target_fs
},
"Effort": {
"name": "Effort",
"data": normalized_effort_signal,
"fs": target_fs,
"length": len(normalized_effort_signal),
"second": len(normalized_effort_signal) / target_fs
},
"Flow P": {
"name": "Flow_P",
"data": normalized_flowp_signal,
"fs": target_fs,
"length": len(normalized_flowp_signal),
"second": len(normalized_flowp_signal) / target_fs
},
"Flow T": {
"name": "Flow_T",
"data": normalized_flowt_signal,
"fs": target_fs,
"length": len(normalized_flowt_signal),
"second": len(normalized_flowt_signal) / target_fs
},
"SpO2": {
"name": "SpO2",
"data": spo2_data_filt_fill,
"fs": target_fs,
"length": len(spo2_data_filt_fill),
"second": len(spo2_data_filt_fill) / target_fs
},
"HR": {
"name": "HR",
"data": psg_data["HR"]["data"],
"fs": psg_data["HR"]["fs"],
"length": psg_data["HR"]["length"],
"second": psg_data["HR"]["second"]
},
"RRI": {
"name": "RRI",
"data": rri,
"fs": target_fs,
"length": len(rri),
"second": len(rri) / target_fs
},
"5_class": {
"name": "Stage",
"data": psg_data["5_class"]["data"],
"fs": psg_data["5_class"]["fs"],
"length": psg_data["5_class"]["length"],
"second": psg_data["5_class"]["second"]
}
}
np.savez_compressed(save_signal_path, **psg_data)
np.savez_compressed(save_segment_path,
segment_list=segment_list,
disable_segment_list=disable_segment_list)
if verbose:
print(f"Saved processed signals to: {save_signal_path}")
print(f"Saved segment list to: {save_segment_path}")
if draw_segment:
total_len = len(segment_list) + len(disable_segment_list)
if verbose:
print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}")
draw_tools.draw_psg_label(
psg_data=psg_data,
psg_label=event_mask["SA_Label"],
segment_list=segment_list,
save_path=visual_path / f"{samp_id}" / "enable",
verbose=verbose,
multi_p=multi_p,
multi_task_id=multi_task_id
)
draw_tools.draw_psg_label(
psg_data=psg_data,
psg_label=event_mask["SA_Label"],
segment_list=disable_segment_list,
save_path=visual_path / f"{samp_id}" / "disable",
verbose=verbose,
multi_p=multi_p,
multi_task_id=multi_task_id
)
# 显式删除大型对象
try:
del psg_data
del normalized_tho_signal, normalized_abd_signal
del normalized_flowp_signal, normalized_flowt_signal
del normalized_effort_signal
del spo2_data_filt, spo2_data_filt_fill
del rri
del event_mask, event_list
del segment_list, disable_segment_list
except:
pass
# 强制垃圾回收
gc.collect()
def multiprocess_entry(_progress, task_id, _id):
build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id)
def multiprocess_with_tqdm(args_list, n_processes):
from concurrent.futures import ProcessPoolExecutor
from rich import progress
with progress.Progress(
"[progress.description]{task.description}",
progress.BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
progress.MofNCompleteColumn(),
progress.TimeRemainingColumn(),
progress.TimeElapsedColumn(),
refresh_per_second=1, # bit slower updates
transient=False
) as progress:
futures = []
with multiprocessing.Manager() as manager:
_progress = manager.dict()
overall_progress_task = progress.add_task("[green]All jobs progress:")
with ProcessPoolExecutor(max_workers=n_processes, mp_context=multiprocessing.get_context("spawn")) as executor:
for i_args in range(len(args_list)):
args = args_list[i_args]
task_id = progress.add_task(f"task {i_args}", visible=True)
futures.append(executor.submit(multiprocess_entry, _progress, task_id, args_list[i_args]))
# monitor the progress:
while (n_finished := sum([future.done() for future in futures])) < len(
futures
):
progress.update(
overall_progress_task, completed=n_finished, total=len(futures)
)
for task_id, update_data in _progress.items():
desc = update_data.get("desc", "")
# update the progress bar for this task:
progress.update(
task_id,
completed=update_data.get("progress", 0),
total=update_data.get("total", 0),
description=desc
)
# raise any errors:
for future in futures:
future.result()
def multiprocess_with_pool(args_list, n_processes):
"""使用Pool每个worker处理固定数量任务后重启"""
from multiprocessing import Pool
# maxtasksperchild 设置每个worker处理多少任务后重启释放内存
with Pool(processes=n_processes, maxtasksperchild=2) as pool:
results = []
for samp_id in args_list:
result = pool.apply_async(
build_HYS_dataset_segment,
args=(samp_id, False, True, False, None, None)
)
results.append(result)
# 等待所有任务完成
for i, result in enumerate(results):
try:
result.get()
print(f"Completed {i + 1}/{len(args_list)}")
except Exception as e:
print(f"Error processing task {i}: {e}")
pool.close()
pool.join()
if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml"
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
mask_path = Path(conf["mask_save_path"])
save_path = Path(conf["dataset_config"]["dataset_save_path"])
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
dataset_config = conf["dataset_config"]
visual_path.mkdir(parents=True, exist_ok=True)
save_processed_signal_path = save_path / "Signals"
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
save_segment_list_path = save_path / "Segments_List"
save_segment_list_path.mkdir(parents=True, exist_ok=True)
save_processed_label_path = save_path / "Labels"
save_processed_label_path.mkdir(parents=True, exist_ok=True)
# print(f"select_ids: {select_ids}")
# print(f"root_path: {root_path}")
# print(f"save_path: {save_path}")
# print(f"visual_path: {visual_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned"
print(select_ids)
# build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True)
# for samp_id in select_ids:
# print(f"Processing sample ID: {samp_id}")
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
# multiprocess_with_tqdm(args_list=select_ids, n_processes=8)
multiprocess_with_pool(args_list=select_ids, n_processes=8)

View File

@ -33,7 +33,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T
bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=verbose)
bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs, verbose=verbose)
normalized_resp_signal = signal_method.normalize_resp_signal(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"])
normalized_resp_signal = signal_method.normalize_resp_signal_by_segment(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"])
# 如果signal_data采样率过进行降采样
@ -123,7 +123,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
psg_data["HR"] = {
"name": "HR",
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"]),
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"], psg_data["Rpeak"]["fs"]),
"fs": psg_data["ECG_Sync"]["fs"],
"length": psg_data["ECG_Sync"]["length"],
"second": psg_data["ECG_Sync"]["second"]
@ -136,28 +136,28 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T
if verbose:
print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}")
draw_tools.draw_psg_bcg_label(psg_data=psg_data,
psg_label=psg_event_mask,
bcg_data=bcg_data,
event_mask=event_mask,
segment_list=segment_list,
save_path=visual_path / f"{samp_id}" / "enable",
verbose=verbose,
multi_p=multi_p,
multi_task_id=multi_task_id
)
draw_tools.draw_psg_bcg_label(
psg_data=psg_data,
psg_label=psg_event_mask,
bcg_data=bcg_data,
event_mask=event_mask,
segment_list=disable_segment_list,
save_path=visual_path / f"{samp_id}" / "disable",
verbose=verbose,
multi_p=multi_p,
multi_task_id=multi_task_id
)
# draw_tools.draw_psg_bcg_label(psg_data=psg_data,
# psg_label=psg_event_mask,
# bcg_data=bcg_data,
# event_mask=event_mask,
# segment_list=segment_list,
# save_path=visual_path / f"{samp_id}" / "enable",
# verbose=verbose,
# multi_p=multi_p,
# multi_task_id=multi_task_id
# )
#
# draw_tools.draw_psg_bcg_label(
# psg_data=psg_data,
# psg_label=psg_event_mask,
# bcg_data=bcg_data,
# event_mask=event_mask,
# segment_list=disable_segment_list,
# save_path=visual_path / f"{samp_id}" / "disable",
# verbose=verbose,
# multi_p=multi_p,
# multi_task_id=multi_task_id
# )
@ -241,11 +241,12 @@ if __name__ == '__main__':
org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned"
print(select_ids)
build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True)
# build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True)
# for samp_id in select_ids:
# print(f"Processing sample ID: {samp_id}")
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
# multiprocess_with_tqdm(args_list=select_ids, n_processes=16)
multiprocess_with_tqdm(args_list=select_ids, n_processes=16)

View File

@ -0,0 +1,129 @@
select_ids:
- 54
- 88
- 220
- 221
- 229
- 282
- 286
- 541
- 579
- 582
- 670
- 671
- 683
- 684
- 735
- 933
- 935
- 950
- 952
- 960
- 962
- 967
- 1302
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG
dataset_config:
window_sec: 180
stride_sec: 60
dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_PSG_dataset
dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_PSG_dataset/visualization
effort:
downsample_fs: 10
effort_filter:
filter_type: bandpass
low_cut: 0.05
high_cut: 0.5
order: 3
flow:
downsample_fs: 10
flow_filter:
filter_type: bandpass
low_cut: 0.05
high_cut: 0.5
order: 3
#ecg:
# downsample_fs: 100
#
#ecg_filter:
# filter_type: bandpass
# low_cut: 0.5
# high_cut: 40
# order: 5
#resp:
# downsample_fs_1: None
# downsample_fs_2: 10
#
#resp_filter:
# filter_type: bandpass
# low_cut: 0.05
# high_cut: 0.5
# order: 3
#
#resp_low_amp:
# window_size_sec: 30
# stride_sec:
# amplitude_threshold: 3
# merge_gap_sec: 60
# min_duration_sec: 60
#
#resp_movement:
# window_size_sec: 20
# stride_sec: 1
# std_median_multiplier: 4
# compare_intervals_sec:
# - 60
# - 120
## - 180
# interval_multiplier: 3
# merge_gap_sec: 30
# min_duration_sec: 1
#
#resp_movement_revise:
# up_interval_multiplier: 3
# down_interval_multiplier: 2
# compare_intervals_sec: 30
# merge_gap_sec: 10
# min_duration_sec: 1
#
#resp_amp_change:
# mav_calc_window_sec: 4
# threshold_amplitude: 0.25
# threshold_energy: 0.4
#
#
#bcg:
# downsample_fs: 100
#
#bcg_filter:
# filter_type: bandpass
# low_cut: 1
# high_cut: 10
# order: 10
#
#bcg_low_amp:
# window_size_sec: 1
# stride_sec:
# amplitude_threshold: 8
# merge_gap_sec: 20
# min_duration_sec: 3
#
#
#bcg_movement:
# window_size_sec: 2
# stride_sec:
# merge_gap_sec: 20
# min_duration_sec: 4

View File

@ -1,2 +1,2 @@
from .draw_statics import draw_signal_with_mask
from .draw_label import draw_psg_bcg_label, draw_resp_label
from .draw_statics import draw_signal_with_mask, draw_psg_signal
from .draw_label import draw_psg_bcg_label,draw_psg_label

View File

@ -7,10 +7,10 @@ import seaborn as sns
import numpy as np
from tqdm.rich import tqdm
import utils
import gc
# 添加with_prediction参数
psg_chn_name2ax = {
psg_bcg_chn_name2ax = {
"SpO2": 0,
"Flow T": 1,
"Flow P": 2,
@ -24,6 +24,19 @@ psg_chn_name2ax = {
"bcg_twinx": 10,
}
psg_chn_name2ax = {
"SpO2": 0,
"Flow T": 1,
"Flow P": 2,
"Effort Tho": 3,
"Effort Abd": 4,
"Effort": 5,
"HR": 6,
"RRI": 7,
"Stage": 8,
}
resp_chn_name2ax = {
"resp": 0,
"bcg": 1,
@ -39,6 +52,54 @@ def create_psg_bcg_figure():
ax = fig.add_subplot(gs[i])
axes.append(ax)
axes[psg_bcg_chn_name2ax["SpO2"]].grid(True)
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["SpO2"]].set_ylim((85, 100))
axes[psg_bcg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["Flow T"]].grid(True)
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["Flow P"]].grid(True)
# axes[2].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["Effort Tho"]].grid(True)
# axes[3].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["Effort Abd"]].grid(True)
# axes[4].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["HR"]].grid(True)
axes[psg_bcg_chn_name2ax["HR"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["resp"]].grid(True)
axes[psg_bcg_chn_name2ax["resp"]].tick_params(axis='x', colors="white")
axes.append(axes[psg_bcg_chn_name2ax["resp"]].twinx())
axes[psg_bcg_chn_name2ax["bcg"]].grid(True)
# axes[5].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white")
axes.append(axes[psg_bcg_chn_name2ax["bcg"]].twinx())
axes[psg_bcg_chn_name2ax["Stage"]].grid(True)
# axes[7].xaxis.set_major_formatter(Params.FORMATTER)
return fig, axes
def create_psg_figure():
fig = plt.figure(figsize=(12, 8), dpi=200)
gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 1, 1, 1])
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
axes = []
for i in range(9):
ax = fig.add_subplot(gs[i])
axes.append(ax)
axes[psg_chn_name2ax["SpO2"]].grid(True)
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_chn_name2ax["SpO2"]].set_ylim((85, 100))
@ -60,24 +121,21 @@ def create_psg_bcg_figure():
# axes[4].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Effort"]].grid(True)
axes[psg_chn_name2ax["Effort"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["HR"]].grid(True)
axes[psg_chn_name2ax["HR"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["resp"]].grid(True)
axes[psg_chn_name2ax["resp"]].tick_params(axis='x', colors="white")
axes.append(axes[psg_chn_name2ax["resp"]].twinx())
axes[psg_chn_name2ax["RRI"]].grid(True)
axes[psg_chn_name2ax["RRI"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["bcg"]].grid(True)
# axes[5].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white")
axes.append(axes[psg_chn_name2ax["bcg"]].twinx())
axes[psg_chn_name2ax["Stage"]].grid(True)
# axes[7].xaxis.set_major_formatter(Params.FORMATTER)
return fig, axes
def create_resp_figure():
fig = plt.figure(figsize=(12, 6), dpi=100)
gs = GridSpec(2, 1, height_ratios=[3, 2])
@ -150,8 +208,8 @@ def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, ev
def plt_stage_on_ax(ax, stage_data, segment_start, segment_end):
stage_signal = stage_data["data"]
stage_fs = stage_data["fs"]
time_axis = np.linspace(segment_start / stage_fs, segment_end / stage_fs, segment_end - segment_start)
ax.plot(time_axis, stage_signal[segment_start:segment_end], color='black', label=stage_data["name"])
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * stage_fs)
ax.plot(time_axis, stage_signal[segment_start * stage_fs:segment_end * stage_fs], color='black', label=stage_data["name"])
ax.set_ylim(0, 6)
ax.set_yticks([1, 2, 3, 4, 5])
ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"])
@ -162,11 +220,11 @@ def plt_stage_on_ax(ax, stage_data, segment_start, segment_end):
def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end):
spo2_signal = spo2_data["data"]
spo2_fs = spo2_data["fs"]
time_axis = np.linspace(segment_start / spo2_fs, segment_end / spo2_fs, segment_end - segment_start)
ax.plot(time_axis, spo2_signal[segment_start:segment_end], color='black', label=spo2_data["name"])
time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * spo2_fs)
ax.plot(time_axis, spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs], color='black', label=spo2_data["name"])
if spo2_signal[segment_start:segment_end].min() < 85:
ax.set_ylim((spo2_signal[segment_start:segment_end].min() - 5, 100))
if spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() < 85:
ax.set_ylim((spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() - 5, 100))
else:
ax.set_ylim((85, 100))
ax.set_ylabel("SpO2 (%)")
@ -197,6 +255,56 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list,
# event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0)
fig, axes = create_psg_bcg_figure()
for i, (segment_start, segment_end) in enumerate(segment_list):
for ax in axes:
ax.cla()
plt_spo2_on_ax(axes[psg_bcg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end)
plt_stage_on_ax(axes[psg_bcg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4],
ax2=axes[psg_bcg_chn_name2ax["resp_twinx"]])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4],
ax2=axes[psg_bcg_chn_name2ax["bcg_twinx"]])
if save_path is not None:
fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png")
# print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
if multi_p is not None:
multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"}
plt.close(fig)
plt.close('all')
gc.collect()
def draw_psg_label(psg_data, psg_label, segment_list, save_path=None, verbose=True,
multi_p=None, multi_task_id=None):
if save_path is not None:
save_path.mkdir(parents=True, exist_ok=True)
if multi_p is None:
# 遍历psg_data中所有数据的长度
for i in range(len(psg_data.keys())):
chn_name = list(psg_data.keys())[i]
print(f"{chn_name} data length: {len(psg_data[chn_name]['data'])}, fs: {psg_data[chn_name]['fs']}")
# psg_label的长度
print(f"psg_label length: {len(psg_label)}")
fig, axes = create_psg_figure()
for i, (segment_start, segment_end) in enumerate(segment_list):
for ax in axes:
ax.cla()
@ -211,14 +319,10 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort"]], psg_data["Effort"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4],
ax2=axes[psg_chn_name2ax["resp_twinx"]])
plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4],
ax2=axes[psg_chn_name2ax["bcg_twinx"]])
plt_signal_label_on_ax(axes[psg_chn_name2ax["RRI"]], psg_data["RRI"], segment_start, segment_end)
if save_path is not None:
fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png")
@ -226,23 +330,8 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list,
if multi_p is not None:
multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"}
plt.close(fig)
plt.close('all')
gc.collect()
def draw_resp_label(resp_data, resp_label, segment_list):
for mask in resp_label.keys():
if mask.startswith("Resp_"):
resp_label[mask] = utils.none_to_nan_mask(resp_label[mask], 0)
# resp_label["Resp_Score_Alpha"] = score_mask2alpha(resp_label["Resp_Score"])
# resp_label["Resp_Label_Alpha"] = utils.none_to_nan_mask(resp_label["Resp_Label_Alpha"], 0)
fig, axes = create_resp_figure()
for segment_start, segment_end in segment_list:
for ax in axes:
ax.cla()
plt_signal_label_on_ax(axes[resp_chn_name2ax["resp"]], resp_data["resp_signal"], segment_start, segment_end,
resp_label, multi_labels="resp", event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[resp_chn_name2ax["bcg"]], resp_data["bcg_signal"], segment_start, segment_end,
resp_label, multi_labels="bcg", event_codes=[1, 2, 3, 4])
plt.show()

View File

@ -247,6 +247,8 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs,
ax1_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask', 'SA Mask'], loc='upper right')
ax1.set_title(f'Sample {samp_id} - Respiration Component')
ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
ax2.plot(np.linspace(0, len(bcg_data) // bcg_fs, len(bcg_data)), bcg_data, color='green')
ax2.set_ylabel('Amplitude')
@ -300,5 +302,70 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs,
plt.show()
def draw_psg_signal(samp_id, tho_signal, abd_signal, flowp_signal, flowt_signal, spo2_signal, effort_signal, rri_signal, event_mask, fs,
show=False, save_path=None):
sa_mask = event_mask.repeat(fs)
fig, axs = plt.subplots(7, 1, figsize=(18, 12), sharex=True)
time_axis = np.linspace(0, len(tho_signal) / fs, len(tho_signal))
axs[0].plot(time_axis, tho_signal, label='THO', color='black')
axs[0].set_title(f'Sample {samp_id} - PSG Signal Data')
axs[0].set_ylabel('THO Amplitude')
axs[0].legend(loc='upper right')
ax0_twin = axs[0].twinx()
ax0_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
ax0_twin.autoscale(enable=False, axis='y', tight=True)
ax0_twin.set_ylim((-4, 5))
axs[1].plot(time_axis, abd_signal, label='ABD', color='black')
axs[1].set_ylabel('ABD Amplitude')
axs[1].legend(loc='upper right')
ax1_twin = axs[1].twinx()
ax1_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
ax1_twin.autoscale(enable=False, axis='y', tight=True)
ax1_twin.set_ylim((-4, 5))
axs[2].plot(time_axis, effort_signal, label='EFFO', color='black')
axs[2].set_ylabel('EFFO Amplitude')
axs[2].legend(loc='upper right')
ax2_twin = axs[2].twinx()
ax2_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
ax2_twin.autoscale(enable=False, axis='y', tight=True)
ax2_twin.set_ylim((-4, 5))
axs[3].plot(time_axis, flowp_signal, label='FLOWP', color='black')
axs[3].set_ylabel('FLOWP Amplitude')
axs[3].legend(loc='upper right')
ax3_twin = axs[3].twinx()
ax3_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
ax3_twin.autoscale(enable=False, axis='y', tight=True)
ax3_twin.set_ylim((-4, 5))
axs[4].plot(time_axis, flowt_signal, label='FLOWT', color='black')
axs[4].set_ylabel('FLOWT Amplitude')
axs[4].legend(loc='upper right')
ax4_twin = axs[4].twinx()
ax4_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
ax4_twin.autoscale(enable=False, axis='y', tight=True)
ax4_twin.set_ylim((-4, 5))
axs[5].plot(time_axis, rri_signal, label='RRI', color='black')
axs[5].set_ylabel('RRI Amplitude')
axs[5].legend(loc='upper right')
axs[6].plot(time_axis, spo2_signal, label='SPO2', color='black')
axs[6].set_ylabel('SPO2 Amplitude')
axs[6].set_xlabel('Time (s)')
axs[6].legend(loc='upper right')
if save_path is not None:
plt.savefig(save_path, dpi=300)
if show:
plt.show()

View File

@ -0,0 +1,183 @@
"""
本脚本完成对呼研所数据的处理包含以下功能
1. 数据读取与预处理
从传入路径中进行数据和标签的读取并进行初步的预处理
预处理包括为数据进行滤波去噪等操作
2. 数据清洗与异常值处理
3. 输出清晰后的统计信息
4. 数据保存
将处理后的数据保存到指定路径便于后续使用
主要是保存切分后的数据位置和标签
5. 可视化
提供数据处理前后的可视化对比帮助理解数据变化
绘制多条可用性趋势图展示数据的可用区间体动区间低幅值区间等
# 低幅值区间规则标定与剔除
# 高幅值连续体动规则标定与剔除
# 手动标定不可用区间提剔除
"""
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import shutil
import draw_tools
import utils
import numpy as np
import signal_method
import os
os.environ['DISPLAY'] = "localhost:10.0"
def process_one_signal(samp_id, show=False):
pass
tho_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Tho_Sync_*.txt"))
abd_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Abd_Sync_*.txt"))
flowp_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow P_Sync_*.txt"))
flowt_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow T_Sync_*.txt"))
spo2_signal_path = list((org_signal_root_path / f"{samp_id}").glob("SpO2_Sync_*.txt"))
stage_signal_path = list((org_signal_root_path / f"{samp_id}").glob("5_class_Sync_*.txt"))
if not tho_signal_path:
raise FileNotFoundError(f"Effort Tho_Sync file not found for sample ID: {samp_id}")
tho_signal_path = tho_signal_path[0]
print(f"Processing Effort Tho_Sync signal file: {tho_signal_path}")
if not abd_signal_path:
raise FileNotFoundError(f"Effort Abd_Sync file not found for sample ID: {samp_id}")
abd_signal_path = abd_signal_path[0]
print(f"Processing Effort Abd_Sync signal file: {abd_signal_path}")
if not flowp_signal_path:
raise FileNotFoundError(f"Flow P_Sync file not found for sample ID: {samp_id}")
flowp_signal_path = flowp_signal_path[0]
print(f"Processing Flow P_Sync signal file: {flowp_signal_path}")
if not flowt_signal_path:
raise FileNotFoundError(f"Flow T_Sync file not found for sample ID: {samp_id}")
flowt_signal_path = flowt_signal_path[0]
print(f"Processing Flow T_Sync signal file: {flowt_signal_path}")
if not spo2_signal_path:
raise FileNotFoundError(f"SpO2_Sync file not found for sample ID: {samp_id}")
spo2_signal_path = spo2_signal_path[0]
print(f"Processing SpO2_Sync signal file: {spo2_signal_path}")
if not stage_signal_path:
raise FileNotFoundError(f"5_class_Sync file not found for sample ID: {samp_id}")
stage_signal_path = stage_signal_path[0]
print(f"Processing 5_class_Sync signal file: {stage_signal_path}")
label_path = (label_root_path / f"{samp_id}").glob("SA Label_Sync.csv")
if not label_path:
raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}")
label_path = list(label_path)[0]
print(f"Processing Label_corrected file: {label_path}")
#
# # 保存处理后的数据和标签
save_samp_path = save_path / f"{samp_id}"
save_samp_path.mkdir(parents=True, exist_ok=True)
# # # 读取信号数据
tho_data_raw, tho_length, tho_fs, tho_second = utils.read_signal_txt(tho_signal_path, dtype=float, verbose=True)
# abd_data_raw, abd_length, abd_fs, abd_second = utils.read_signal_txt(abd_signal_path, dtype=float, verbose=True)
# flowp_data_raw, flowp_length, flowp_fs, flowp_second = utils.read_signal_txt(flowp_signal_path, dtype=float, verbose=True)
# flowt_data_raw, flowt_length, flowt_fs, flowt_second = utils.read_signal_txt(flowt_signal_path, dtype=float, verbose=True)
# spo2_data_raw, spo2_length, spo2_fs, spo2_second = utils.read_signal_txt(spo2_signal_path, dtype=float, verbose=True)
stage_data_raw, stage_length, stage_fs, stage_second = utils.read_signal_txt(stage_signal_path, dtype=str, verbose=True)
#
# # 预处理与滤波
# tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=tho_data_raw, effort_fs=tho_fs)
# abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=abd_data_raw, effort_fs=abd_fs)
# flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=flowp_data_raw, effort_fs=flowp_fs)
# flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=flowt_data_raw, effort_fs=flowt_fs)
# 降采样
# old_tho_fs = tho_fs
# tho_fs = conf["effort"]["downsample_fs"]
# tho_data_filt = utils.downsample_signal_fast(original_signal=tho_data_filt, original_fs=old_tho_fs, target_fs=tho_fs)
# old_abd_fs = abd_fs
# abd_fs = conf["effort"]["downsample_fs"]
# abd_data_filt = utils.downsample_signal_fast(original_signal=abd_data_filt, original_fs=old_abd_fs, target_fs=abd_fs)
# old_flowp_fs = flowp_fs
# flowp_fs = conf["effort"]["downsample_fs"]
# flowp_data_filt = utils.downsample_signal_fast(original_signal=flowp_data_filt, original_fs=old_flowp_fs, target_fs=flowp_fs)
# old_flowt_fs = flowt_fs
# flowt_fs = conf["effort"]["downsample_fs"]
# flowt_data_filt = utils.downsample_signal_fast(original_signal=flowt_data_filt, original_fs=old_flowt_fs, target_fs=flowt_fs)
# spo2不降采样
# spo2_data_filt = spo2_data_raw
# spo2_fs = spo2_fs
label_data = utils.read_raw_psg_label(path=label_path)
event_mask, score_mask = utils.generate_event_mask(signal_second=tho_second, event_df=label_data, use_correct=False, with_score=False)
# event_mask > 0 的部分为1其他为0
score_mask = np.where(event_mask > 0, 1, 0)
# 根据睡眠分期生成不可用区间
wake_mask = utils.get_wake_mask(stage_data_raw)
# 剔除短于60秒的觉醒区间
wake_mask = utils.remove_short_durations(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), min_duration_sec=60)
# 合并短于120秒的觉醒区间
wake_mask = utils.merge_short_gaps(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), max_gap_sec=60)
disable_label = wake_mask
disable_label = disable_label[:tho_second]
# 复制事件文件 到保存路径
sa_label_save_name = f"{samp_id}_" + label_path.name
shutil.copyfile(label_path, save_samp_path / sa_label_save_name)
#
# 新建一个dataframe分别是秒数、SA标签
save_dict = {
"Second": np.arange(tho_second),
"SA_Label": event_mask,
"SA_Score": score_mask,
"Disable_Label": disable_label,
"Resp_LowAmp_Label": np.zeros_like(event_mask),
"Resp_Movement_Label": np.zeros_like(event_mask),
"Resp_AmpChange_Label": np.zeros_like(event_mask),
"BCG_LowAmp_Label": np.zeros_like(event_mask),
"BCG_Movement_Label": np.zeros_like(event_mask),
"BCG_AmpChange_Label": np.zeros_like(event_mask)
}
mask_label_save_name = f"{samp_id}_Processed_Labels.csv"
utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict)
if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml"
# disable_df_path = project_root_path / "排除区间.xlsx"
#
conf = utils.load_dataset_conf(yaml_path)
root_path = Path(conf["root_path"])
save_path = Path(conf["mask_save_path"])
select_ids = conf["select_ids"]
#
print(f"select_ids: {select_ids}")
print(f"root_path: {root_path}")
print(f"save_path: {save_path}")
#
org_signal_root_path = root_path / "PSG_Aligned"
label_root_path = root_path / "PSG_Aligned"
#
# all_samp_disable_df = utils.read_disable_excel(disable_df_path)
#
# process_one_signal(select_ids[0], show=True)
# #
for samp_id in select_ids:
print(f"Processing sample ID: {samp_id}")
process_one_signal(samp_id, show=False)
print(f"Finished processing sample ID: {samp_id}\n\n")
pass

View File

@ -2,5 +2,5 @@ from .rule_base_event import detect_low_amplitude_signal, detect_movement
from .rule_base_event import position_based_sleep_recognition_v2, position_based_sleep_recognition_v3
from .rule_base_event import movement_revise
from .time_metrics import calc_mav_by_slide_windows
from .signal_process import signal_filter_split, rpeak2hr
from .normalize_method import normalize_resp_signal
from .signal_process import signal_filter_split, rpeak2hr, psg_effort_filter, rpeak2rri_interpolation
from .normalize_method import normalize_resp_signal_by_segment

View File

@ -3,7 +3,7 @@ import pandas as pd
import numpy as np
from scipy import signal
def normalize_resp_signal(resp_signal: np.ndarray, resp_fs, movement_mask, enable_list):
def normalize_resp_signal_by_segment(resp_signal: np.ndarray, resp_fs, movement_mask, enable_list):
# 根据呼吸信号的幅值改变区间对每段进行Z-Score标准化
normalized_resp_signal = np.zeros_like(resp_signal)
# 全部填成nan
@ -33,4 +33,20 @@ def normalize_resp_signal(resp_signal: np.ndarray, resp_fs, movement_mask, enabl
raw_segment = resp_signal[enable_start:enable_end]
normalized_resp_signal[enable_start:enable_end] = (raw_segment - segment_mean) / segment_std
#如果enable区间不从0开始则将前面的部分也进行标准化
if enable_list[0][0] > 0:
new_enable_start = 0
enable_start = enable_list[0][0] * resp_fs
enable_end = enable_list[0][1] * resp_fs
segment = resp_signal_no_movement[enable_start:enable_end]
segment_mean = np.nanmean(segment)
segment_std = np.nanstd(segment)
if segment_std == 0:
raise ValueError(f"segment_std is zero! segment_start: {enable_start}, segment_end: {enable_end}")
raw_segment = resp_signal[new_enable_start:enable_start]
normalized_resp_signal[new_enable_start:enable_start] = (raw_segment - segment_mean) / segment_std
return normalized_resp_signal

View File

@ -1,4 +1,5 @@
import numpy as np
from scipy.interpolate import interp1d
import utils
@ -44,14 +45,24 @@ def signal_filter_split(conf, signal_data_raw, signal_fs, verbose=True):
return signal_data, resp_data_2, resp_fs, bcg_data, signal_fs
def psg_effort_filter(conf, effort_data_raw, effort_fs):
# 滤波
effort_data_1 = utils.bessel(data=effort_data_raw, _type=conf["effort_filter"]["filter_type"],
low_cut=conf["effort_filter"]["low_cut"],
high_cut=conf["effort_filter"]["high_cut"], order=conf["effort_filter"]["order"],
sample_rate=effort_fs)
# 移动平均
effort_data_2 = utils.average_filter(raw_data=effort_data_1, sample_rate=effort_fs, window_size_sec=20)
return effort_data_raw, effort_data_2, effort_fs
def rpeak2hr(rpeak_indices, signal_length):
def rpeak2hr(rpeak_indices, signal_length, ecg_fs):
hr_signal = np.zeros(signal_length)
for i in range(1, len(rpeak_indices)):
rri = rpeak_indices[i] - rpeak_indices[i - 1]
if rri == 0:
continue
hr = 60 * 1000 / rri # 心率单位bpm
hr = 60 * ecg_fs / rri # 心率单位bpm
if hr > 120:
hr = 120
elif hr < 30:
@ -62,3 +73,35 @@ def rpeak2hr(rpeak_indices, signal_length):
hr_signal[rpeak_indices[-1]:] = hr_signal[rpeak_indices[-2]]
return hr_signal
def rpeak2rri_repeat(rpeak_indices, signal_length, ecg_fs):
rri_signal = np.zeros(signal_length)
for i in range(1, len(rpeak_indices)):
rri = rpeak_indices[i] - rpeak_indices[i - 1]
rri_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = rri
# 填充最后一个R峰之后的RRI值
if len(rpeak_indices) > 1:
rri_signal[rpeak_indices[-1]:] = rri_signal[rpeak_indices[-2]]
# 遍历异常值
for i in range(1, len(rpeak_indices)):
rri = rpeak_indices[i] - rpeak_indices[i - 1]
if rri < 0.3 * ecg_fs or rri > 2 * ecg_fs:
rri_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = 0
return rri_signal
def rpeak2rri_interpolation(rpeak_indices, ecg_fs, rri_fs):
r_peak_time = np.asarray(rpeak_indices) / ecg_fs
rri = np.diff(r_peak_time)
t_rri = r_peak_time[1:]
mask = (rri > 0.3) & (rri < 2.0)
rri_clean = rri[mask]
t_rri_clean = t_rri[mask]
t_uniform = np.arange(t_rri_clean[0], t_rri_clean[-1], 1/rri_fs)
f = interp1d(t_rri_clean, rri_clean, kind='linear', fill_value="extrapolate")
rri_uniform = f(t_uniform)
return rri_uniform, rri_fs

View File

@ -178,6 +178,55 @@ def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame:
return df
def read_raw_psg_label(path: Union[str, Path], verbose=True) -> pd.DataFrame:
"""
Read a CSV file and return it as a pandas DataFrame.
Args:
path (str | Path): Path to the CSV file.
verbose (bool):
Returns:
pd.DataFrame: The content of the CSV file as a pandas DataFrame.
:param path:
:param verbose:
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
# 直接用pandas读取 包含中文 故指定编码
df = pd.read_csv(path, encoding="gbk")
if verbose:
print(f"Label file read from {path}, number of rows: {len(df)}")
num_psg_events = np.sum(df["Event type"].notna())
# 统计事件
num_psg_hyp = np.sum(df["Event type"] == "Hypopnea")
num_psg_csa = np.sum(df["Event type"] == "Central apnea")
num_psg_osa = np.sum(df["Event type"] == "Obstructive apnea")
num_psg_msa = np.sum(df["Event type"] == "Mixed apnea")
if verbose:
print("Event Statistics:")
# 格式化输出 总计/来自PSG/手动/删除/未标注 指定宽度
print(f"Type {'Total':^8s}")
print(
f"Hyp: {num_psg_hyp:^8d} ")
print(
f"CSA: {num_psg_csa:^8d} ")
print(
f"OSA: {num_psg_osa:^8d} ")
print(
f"MSA: {num_psg_msa:^8d} ")
print(
f"All: {num_psg_events:^8d}")
df["Start"] = df["Start"].astype(int)
df["End"] = df["End"].astype(int)
return df
def read_disable_excel(path: Union[str, Path]) -> pd.DataFrame:
"""
Read an Excel file and return it as a pandas DataFrame.
@ -225,6 +274,15 @@ def read_mask_execl(path: Union[str, Path]):
return event_mask, event_list
def read_psg_mask_excel(path: Union[str, Path]):
df = pd.read_csv(path)
event_mask = df.to_dict(orient="list")
for key in event_mask:
event_mask[key] = np.array(event_mask[key])
return event_mask
def read_psg_channel(path_str: Union[str, Path], channel_number: list[int], verbose=True):
"""

View File

@ -1,10 +1,12 @@
from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel, read_psg_label
from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel, read_psg_label, read_raw_psg_label, read_psg_mask_excel
from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list
from .operation_tools import merge_short_gaps, remove_short_durations
from .operation_tools import collect_values
from .operation_tools import save_process_label
from .operation_tools import none_to_nan_mask
from .operation_tools import get_wake_mask
from .operation_tools import fill_spo2_anomaly
from .split_method import resp_split
from .HYS_FileReader import read_mask_execl, read_psg_channel
from .event_map import E2N, N2Chn, Stage2N, ColorCycle
from .filter_func import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel
from .filter_func import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel, adjust_sample_rate

View File

@ -20,6 +20,7 @@ def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10, sample_rate=10
raise ValueError("Please choose a type of fliter")
@timing_decorator()
def bessel(data, _type, low_cut=0.0, high_cut=0.0, order=4, sample_rate=1000):
if _type == "lowpass": # 低通滤波处理
b, a = signal.bessel(order, low_cut / (sample_rate * 0.5), btype='lowpass', analog=False, norm='mag')
@ -89,6 +90,52 @@ def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=1
return downsampled_signal
def upsample_signal(original_signal, original_fs, target_fs):
"""
信号升采样
参数:
original_signal : array-like, 原始信号数组
original_fs : float, 原始采样率 (Hz)
target_fs : float, 目标采样率 (Hz)
返回:
upsampled_signal : array-like, 升采样后的信号
"""
if not isinstance(original_signal, np.ndarray):
original_signal = np.array(original_signal)
if target_fs <= original_fs:
raise ValueError("目标采样率必须大于原始采样率")
if target_fs <= 0 or original_fs <= 0:
raise ValueError("采样率必须为正数")
upsample_factor = target_fs / original_fs
num_output_samples = int(len(original_signal) * upsample_factor)
upsampled_signal = signal.resample(original_signal, num_output_samples)
return upsampled_signal
def adjust_sample_rate(signal_data, original_fs, target_fs):
"""
根据信号的原始采样率和目标采样率自动选择升采样或降采样
参数:
signal_data : array-like, 原始信号数组
original_fs : float, 原始采样率 (Hz)
target_fs : float, 目标采样率 (Hz)
返回:
adjusted_signal : array-like, 调整采样率后的信号
"""
if original_fs == target_fs:
return signal_data
elif original_fs > target_fs:
return downsample_signal_fast(signal_data, original_fs, target_fs)
else:
return upsample_signal(signal_data, original_fs, target_fs)
@timing_decorator()
def average_filter(raw_data, sample_rate, window_size_sec=20):

View File

@ -6,6 +6,7 @@ import pandas as pd
from matplotlib import pyplot as plt
import yaml
from numpy.ma.core import append
from scipy.interpolate import PchipInterpolator
from utils.event_map import E2N
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
@ -198,9 +199,12 @@ def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray:
return disable_mask
def generate_event_mask(signal_second: int, event_df, use_correct=True):
def generate_event_mask(signal_second: int, event_df, use_correct=True, with_score=True):
event_mask = np.zeros(signal_second, dtype=int)
if with_score:
score_mask = np.zeros(signal_second, dtype=int)
else:
score_mask = None
if use_correct:
start_name = "correct_Start"
end_name = "correct_End"
@ -217,6 +221,7 @@ def generate_event_mask(signal_second: int, event_df, use_correct=True):
start = row[start_name]
end = row[end_name] + 1
event_mask[start:end] = E2N[row[event_type_name]]
if with_score:
score_mask[start:end] = row["score"]
return event_mask, score_mask
@ -261,3 +266,115 @@ def none_to_nan_mask(mask, ref):
# 将mask中的0替换为nan其他的保持
mask = np.where(mask == 0, np.nan, mask)
return mask
def get_wake_mask(sleep_stage_mask):
# 将N1, N2, N3, REM视为睡眠 0其他为清醒 1
# 输入是字符 分别有 'W', 'N1', 'N2', 'N3', 'R' 等
wake_mask = np.where(np.isin(sleep_stage_mask, ['N1', 'N2', 'N3', 'REM', 'R']), 0, 1)
return wake_mask
def detect_spo2_anomaly(spo2, fs, diff_thresh=7):
anomaly = np.zeros(len(spo2), dtype=bool)
# 生理范围
anomaly |= (spo2 < 50) | (spo2 > 100)
# 突变
diff = np.abs(np.diff(spo2, prepend=spo2[0]))
anomaly |= diff > diff_thresh
# NaN
anomaly |= np.isnan(spo2)
return anomaly
def merge_close_anomalies(anomaly, fs, min_gap_duration):
min_gap = int(min_gap_duration * fs)
merged = anomaly.copy()
i = 0
n = len(anomaly)
while i < n:
if not anomaly[i]:
i += 1
continue
# 当前异常段
start = i
while i < n and anomaly[i]:
i += 1
end = i
# 向后看 gap
j = end
while j < n and not anomaly[j]:
j += 1
if j < n and (j - end) < min_gap:
merged[end:j] = True
return merged
def fill_spo2_anomaly(
spo2_data,
spo2_fs,
max_fill_duration,
min_gap_duration,
):
spo2 = spo2_data.astype(float).copy()
n = len(spo2)
anomaly = detect_spo2_anomaly(spo2, spo2_fs)
anomaly = merge_close_anomalies(anomaly, spo2_fs, min_gap_duration)
max_len = int(max_fill_duration * spo2_fs)
valid_mask = ~anomaly
i = 0
while i < n:
if not anomaly[i]:
i += 1
continue
start = i
while i < n and anomaly[i]:
i += 1
end = i
seg_len = end - start
# 超长异常段
if seg_len > max_len:
spo2[start:end] = np.nan
valid_mask[start:end] = False
continue
has_left = start > 0 and valid_mask[start - 1]
has_right = end < n and valid_mask[end]
# 开头异常:单侧填充
if not has_left and has_right:
spo2[start:end] = spo2[end]
continue
# 结尾异常:单侧填充
if has_left and not has_right:
spo2[start:end] = spo2[start - 1]
continue
# 两侧都有 → PCHIP
if has_left and has_right:
x = np.array([start - 1, end])
y = np.array([spo2[start - 1], spo2[end]])
interp = PchipInterpolator(x, y)
spo2[start:end] = interp(np.arange(start, end))
continue
# 两侧都没有(极端情况)
spo2[start:end] = np.nan
valid_mask[start:end] = False
return spo2, valid_mask

View File

@ -54,5 +54,3 @@ def resp_split(dataset_config, event_mask, event_list, verbose=False):
return segment_list, disable_segment_list