优化数据处理模块,增加PSG信号绘图功能,重构部分函数以提高可读性和可维护性

This commit is contained in:
marques 2026-01-19 14:27:26 +08:00
parent d09ffecf70
commit 097c9cbf0b
15 changed files with 1228 additions and 83 deletions

View File

@ -0,0 +1,395 @@
import multiprocessing
import sys
from pathlib import Path
import os
import numpy as np
from utils import N2Chn
os.environ['DISPLAY'] = "localhost:10.0"
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import utils
import signal_method
import draw_tools
import shutil
import gc
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None):
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
total_seconds = min(
psg_data[i]["second"] for i in N2Chn.values() if i != "Rpeak"
)
for i in N2Chn.values():
if i == "Rpeak":
continue
length = int(total_seconds * psg_data[i]["fs"])
psg_data[i]["data"] = psg_data[i]["data"][:length]
psg_data[i]["length"] = length
psg_data[i]["second"] = total_seconds
psg_data["HR"] = {
"name": "HR",
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"],
psg_data["Rpeak"]["fs"]),
"fs": psg_data["ECG_Sync"]["fs"],
"length": psg_data["ECG_Sync"]["length"],
"second": psg_data["ECG_Sync"]["second"]
}
# 预处理与滤波
tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Tho"]["data"], effort_fs=psg_data["Effort Tho"]["fs"])
abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Abd"]["data"], effort_fs=psg_data["Effort Abd"]["fs"])
flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow P"]["data"], effort_fs=psg_data["Flow P"]["fs"])
flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow T"]["data"], effort_fs=psg_data["Flow T"]["fs"])
rri, rri_fs = signal_method.rpeak2rri_interpolation(rpeak_indices=psg_data["Rpeak"]["data"], ecg_fs=psg_data["ECG_Sync"]["fs"], rri_fs=100)
mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv")
if verbose:
print(f"mask_excel_path: {mask_excel_path}")
event_mask, event_list = utils.read_mask_execl(mask_excel_path)
enable_list = [[0, psg_data["Effort Tho"]["second"]]]
normalized_tho_signal = signal_method.normalize_resp_signal_by_segment(tho_data_filt, tho_fs, np.zeros(psg_data["Effort Tho"]["second"]), enable_list)
normalized_abd_signal = signal_method.normalize_resp_signal_by_segment(abd_data_filt, abd_fs, np.zeros(psg_data["Effort Abd"]["second"]), enable_list)
normalized_flowp_signal = signal_method.normalize_resp_signal_by_segment(flowp_data_filt, flowp_fs, np.zeros(psg_data["Flow P"]["second"]), enable_list)
normalized_flowt_signal = signal_method.normalize_resp_signal_by_segment(flowt_data_filt, flowt_fs, np.zeros(psg_data["Flow T"]["second"]), enable_list)
# 都调整至100Hz采样率
target_fs = 100
normalized_tho_signal = utils.adjust_sample_rate(normalized_tho_signal, tho_fs, target_fs)
normalized_abd_signal = utils.adjust_sample_rate(normalized_abd_signal, abd_fs, target_fs)
normalized_flowp_signal = utils.adjust_sample_rate(normalized_flowp_signal, flowp_fs, target_fs)
normalized_flowt_signal = utils.adjust_sample_rate(normalized_flowt_signal, flowt_fs, target_fs)
spo2_data_filt = utils.adjust_sample_rate(psg_data["SpO2"]["data"], psg_data["SpO2"]["fs"], target_fs)
normalized_effort_signal = (normalized_tho_signal + normalized_abd_signal) / 2
rri = utils.adjust_sample_rate(rri, rri_fs, target_fs)
# 调整至相同长度
min_length = min(len(normalized_tho_signal), len(normalized_abd_signal), len(normalized_flowp_signal), len(normalized_flowt_signal), len(spo2_data_filt), len(normalized_effort_signal)
,len(rri))
min_length = min_length - min_length % target_fs # 保证是整数秒
normalized_tho_signal = normalized_tho_signal[:min_length]
normalized_abd_signal = normalized_abd_signal[:min_length]
normalized_flowp_signal = normalized_flowp_signal[:min_length]
normalized_flowt_signal = normalized_flowt_signal[:min_length]
spo2_data_filt = spo2_data_filt[:min_length]
normalized_effort_signal = normalized_effort_signal[:min_length]
rri = rri[:min_length]
tho_second = min_length / target_fs
for i in event_mask.keys():
event_mask[i] = event_mask[i][:int(tho_second)]
spo2_data_filt_fill, spo2_disable_mask = utils.fill_spo2_anomaly(spo2_data=spo2_data_filt,
spo2_fs=target_fs,
max_fill_duration=30,
min_gap_duration=10,)
draw_tools.draw_psg_signal(
samp_id=samp_id,
tho_signal=normalized_tho_signal,
abd_signal=normalized_abd_signal,
flowp_signal=normalized_flowp_signal,
flowt_signal=normalized_flowt_signal,
spo2_signal=spo2_data_filt,
effort_signal=normalized_effort_signal,
rri_signal = rri,
fs=target_fs,
event_mask=event_mask["SA_Label"],
save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots.png",
show=show
)
draw_tools.draw_psg_signal(
samp_id=samp_id,
tho_signal=normalized_tho_signal,
abd_signal=normalized_abd_signal,
flowp_signal=normalized_flowp_signal,
flowt_signal=normalized_flowt_signal,
spo2_signal=spo2_data_filt_fill,
effort_signal=normalized_effort_signal,
rri_signal = rri,
fs=target_fs,
event_mask=event_mask["SA_Label"],
save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots_fill.png",
show=show
)
spo2_disable_mask = spo2_disable_mask[::target_fs]
min_len = min(len(event_mask["Disable_Label"]), len(spo2_disable_mask))
if len(event_mask["Disable_Label"]) != len(spo2_disable_mask):
print(f"Warning: Data length mismatch! Truncating to {min_len}.")
event_mask["Disable_Label"] = event_mask["Disable_Label"][:min_len] & spo2_disable_mask[:min_len]
event_list = {
"EnableSegment": utils.event_mask_2_list(1 - event_mask["Disable_Label"]),
"DisableSegment": utils.event_mask_2_list(event_mask["Disable_Label"])}
spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=95)
segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose)
if verbose:
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
# 复制mask到processed_Labels文件夹
save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv"
shutil.copyfile(mask_excel_path, save_mask_excel_path)
# 复制SA Label_corrected.csv到processed_Labels文件夹
sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_Sync.csv")
if sa_label_corrected_path.exists():
save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv"
shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path)
else:
if verbose:
print(f"Warning: {sa_label_corrected_path} does not exist.")
# 保存处理后的信号和截取的片段列表
save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz"
save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz"
# psg_data更新为处理后的信号
# 用下划线替换键里面的空格
psg_data = {
"Effort Tho": {
"name": "Effort_Tho",
"data": normalized_tho_signal,
"fs": target_fs,
"length": len(normalized_tho_signal),
"second": len(normalized_tho_signal) / target_fs
},
"Effort Abd": {
"name": "Effort_Abd",
"data": normalized_abd_signal,
"fs": target_fs,
"length": len(normalized_abd_signal),
"second": len(normalized_abd_signal) / target_fs
},
"Effort": {
"name": "Effort",
"data": normalized_effort_signal,
"fs": target_fs,
"length": len(normalized_effort_signal),
"second": len(normalized_effort_signal) / target_fs
},
"Flow P": {
"name": "Flow_P",
"data": normalized_flowp_signal,
"fs": target_fs,
"length": len(normalized_flowp_signal),
"second": len(normalized_flowp_signal) / target_fs
},
"Flow T": {
"name": "Flow_T",
"data": normalized_flowt_signal,
"fs": target_fs,
"length": len(normalized_flowt_signal),
"second": len(normalized_flowt_signal) / target_fs
},
"SpO2": {
"name": "SpO2",
"data": spo2_data_filt_fill,
"fs": target_fs,
"length": len(spo2_data_filt_fill),
"second": len(spo2_data_filt_fill) / target_fs
},
"HR": {
"name": "HR",
"data": psg_data["HR"]["data"],
"fs": psg_data["HR"]["fs"],
"length": psg_data["HR"]["length"],
"second": psg_data["HR"]["second"]
},
"RRI": {
"name": "RRI",
"data": rri,
"fs": target_fs,
"length": len(rri),
"second": len(rri) / target_fs
},
"5_class": {
"name": "Stage",
"data": psg_data["5_class"]["data"],
"fs": psg_data["5_class"]["fs"],
"length": psg_data["5_class"]["length"],
"second": psg_data["5_class"]["second"]
}
}
np.savez_compressed(save_signal_path, **psg_data)
np.savez_compressed(save_segment_path,
segment_list=segment_list,
disable_segment_list=disable_segment_list)
if verbose:
print(f"Saved processed signals to: {save_signal_path}")
print(f"Saved segment list to: {save_segment_path}")
if draw_segment:
total_len = len(segment_list) + len(disable_segment_list)
if verbose:
print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}")
draw_tools.draw_psg_label(
psg_data=psg_data,
psg_label=event_mask["SA_Label"],
segment_list=segment_list,
save_path=visual_path / f"{samp_id}" / "enable",
verbose=verbose,
multi_p=multi_p,
multi_task_id=multi_task_id
)
draw_tools.draw_psg_label(
psg_data=psg_data,
psg_label=event_mask["SA_Label"],
segment_list=disable_segment_list,
save_path=visual_path / f"{samp_id}" / "disable",
verbose=verbose,
multi_p=multi_p,
multi_task_id=multi_task_id
)
# 显式删除大型对象
try:
del psg_data
del normalized_tho_signal, normalized_abd_signal
del normalized_flowp_signal, normalized_flowt_signal
del normalized_effort_signal
del spo2_data_filt, spo2_data_filt_fill
del rri
del event_mask, event_list
del segment_list, disable_segment_list
except:
pass
# 强制垃圾回收
gc.collect()
def multiprocess_entry(_progress, task_id, _id):
build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id)
def multiprocess_with_tqdm(args_list, n_processes):
from concurrent.futures import ProcessPoolExecutor
from rich import progress
with progress.Progress(
"[progress.description]{task.description}",
progress.BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
progress.MofNCompleteColumn(),
progress.TimeRemainingColumn(),
progress.TimeElapsedColumn(),
refresh_per_second=1, # bit slower updates
transient=False
) as progress:
futures = []
with multiprocessing.Manager() as manager:
_progress = manager.dict()
overall_progress_task = progress.add_task("[green]All jobs progress:")
with ProcessPoolExecutor(max_workers=n_processes, mp_context=multiprocessing.get_context("spawn")) as executor:
for i_args in range(len(args_list)):
args = args_list[i_args]
task_id = progress.add_task(f"task {i_args}", visible=True)
futures.append(executor.submit(multiprocess_entry, _progress, task_id, args_list[i_args]))
# monitor the progress:
while (n_finished := sum([future.done() for future in futures])) < len(
futures
):
progress.update(
overall_progress_task, completed=n_finished, total=len(futures)
)
for task_id, update_data in _progress.items():
desc = update_data.get("desc", "")
# update the progress bar for this task:
progress.update(
task_id,
completed=update_data.get("progress", 0),
total=update_data.get("total", 0),
description=desc
)
# raise any errors:
for future in futures:
future.result()
def multiprocess_with_pool(args_list, n_processes):
"""使用Pool每个worker处理固定数量任务后重启"""
from multiprocessing import Pool
# maxtasksperchild 设置每个worker处理多少任务后重启释放内存
with Pool(processes=n_processes, maxtasksperchild=2) as pool:
results = []
for samp_id in args_list:
result = pool.apply_async(
build_HYS_dataset_segment,
args=(samp_id, False, True, False, None, None)
)
results.append(result)
# 等待所有任务完成
for i, result in enumerate(results):
try:
result.get()
print(f"Completed {i + 1}/{len(args_list)}")
except Exception as e:
print(f"Error processing task {i}: {e}")
pool.close()
pool.join()
if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml"
conf = utils.load_dataset_conf(yaml_path)
select_ids = conf["select_ids"]
root_path = Path(conf["root_path"])
mask_path = Path(conf["mask_save_path"])
save_path = Path(conf["dataset_config"]["dataset_save_path"])
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
dataset_config = conf["dataset_config"]
visual_path.mkdir(parents=True, exist_ok=True)
save_processed_signal_path = save_path / "Signals"
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
save_segment_list_path = save_path / "Segments_List"
save_segment_list_path.mkdir(parents=True, exist_ok=True)
save_processed_label_path = save_path / "Labels"
save_processed_label_path.mkdir(parents=True, exist_ok=True)
# print(f"select_ids: {select_ids}")
# print(f"root_path: {root_path}")
# print(f"save_path: {save_path}")
# print(f"visual_path: {visual_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned"
print(select_ids)
# build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True)
# for samp_id in select_ids:
# print(f"Processing sample ID: {samp_id}")
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
# multiprocess_with_tqdm(args_list=select_ids, n_processes=8)
multiprocess_with_pool(args_list=select_ids, n_processes=8)

View File

@ -33,7 +33,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T
bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=verbose) bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=verbose)
bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs, verbose=verbose) bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs, verbose=verbose)
normalized_resp_signal = signal_method.normalize_resp_signal(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"]) normalized_resp_signal = signal_method.normalize_resp_signal_by_segment(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"])
# 如果signal_data采样率过进行降采样 # 如果signal_data采样率过进行降采样
@ -123,7 +123,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose) psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
psg_data["HR"] = { psg_data["HR"] = {
"name": "HR", "name": "HR",
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"]), "data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"], psg_data["Rpeak"]["fs"]),
"fs": psg_data["ECG_Sync"]["fs"], "fs": psg_data["ECG_Sync"]["fs"],
"length": psg_data["ECG_Sync"]["length"], "length": psg_data["ECG_Sync"]["length"],
"second": psg_data["ECG_Sync"]["second"] "second": psg_data["ECG_Sync"]["second"]
@ -136,28 +136,28 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=T
if verbose: if verbose:
print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}") print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}")
draw_tools.draw_psg_bcg_label(psg_data=psg_data, # draw_tools.draw_psg_bcg_label(psg_data=psg_data,
psg_label=psg_event_mask, # psg_label=psg_event_mask,
bcg_data=bcg_data, # bcg_data=bcg_data,
event_mask=event_mask, # event_mask=event_mask,
segment_list=segment_list, # segment_list=segment_list,
save_path=visual_path / f"{samp_id}" / "enable", # save_path=visual_path / f"{samp_id}" / "enable",
verbose=verbose, # verbose=verbose,
multi_p=multi_p, # multi_p=multi_p,
multi_task_id=multi_task_id # multi_task_id=multi_task_id
) # )
#
draw_tools.draw_psg_bcg_label( # draw_tools.draw_psg_bcg_label(
psg_data=psg_data, # psg_data=psg_data,
psg_label=psg_event_mask, # psg_label=psg_event_mask,
bcg_data=bcg_data, # bcg_data=bcg_data,
event_mask=event_mask, # event_mask=event_mask,
segment_list=disable_segment_list, # segment_list=disable_segment_list,
save_path=visual_path / f"{samp_id}" / "disable", # save_path=visual_path / f"{samp_id}" / "disable",
verbose=verbose, # verbose=verbose,
multi_p=multi_p, # multi_p=multi_p,
multi_task_id=multi_task_id # multi_task_id=multi_task_id
) # )
@ -241,11 +241,12 @@ if __name__ == '__main__':
org_signal_root_path = root_path / "OrgBCG_Aligned" org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned" psg_signal_root_path = root_path / "PSG_Aligned"
print(select_ids)
build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True) # build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True)
# for samp_id in select_ids: # for samp_id in select_ids:
# print(f"Processing sample ID: {samp_id}") # print(f"Processing sample ID: {samp_id}")
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) # build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
# multiprocess_with_tqdm(args_list=select_ids, n_processes=16) multiprocess_with_tqdm(args_list=select_ids, n_processes=16)

View File

@ -0,0 +1,129 @@
select_ids:
- 54
- 88
- 220
- 221
- 229
- 282
- 286
- 541
- 579
- 582
- 670
- 671
- 683
- 684
- 735
- 933
- 935
- 950
- 952
- 960
- 962
- 967
- 1302
root_path: /mnt/disk_wd/marques_dataset/DataCombine2023/HYS
mask_save_path: /mnt/disk_code/marques/dataprepare/output/HYS_PSG
dataset_config:
window_sec: 180
stride_sec: 60
dataset_save_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_PSG_dataset
dataset_visual_path: /mnt/disk_wd/marques_dataset/SA_dataset/HYS_PSG_dataset/visualization
effort:
downsample_fs: 10
effort_filter:
filter_type: bandpass
low_cut: 0.05
high_cut: 0.5
order: 3
flow:
downsample_fs: 10
flow_filter:
filter_type: bandpass
low_cut: 0.05
high_cut: 0.5
order: 3
#ecg:
# downsample_fs: 100
#
#ecg_filter:
# filter_type: bandpass
# low_cut: 0.5
# high_cut: 40
# order: 5
#resp:
# downsample_fs_1: None
# downsample_fs_2: 10
#
#resp_filter:
# filter_type: bandpass
# low_cut: 0.05
# high_cut: 0.5
# order: 3
#
#resp_low_amp:
# window_size_sec: 30
# stride_sec:
# amplitude_threshold: 3
# merge_gap_sec: 60
# min_duration_sec: 60
#
#resp_movement:
# window_size_sec: 20
# stride_sec: 1
# std_median_multiplier: 4
# compare_intervals_sec:
# - 60
# - 120
## - 180
# interval_multiplier: 3
# merge_gap_sec: 30
# min_duration_sec: 1
#
#resp_movement_revise:
# up_interval_multiplier: 3
# down_interval_multiplier: 2
# compare_intervals_sec: 30
# merge_gap_sec: 10
# min_duration_sec: 1
#
#resp_amp_change:
# mav_calc_window_sec: 4
# threshold_amplitude: 0.25
# threshold_energy: 0.4
#
#
#bcg:
# downsample_fs: 100
#
#bcg_filter:
# filter_type: bandpass
# low_cut: 1
# high_cut: 10
# order: 10
#
#bcg_low_amp:
# window_size_sec: 1
# stride_sec:
# amplitude_threshold: 8
# merge_gap_sec: 20
# min_duration_sec: 3
#
#
#bcg_movement:
# window_size_sec: 2
# stride_sec:
# merge_gap_sec: 20
# min_duration_sec: 4

View File

@ -1,2 +1,2 @@
from .draw_statics import draw_signal_with_mask from .draw_statics import draw_signal_with_mask, draw_psg_signal
from .draw_label import draw_psg_bcg_label, draw_resp_label from .draw_label import draw_psg_bcg_label,draw_psg_label

View File

@ -7,10 +7,10 @@ import seaborn as sns
import numpy as np import numpy as np
from tqdm.rich import tqdm from tqdm.rich import tqdm
import utils import utils
import gc
# 添加with_prediction参数 # 添加with_prediction参数
psg_chn_name2ax = { psg_bcg_chn_name2ax = {
"SpO2": 0, "SpO2": 0,
"Flow T": 1, "Flow T": 1,
"Flow P": 2, "Flow P": 2,
@ -24,6 +24,19 @@ psg_chn_name2ax = {
"bcg_twinx": 10, "bcg_twinx": 10,
} }
psg_chn_name2ax = {
"SpO2": 0,
"Flow T": 1,
"Flow P": 2,
"Effort Tho": 3,
"Effort Abd": 4,
"Effort": 5,
"HR": 6,
"RRI": 7,
"Stage": 8,
}
resp_chn_name2ax = { resp_chn_name2ax = {
"resp": 0, "resp": 0,
"bcg": 1, "bcg": 1,
@ -39,6 +52,54 @@ def create_psg_bcg_figure():
ax = fig.add_subplot(gs[i]) ax = fig.add_subplot(gs[i])
axes.append(ax) axes.append(ax)
axes[psg_bcg_chn_name2ax["SpO2"]].grid(True)
# axes[0].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["SpO2"]].set_ylim((85, 100))
axes[psg_bcg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["Flow T"]].grid(True)
# axes[1].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["Flow P"]].grid(True)
# axes[2].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["Effort Tho"]].grid(True)
# axes[3].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["Effort Abd"]].grid(True)
# axes[4].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["HR"]].grid(True)
axes[psg_bcg_chn_name2ax["HR"]].tick_params(axis='x', colors="white")
axes[psg_bcg_chn_name2ax["resp"]].grid(True)
axes[psg_bcg_chn_name2ax["resp"]].tick_params(axis='x', colors="white")
axes.append(axes[psg_bcg_chn_name2ax["resp"]].twinx())
axes[psg_bcg_chn_name2ax["bcg"]].grid(True)
# axes[5].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_bcg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white")
axes.append(axes[psg_bcg_chn_name2ax["bcg"]].twinx())
axes[psg_bcg_chn_name2ax["Stage"]].grid(True)
# axes[7].xaxis.set_major_formatter(Params.FORMATTER)
return fig, axes
def create_psg_figure():
fig = plt.figure(figsize=(12, 8), dpi=200)
gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 1, 1, 1])
fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0)
axes = []
for i in range(9):
ax = fig.add_subplot(gs[i])
axes.append(ax)
axes[psg_chn_name2ax["SpO2"]].grid(True) axes[psg_chn_name2ax["SpO2"]].grid(True)
# axes[0].xaxis.set_major_formatter(Params.FORMATTER) # axes[0].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_chn_name2ax["SpO2"]].set_ylim((85, 100)) axes[psg_chn_name2ax["SpO2"]].set_ylim((85, 100))
@ -60,24 +121,21 @@ def create_psg_bcg_figure():
# axes[4].xaxis.set_major_formatter(Params.FORMATTER) # axes[4].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["Effort"]].grid(True)
axes[psg_chn_name2ax["Effort"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["HR"]].grid(True) axes[psg_chn_name2ax["HR"]].grid(True)
axes[psg_chn_name2ax["HR"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["HR"]].tick_params(axis='x', colors="white")
axes[psg_chn_name2ax["resp"]].grid(True) axes[psg_chn_name2ax["RRI"]].grid(True)
axes[psg_chn_name2ax["resp"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["RRI"]].tick_params(axis='x', colors="white")
axes.append(axes[psg_chn_name2ax["resp"]].twinx())
axes[psg_chn_name2ax["bcg"]].grid(True)
# axes[5].xaxis.set_major_formatter(Params.FORMATTER)
axes[psg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white")
axes.append(axes[psg_chn_name2ax["bcg"]].twinx())
axes[psg_chn_name2ax["Stage"]].grid(True) axes[psg_chn_name2ax["Stage"]].grid(True)
# axes[7].xaxis.set_major_formatter(Params.FORMATTER)
return fig, axes return fig, axes
def create_resp_figure(): def create_resp_figure():
fig = plt.figure(figsize=(12, 6), dpi=100) fig = plt.figure(figsize=(12, 6), dpi=100)
gs = GridSpec(2, 1, height_ratios=[3, 2]) gs = GridSpec(2, 1, height_ratios=[3, 2])
@ -150,8 +208,8 @@ def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, ev
def plt_stage_on_ax(ax, stage_data, segment_start, segment_end): def plt_stage_on_ax(ax, stage_data, segment_start, segment_end):
stage_signal = stage_data["data"] stage_signal = stage_data["data"]
stage_fs = stage_data["fs"] stage_fs = stage_data["fs"]
time_axis = np.linspace(segment_start / stage_fs, segment_end / stage_fs, segment_end - segment_start) time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * stage_fs)
ax.plot(time_axis, stage_signal[segment_start:segment_end], color='black', label=stage_data["name"]) ax.plot(time_axis, stage_signal[segment_start * stage_fs:segment_end * stage_fs], color='black', label=stage_data["name"])
ax.set_ylim(0, 6) ax.set_ylim(0, 6)
ax.set_yticks([1, 2, 3, 4, 5]) ax.set_yticks([1, 2, 3, 4, 5])
ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"]) ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"])
@ -162,11 +220,11 @@ def plt_stage_on_ax(ax, stage_data, segment_start, segment_end):
def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end): def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end):
spo2_signal = spo2_data["data"] spo2_signal = spo2_data["data"]
spo2_fs = spo2_data["fs"] spo2_fs = spo2_data["fs"]
time_axis = np.linspace(segment_start / spo2_fs, segment_end / spo2_fs, segment_end - segment_start) time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * spo2_fs)
ax.plot(time_axis, spo2_signal[segment_start:segment_end], color='black', label=spo2_data["name"]) ax.plot(time_axis, spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs], color='black', label=spo2_data["name"])
if spo2_signal[segment_start:segment_end].min() < 85: if spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() < 85:
ax.set_ylim((spo2_signal[segment_start:segment_end].min() - 5, 100)) ax.set_ylim((spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() - 5, 100))
else: else:
ax.set_ylim((85, 100)) ax.set_ylim((85, 100))
ax.set_ylabel("SpO2 (%)") ax.set_ylabel("SpO2 (%)")
@ -197,6 +255,56 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list,
# event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0) # event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0)
fig, axes = create_psg_bcg_figure() fig, axes = create_psg_bcg_figure()
for i, (segment_start, segment_end) in enumerate(segment_list):
for ax in axes:
ax.cla()
plt_spo2_on_ax(axes[psg_bcg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end)
plt_stage_on_ax(axes[psg_bcg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4],
ax2=axes[psg_bcg_chn_name2ax["resp_twinx"]])
plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4],
ax2=axes[psg_bcg_chn_name2ax["bcg_twinx"]])
if save_path is not None:
fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png")
# print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
if multi_p is not None:
multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"}
plt.close(fig)
plt.close('all')
gc.collect()
def draw_psg_label(psg_data, psg_label, segment_list, save_path=None, verbose=True,
multi_p=None, multi_task_id=None):
if save_path is not None:
save_path.mkdir(parents=True, exist_ok=True)
if multi_p is None:
# 遍历psg_data中所有数据的长度
for i in range(len(psg_data.keys())):
chn_name = list(psg_data.keys())[i]
print(f"{chn_name} data length: {len(psg_data[chn_name]['data'])}, fs: {psg_data[chn_name]['fs']}")
# psg_label的长度
print(f"psg_label length: {len(psg_label)}")
fig, axes = create_psg_figure()
for i, (segment_start, segment_end) in enumerate(segment_list): for i, (segment_start, segment_end) in enumerate(segment_list):
for ax in axes: for ax in axes:
ax.cla() ax.cla()
@ -211,14 +319,10 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list,
psg_label, event_codes=[1, 2, 3, 4]) psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end, plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4]) psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort"]], psg_data["Effort"], segment_start, segment_end,
psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end, plt_signal_label_on_ax(axes[psg_chn_name2ax["RRI"]], psg_data["RRI"], segment_start, segment_end)
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4],
ax2=axes[psg_chn_name2ax["resp_twinx"]])
plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4],
ax2=axes[psg_chn_name2ax["bcg_twinx"]])
if save_path is not None: if save_path is not None:
fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png") fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png")
@ -226,23 +330,8 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list,
if multi_p is not None: if multi_p is not None:
multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"} multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"}
plt.close(fig)
plt.close('all')
gc.collect()
def draw_resp_label(resp_data, resp_label, segment_list):
for mask in resp_label.keys():
if mask.startswith("Resp_"):
resp_label[mask] = utils.none_to_nan_mask(resp_label[mask], 0)
# resp_label["Resp_Score_Alpha"] = score_mask2alpha(resp_label["Resp_Score"])
# resp_label["Resp_Label_Alpha"] = utils.none_to_nan_mask(resp_label["Resp_Label_Alpha"], 0)
fig, axes = create_resp_figure()
for segment_start, segment_end in segment_list:
for ax in axes:
ax.cla()
plt_signal_label_on_ax(axes[resp_chn_name2ax["resp"]], resp_data["resp_signal"], segment_start, segment_end,
resp_label, multi_labels="resp", event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[resp_chn_name2ax["bcg"]], resp_data["bcg_signal"], segment_start, segment_end,
resp_label, multi_labels="bcg", event_codes=[1, 2, 3, 4])
plt.show()

View File

@ -247,6 +247,8 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs,
ax1_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask', 'SA Mask'], loc='upper right') ax1_twin.legend(['Low Amplitude Mask', 'Movement Mask', 'Amplitude Change Mask', 'SA Mask'], loc='upper right')
ax1.set_title(f'Sample {samp_id} - Respiration Component') ax1.set_title(f'Sample {samp_id} - Respiration Component')
ax2 = fig.add_subplot(3, 1, 3, sharex=ax0) ax2 = fig.add_subplot(3, 1, 3, sharex=ax0)
ax2.plot(np.linspace(0, len(bcg_data) // bcg_fs, len(bcg_data)), bcg_data, color='green') ax2.plot(np.linspace(0, len(bcg_data) // bcg_fs, len(bcg_data)), bcg_data, color='green')
ax2.set_ylabel('Amplitude') ax2.set_ylabel('Amplitude')
@ -300,5 +302,70 @@ def draw_signal_with_mask(samp_id, signal_data, resp_data, bcg_data, signal_fs,
plt.show() plt.show()
def draw_psg_signal(samp_id, tho_signal, abd_signal, flowp_signal, flowt_signal, spo2_signal, effort_signal, rri_signal, event_mask, fs,
show=False, save_path=None):
sa_mask = event_mask.repeat(fs)
fig, axs = plt.subplots(7, 1, figsize=(18, 12), sharex=True)
time_axis = np.linspace(0, len(tho_signal) / fs, len(tho_signal))
axs[0].plot(time_axis, tho_signal, label='THO', color='black')
axs[0].set_title(f'Sample {samp_id} - PSG Signal Data')
axs[0].set_ylabel('THO Amplitude')
axs[0].legend(loc='upper right')
ax0_twin = axs[0].twinx()
ax0_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
ax0_twin.autoscale(enable=False, axis='y', tight=True)
ax0_twin.set_ylim((-4, 5))
axs[1].plot(time_axis, abd_signal, label='ABD', color='black')
axs[1].set_ylabel('ABD Amplitude')
axs[1].legend(loc='upper right')
ax1_twin = axs[1].twinx()
ax1_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
ax1_twin.autoscale(enable=False, axis='y', tight=True)
ax1_twin.set_ylim((-4, 5))
axs[2].plot(time_axis, effort_signal, label='EFFO', color='black')
axs[2].set_ylabel('EFFO Amplitude')
axs[2].legend(loc='upper right')
ax2_twin = axs[2].twinx()
ax2_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
ax2_twin.autoscale(enable=False, axis='y', tight=True)
ax2_twin.set_ylim((-4, 5))
axs[3].plot(time_axis, flowp_signal, label='FLOWP', color='black')
axs[3].set_ylabel('FLOWP Amplitude')
axs[3].legend(loc='upper right')
ax3_twin = axs[3].twinx()
ax3_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
ax3_twin.autoscale(enable=False, axis='y', tight=True)
ax3_twin.set_ylim((-4, 5))
axs[4].plot(time_axis, flowt_signal, label='FLOWT', color='black')
axs[4].set_ylabel('FLOWT Amplitude')
axs[4].legend(loc='upper right')
ax4_twin = axs[4].twinx()
ax4_twin.plot(time_axis, sa_mask, color='purple', alpha=0.5, label='SA Mask')
ax4_twin.autoscale(enable=False, axis='y', tight=True)
ax4_twin.set_ylim((-4, 5))
axs[5].plot(time_axis, rri_signal, label='RRI', color='black')
axs[5].set_ylabel('RRI Amplitude')
axs[5].legend(loc='upper right')
axs[6].plot(time_axis, spo2_signal, label='SPO2', color='black')
axs[6].set_ylabel('SPO2 Amplitude')
axs[6].set_xlabel('Time (s)')
axs[6].legend(loc='upper right')
if save_path is not None:
plt.savefig(save_path, dpi=300)
if show:
plt.show()

View File

@ -0,0 +1,183 @@
"""
本脚本完成对呼研所数据的处理包含以下功能
1. 数据读取与预处理
从传入路径中进行数据和标签的读取并进行初步的预处理
预处理包括为数据进行滤波去噪等操作
2. 数据清洗与异常值处理
3. 输出清晰后的统计信息
4. 数据保存
将处理后的数据保存到指定路径便于后续使用
主要是保存切分后的数据位置和标签
5. 可视化
提供数据处理前后的可视化对比帮助理解数据变化
绘制多条可用性趋势图展示数据的可用区间体动区间低幅值区间等
# 低幅值区间规则标定与剔除
# 高幅值连续体动规则标定与剔除
# 手动标定不可用区间提剔除
"""
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
project_root_path = Path(__file__).resolve().parent.parent
import shutil
import draw_tools
import utils
import numpy as np
import signal_method
import os
os.environ['DISPLAY'] = "localhost:10.0"
def process_one_signal(samp_id, show=False):
pass
tho_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Tho_Sync_*.txt"))
abd_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Effort Abd_Sync_*.txt"))
flowp_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow P_Sync_*.txt"))
flowt_signal_path = list((org_signal_root_path / f"{samp_id}").glob("Flow T_Sync_*.txt"))
spo2_signal_path = list((org_signal_root_path / f"{samp_id}").glob("SpO2_Sync_*.txt"))
stage_signal_path = list((org_signal_root_path / f"{samp_id}").glob("5_class_Sync_*.txt"))
if not tho_signal_path:
raise FileNotFoundError(f"Effort Tho_Sync file not found for sample ID: {samp_id}")
tho_signal_path = tho_signal_path[0]
print(f"Processing Effort Tho_Sync signal file: {tho_signal_path}")
if not abd_signal_path:
raise FileNotFoundError(f"Effort Abd_Sync file not found for sample ID: {samp_id}")
abd_signal_path = abd_signal_path[0]
print(f"Processing Effort Abd_Sync signal file: {abd_signal_path}")
if not flowp_signal_path:
raise FileNotFoundError(f"Flow P_Sync file not found for sample ID: {samp_id}")
flowp_signal_path = flowp_signal_path[0]
print(f"Processing Flow P_Sync signal file: {flowp_signal_path}")
if not flowt_signal_path:
raise FileNotFoundError(f"Flow T_Sync file not found for sample ID: {samp_id}")
flowt_signal_path = flowt_signal_path[0]
print(f"Processing Flow T_Sync signal file: {flowt_signal_path}")
if not spo2_signal_path:
raise FileNotFoundError(f"SpO2_Sync file not found for sample ID: {samp_id}")
spo2_signal_path = spo2_signal_path[0]
print(f"Processing SpO2_Sync signal file: {spo2_signal_path}")
if not stage_signal_path:
raise FileNotFoundError(f"5_class_Sync file not found for sample ID: {samp_id}")
stage_signal_path = stage_signal_path[0]
print(f"Processing 5_class_Sync signal file: {stage_signal_path}")
label_path = (label_root_path / f"{samp_id}").glob("SA Label_Sync.csv")
if not label_path:
raise FileNotFoundError(f"Label_corrected file not found for sample ID: {samp_id}")
label_path = list(label_path)[0]
print(f"Processing Label_corrected file: {label_path}")
#
# # 保存处理后的数据和标签
save_samp_path = save_path / f"{samp_id}"
save_samp_path.mkdir(parents=True, exist_ok=True)
# # # 读取信号数据
tho_data_raw, tho_length, tho_fs, tho_second = utils.read_signal_txt(tho_signal_path, dtype=float, verbose=True)
# abd_data_raw, abd_length, abd_fs, abd_second = utils.read_signal_txt(abd_signal_path, dtype=float, verbose=True)
# flowp_data_raw, flowp_length, flowp_fs, flowp_second = utils.read_signal_txt(flowp_signal_path, dtype=float, verbose=True)
# flowt_data_raw, flowt_length, flowt_fs, flowt_second = utils.read_signal_txt(flowt_signal_path, dtype=float, verbose=True)
# spo2_data_raw, spo2_length, spo2_fs, spo2_second = utils.read_signal_txt(spo2_signal_path, dtype=float, verbose=True)
stage_data_raw, stage_length, stage_fs, stage_second = utils.read_signal_txt(stage_signal_path, dtype=str, verbose=True)
#
# # 预处理与滤波
# tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=tho_data_raw, effort_fs=tho_fs)
# abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=abd_data_raw, effort_fs=abd_fs)
# flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=flowp_data_raw, effort_fs=flowp_fs)
# flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=flowt_data_raw, effort_fs=flowt_fs)
# 降采样
# old_tho_fs = tho_fs
# tho_fs = conf["effort"]["downsample_fs"]
# tho_data_filt = utils.downsample_signal_fast(original_signal=tho_data_filt, original_fs=old_tho_fs, target_fs=tho_fs)
# old_abd_fs = abd_fs
# abd_fs = conf["effort"]["downsample_fs"]
# abd_data_filt = utils.downsample_signal_fast(original_signal=abd_data_filt, original_fs=old_abd_fs, target_fs=abd_fs)
# old_flowp_fs = flowp_fs
# flowp_fs = conf["effort"]["downsample_fs"]
# flowp_data_filt = utils.downsample_signal_fast(original_signal=flowp_data_filt, original_fs=old_flowp_fs, target_fs=flowp_fs)
# old_flowt_fs = flowt_fs
# flowt_fs = conf["effort"]["downsample_fs"]
# flowt_data_filt = utils.downsample_signal_fast(original_signal=flowt_data_filt, original_fs=old_flowt_fs, target_fs=flowt_fs)
# spo2不降采样
# spo2_data_filt = spo2_data_raw
# spo2_fs = spo2_fs
label_data = utils.read_raw_psg_label(path=label_path)
event_mask, score_mask = utils.generate_event_mask(signal_second=tho_second, event_df=label_data, use_correct=False, with_score=False)
# event_mask > 0 的部分为1其他为0
score_mask = np.where(event_mask > 0, 1, 0)
# 根据睡眠分期生成不可用区间
wake_mask = utils.get_wake_mask(stage_data_raw)
# 剔除短于60秒的觉醒区间
wake_mask = utils.remove_short_durations(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), min_duration_sec=60)
# 合并短于120秒的觉醒区间
wake_mask = utils.merge_short_gaps(wake_mask, time_points=np.arange(len(wake_mask) * stage_fs), max_gap_sec=60)
disable_label = wake_mask
disable_label = disable_label[:tho_second]
# 复制事件文件 到保存路径
sa_label_save_name = f"{samp_id}_" + label_path.name
shutil.copyfile(label_path, save_samp_path / sa_label_save_name)
#
# 新建一个dataframe分别是秒数、SA标签
save_dict = {
"Second": np.arange(tho_second),
"SA_Label": event_mask,
"SA_Score": score_mask,
"Disable_Label": disable_label,
"Resp_LowAmp_Label": np.zeros_like(event_mask),
"Resp_Movement_Label": np.zeros_like(event_mask),
"Resp_AmpChange_Label": np.zeros_like(event_mask),
"BCG_LowAmp_Label": np.zeros_like(event_mask),
"BCG_Movement_Label": np.zeros_like(event_mask),
"BCG_AmpChange_Label": np.zeros_like(event_mask)
}
mask_label_save_name = f"{samp_id}_Processed_Labels.csv"
utils.save_process_label(save_path=save_samp_path / mask_label_save_name, save_dict=save_dict)
if __name__ == '__main__':
yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml"
# disable_df_path = project_root_path / "排除区间.xlsx"
#
conf = utils.load_dataset_conf(yaml_path)
root_path = Path(conf["root_path"])
save_path = Path(conf["mask_save_path"])
select_ids = conf["select_ids"]
#
print(f"select_ids: {select_ids}")
print(f"root_path: {root_path}")
print(f"save_path: {save_path}")
#
org_signal_root_path = root_path / "PSG_Aligned"
label_root_path = root_path / "PSG_Aligned"
#
# all_samp_disable_df = utils.read_disable_excel(disable_df_path)
#
# process_one_signal(select_ids[0], show=True)
# #
for samp_id in select_ids:
print(f"Processing sample ID: {samp_id}")
process_one_signal(samp_id, show=False)
print(f"Finished processing sample ID: {samp_id}\n\n")
pass

View File

@ -2,5 +2,5 @@ from .rule_base_event import detect_low_amplitude_signal, detect_movement
from .rule_base_event import position_based_sleep_recognition_v2, position_based_sleep_recognition_v3 from .rule_base_event import position_based_sleep_recognition_v2, position_based_sleep_recognition_v3
from .rule_base_event import movement_revise from .rule_base_event import movement_revise
from .time_metrics import calc_mav_by_slide_windows from .time_metrics import calc_mav_by_slide_windows
from .signal_process import signal_filter_split, rpeak2hr from .signal_process import signal_filter_split, rpeak2hr, psg_effort_filter, rpeak2rri_interpolation
from .normalize_method import normalize_resp_signal from .normalize_method import normalize_resp_signal_by_segment

View File

@ -3,7 +3,7 @@ import pandas as pd
import numpy as np import numpy as np
from scipy import signal from scipy import signal
def normalize_resp_signal(resp_signal: np.ndarray, resp_fs, movement_mask, enable_list): def normalize_resp_signal_by_segment(resp_signal: np.ndarray, resp_fs, movement_mask, enable_list):
# 根据呼吸信号的幅值改变区间对每段进行Z-Score标准化 # 根据呼吸信号的幅值改变区间对每段进行Z-Score标准化
normalized_resp_signal = np.zeros_like(resp_signal) normalized_resp_signal = np.zeros_like(resp_signal)
# 全部填成nan # 全部填成nan
@ -33,4 +33,20 @@ def normalize_resp_signal(resp_signal: np.ndarray, resp_fs, movement_mask, enabl
raw_segment = resp_signal[enable_start:enable_end] raw_segment = resp_signal[enable_start:enable_end]
normalized_resp_signal[enable_start:enable_end] = (raw_segment - segment_mean) / segment_std normalized_resp_signal[enable_start:enable_end] = (raw_segment - segment_mean) / segment_std
#如果enable区间不从0开始则将前面的部分也进行标准化
if enable_list[0][0] > 0:
new_enable_start = 0
enable_start = enable_list[0][0] * resp_fs
enable_end = enable_list[0][1] * resp_fs
segment = resp_signal_no_movement[enable_start:enable_end]
segment_mean = np.nanmean(segment)
segment_std = np.nanstd(segment)
if segment_std == 0:
raise ValueError(f"segment_std is zero! segment_start: {enable_start}, segment_end: {enable_end}")
raw_segment = resp_signal[new_enable_start:enable_start]
normalized_resp_signal[new_enable_start:enable_start] = (raw_segment - segment_mean) / segment_std
return normalized_resp_signal return normalized_resp_signal

View File

@ -1,4 +1,5 @@
import numpy as np import numpy as np
from scipy.interpolate import interp1d
import utils import utils
@ -44,14 +45,24 @@ def signal_filter_split(conf, signal_data_raw, signal_fs, verbose=True):
return signal_data, resp_data_2, resp_fs, bcg_data, signal_fs return signal_data, resp_data_2, resp_fs, bcg_data, signal_fs
def psg_effort_filter(conf, effort_data_raw, effort_fs):
# 滤波
effort_data_1 = utils.bessel(data=effort_data_raw, _type=conf["effort_filter"]["filter_type"],
low_cut=conf["effort_filter"]["low_cut"],
high_cut=conf["effort_filter"]["high_cut"], order=conf["effort_filter"]["order"],
sample_rate=effort_fs)
# 移动平均
effort_data_2 = utils.average_filter(raw_data=effort_data_1, sample_rate=effort_fs, window_size_sec=20)
return effort_data_raw, effort_data_2, effort_fs
def rpeak2hr(rpeak_indices, signal_length):
def rpeak2hr(rpeak_indices, signal_length, ecg_fs):
hr_signal = np.zeros(signal_length) hr_signal = np.zeros(signal_length)
for i in range(1, len(rpeak_indices)): for i in range(1, len(rpeak_indices)):
rri = rpeak_indices[i] - rpeak_indices[i - 1] rri = rpeak_indices[i] - rpeak_indices[i - 1]
if rri == 0: if rri == 0:
continue continue
hr = 60 * 1000 / rri # 心率单位bpm hr = 60 * ecg_fs / rri # 心率单位bpm
if hr > 120: if hr > 120:
hr = 120 hr = 120
elif hr < 30: elif hr < 30:
@ -62,3 +73,35 @@ def rpeak2hr(rpeak_indices, signal_length):
hr_signal[rpeak_indices[-1]:] = hr_signal[rpeak_indices[-2]] hr_signal[rpeak_indices[-1]:] = hr_signal[rpeak_indices[-2]]
return hr_signal return hr_signal
def rpeak2rri_repeat(rpeak_indices, signal_length, ecg_fs):
rri_signal = np.zeros(signal_length)
for i in range(1, len(rpeak_indices)):
rri = rpeak_indices[i] - rpeak_indices[i - 1]
rri_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = rri
# 填充最后一个R峰之后的RRI值
if len(rpeak_indices) > 1:
rri_signal[rpeak_indices[-1]:] = rri_signal[rpeak_indices[-2]]
# 遍历异常值
for i in range(1, len(rpeak_indices)):
rri = rpeak_indices[i] - rpeak_indices[i - 1]
if rri < 0.3 * ecg_fs or rri > 2 * ecg_fs:
rri_signal[rpeak_indices[i - 1]:rpeak_indices[i]] = 0
return rri_signal
def rpeak2rri_interpolation(rpeak_indices, ecg_fs, rri_fs):
r_peak_time = np.asarray(rpeak_indices) / ecg_fs
rri = np.diff(r_peak_time)
t_rri = r_peak_time[1:]
mask = (rri > 0.3) & (rri < 2.0)
rri_clean = rri[mask]
t_rri_clean = t_rri[mask]
t_uniform = np.arange(t_rri_clean[0], t_rri_clean[-1], 1/rri_fs)
f = interp1d(t_rri_clean, rri_clean, kind='linear', fill_value="extrapolate")
rri_uniform = f(t_uniform)
return rri_uniform, rri_fs

View File

@ -178,6 +178,55 @@ def read_label_csv(path: Union[str, Path], verbose=True) -> pd.DataFrame:
return df return df
def read_raw_psg_label(path: Union[str, Path], verbose=True) -> pd.DataFrame:
"""
Read a CSV file and return it as a pandas DataFrame.
Args:
path (str | Path): Path to the CSV file.
verbose (bool):
Returns:
pd.DataFrame: The content of the CSV file as a pandas DataFrame.
:param path:
:param verbose:
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
# 直接用pandas读取 包含中文 故指定编码
df = pd.read_csv(path, encoding="gbk")
if verbose:
print(f"Label file read from {path}, number of rows: {len(df)}")
num_psg_events = np.sum(df["Event type"].notna())
# 统计事件
num_psg_hyp = np.sum(df["Event type"] == "Hypopnea")
num_psg_csa = np.sum(df["Event type"] == "Central apnea")
num_psg_osa = np.sum(df["Event type"] == "Obstructive apnea")
num_psg_msa = np.sum(df["Event type"] == "Mixed apnea")
if verbose:
print("Event Statistics:")
# 格式化输出 总计/来自PSG/手动/删除/未标注 指定宽度
print(f"Type {'Total':^8s}")
print(
f"Hyp: {num_psg_hyp:^8d} ")
print(
f"CSA: {num_psg_csa:^8d} ")
print(
f"OSA: {num_psg_osa:^8d} ")
print(
f"MSA: {num_psg_msa:^8d} ")
print(
f"All: {num_psg_events:^8d}")
df["Start"] = df["Start"].astype(int)
df["End"] = df["End"].astype(int)
return df
def read_disable_excel(path: Union[str, Path]) -> pd.DataFrame: def read_disable_excel(path: Union[str, Path]) -> pd.DataFrame:
""" """
Read an Excel file and return it as a pandas DataFrame. Read an Excel file and return it as a pandas DataFrame.
@ -225,6 +274,15 @@ def read_mask_execl(path: Union[str, Path]):
return event_mask, event_list return event_mask, event_list
def read_psg_mask_excel(path: Union[str, Path]):
df = pd.read_csv(path)
event_mask = df.to_dict(orient="list")
for key in event_mask:
event_mask[key] = np.array(event_mask[key])
return event_mask
def read_psg_channel(path_str: Union[str, Path], channel_number: list[int], verbose=True): def read_psg_channel(path_str: Union[str, Path], channel_number: list[int], verbose=True):
""" """

View File

@ -1,10 +1,12 @@
from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel, read_psg_label from .HYS_FileReader import read_label_csv, read_signal_txt, read_disable_excel, read_psg_label, read_raw_psg_label, read_psg_mask_excel
from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list from .operation_tools import load_dataset_conf, generate_disable_mask, generate_event_mask, event_mask_2_list
from .operation_tools import merge_short_gaps, remove_short_durations from .operation_tools import merge_short_gaps, remove_short_durations
from .operation_tools import collect_values from .operation_tools import collect_values
from .operation_tools import save_process_label from .operation_tools import save_process_label
from .operation_tools import none_to_nan_mask from .operation_tools import none_to_nan_mask
from .operation_tools import get_wake_mask
from .operation_tools import fill_spo2_anomaly
from .split_method import resp_split from .split_method import resp_split
from .HYS_FileReader import read_mask_execl, read_psg_channel from .HYS_FileReader import read_mask_execl, read_psg_channel
from .event_map import E2N, N2Chn, Stage2N, ColorCycle from .event_map import E2N, N2Chn, Stage2N, ColorCycle
from .filter_func import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel from .filter_func import butterworth, average_filter, downsample_signal_fast, notch_filter, bessel, adjust_sample_rate

View File

@ -20,6 +20,7 @@ def butterworth(data, _type, low_cut=0.0, high_cut=0.0, order=10, sample_rate=10
raise ValueError("Please choose a type of fliter") raise ValueError("Please choose a type of fliter")
@timing_decorator()
def bessel(data, _type, low_cut=0.0, high_cut=0.0, order=4, sample_rate=1000): def bessel(data, _type, low_cut=0.0, high_cut=0.0, order=4, sample_rate=1000):
if _type == "lowpass": # 低通滤波处理 if _type == "lowpass": # 低通滤波处理
b, a = signal.bessel(order, low_cut / (sample_rate * 0.5), btype='lowpass', analog=False, norm='mag') b, a = signal.bessel(order, low_cut / (sample_rate * 0.5), btype='lowpass', analog=False, norm='mag')
@ -89,6 +90,52 @@ def downsample_signal_fast(original_signal, original_fs, target_fs, chunk_size=1
return downsampled_signal return downsampled_signal
def upsample_signal(original_signal, original_fs, target_fs):
"""
信号升采样
参数:
original_signal : array-like, 原始信号数组
original_fs : float, 原始采样率 (Hz)
target_fs : float, 目标采样率 (Hz)
返回:
upsampled_signal : array-like, 升采样后的信号
"""
if not isinstance(original_signal, np.ndarray):
original_signal = np.array(original_signal)
if target_fs <= original_fs:
raise ValueError("目标采样率必须大于原始采样率")
if target_fs <= 0 or original_fs <= 0:
raise ValueError("采样率必须为正数")
upsample_factor = target_fs / original_fs
num_output_samples = int(len(original_signal) * upsample_factor)
upsampled_signal = signal.resample(original_signal, num_output_samples)
return upsampled_signal
def adjust_sample_rate(signal_data, original_fs, target_fs):
"""
根据信号的原始采样率和目标采样率自动选择升采样或降采样
参数:
signal_data : array-like, 原始信号数组
original_fs : float, 原始采样率 (Hz)
target_fs : float, 目标采样率 (Hz)
返回:
adjusted_signal : array-like, 调整采样率后的信号
"""
if original_fs == target_fs:
return signal_data
elif original_fs > target_fs:
return downsample_signal_fast(signal_data, original_fs, target_fs)
else:
return upsample_signal(signal_data, original_fs, target_fs)
@timing_decorator() @timing_decorator()
def average_filter(raw_data, sample_rate, window_size_sec=20): def average_filter(raw_data, sample_rate, window_size_sec=20):

View File

@ -6,6 +6,7 @@ import pandas as pd
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import yaml import yaml
from numpy.ma.core import append from numpy.ma.core import append
from scipy.interpolate import PchipInterpolator
from utils.event_map import E2N from utils.event_map import E2N
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
@ -198,9 +199,12 @@ def generate_disable_mask(signal_second: int, disable_df) -> np.ndarray:
return disable_mask return disable_mask
def generate_event_mask(signal_second: int, event_df, use_correct=True): def generate_event_mask(signal_second: int, event_df, use_correct=True, with_score=True):
event_mask = np.zeros(signal_second, dtype=int) event_mask = np.zeros(signal_second, dtype=int)
score_mask = np.zeros(signal_second, dtype=int) if with_score:
score_mask = np.zeros(signal_second, dtype=int)
else:
score_mask = None
if use_correct: if use_correct:
start_name = "correct_Start" start_name = "correct_Start"
end_name = "correct_End" end_name = "correct_End"
@ -217,7 +221,8 @@ def generate_event_mask(signal_second: int, event_df, use_correct=True):
start = row[start_name] start = row[start_name]
end = row[end_name] + 1 end = row[end_name] + 1
event_mask[start:end] = E2N[row[event_type_name]] event_mask[start:end] = E2N[row[event_type_name]]
score_mask[start:end] = row["score"] if with_score:
score_mask[start:end] = row["score"]
return event_mask, score_mask return event_mask, score_mask
@ -260,4 +265,116 @@ def none_to_nan_mask(mask, ref):
else: else:
# 将mask中的0替换为nan其他的保持 # 将mask中的0替换为nan其他的保持
mask = np.where(mask == 0, np.nan, mask) mask = np.where(mask == 0, np.nan, mask)
return mask return mask
def get_wake_mask(sleep_stage_mask):
# 将N1, N2, N3, REM视为睡眠 0其他为清醒 1
# 输入是字符 分别有 'W', 'N1', 'N2', 'N3', 'R' 等
wake_mask = np.where(np.isin(sleep_stage_mask, ['N1', 'N2', 'N3', 'REM', 'R']), 0, 1)
return wake_mask
def detect_spo2_anomaly(spo2, fs, diff_thresh=7):
anomaly = np.zeros(len(spo2), dtype=bool)
# 生理范围
anomaly |= (spo2 < 50) | (spo2 > 100)
# 突变
diff = np.abs(np.diff(spo2, prepend=spo2[0]))
anomaly |= diff > diff_thresh
# NaN
anomaly |= np.isnan(spo2)
return anomaly
def merge_close_anomalies(anomaly, fs, min_gap_duration):
min_gap = int(min_gap_duration * fs)
merged = anomaly.copy()
i = 0
n = len(anomaly)
while i < n:
if not anomaly[i]:
i += 1
continue
# 当前异常段
start = i
while i < n and anomaly[i]:
i += 1
end = i
# 向后看 gap
j = end
while j < n and not anomaly[j]:
j += 1
if j < n and (j - end) < min_gap:
merged[end:j] = True
return merged
def fill_spo2_anomaly(
spo2_data,
spo2_fs,
max_fill_duration,
min_gap_duration,
):
spo2 = spo2_data.astype(float).copy()
n = len(spo2)
anomaly = detect_spo2_anomaly(spo2, spo2_fs)
anomaly = merge_close_anomalies(anomaly, spo2_fs, min_gap_duration)
max_len = int(max_fill_duration * spo2_fs)
valid_mask = ~anomaly
i = 0
while i < n:
if not anomaly[i]:
i += 1
continue
start = i
while i < n and anomaly[i]:
i += 1
end = i
seg_len = end - start
# 超长异常段
if seg_len > max_len:
spo2[start:end] = np.nan
valid_mask[start:end] = False
continue
has_left = start > 0 and valid_mask[start - 1]
has_right = end < n and valid_mask[end]
# 开头异常:单侧填充
if not has_left and has_right:
spo2[start:end] = spo2[end]
continue
# 结尾异常:单侧填充
if has_left and not has_right:
spo2[start:end] = spo2[start - 1]
continue
# 两侧都有 → PCHIP
if has_left and has_right:
x = np.array([start - 1, end])
y = np.array([spo2[start - 1], spo2[end]])
interp = PchipInterpolator(x, y)
spo2[start:end] = interp(np.arange(start, end))
continue
# 两侧都没有(极端情况)
spo2[start:end] = np.nan
valid_mask[start:end] = False
return spo2, valid_mask

View File

@ -54,5 +54,3 @@ def resp_split(dataset_config, event_mask, event_list, verbose=False):
return segment_list, disable_segment_list return segment_list, disable_segment_list