395 lines
16 KiB
Python
395 lines
16 KiB
Python
import multiprocessing
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import os
|
||
import numpy as np
|
||
|
||
from utils import N2Chn
|
||
|
||
os.environ['DISPLAY'] = "localhost:10.0"
|
||
|
||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||
project_root_path = Path(__file__).resolve().parent.parent
|
||
|
||
import utils
|
||
import signal_method
|
||
import draw_tools
|
||
import shutil
|
||
import gc
|
||
|
||
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None):
|
||
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
|
||
|
||
total_seconds = min(
|
||
psg_data[i]["second"] for i in N2Chn.values() if i != "Rpeak"
|
||
)
|
||
for i in N2Chn.values():
|
||
if i == "Rpeak":
|
||
continue
|
||
length = int(total_seconds * psg_data[i]["fs"])
|
||
psg_data[i]["data"] = psg_data[i]["data"][:length]
|
||
psg_data[i]["length"] = length
|
||
psg_data[i]["second"] = total_seconds
|
||
|
||
psg_data["HR"] = {
|
||
"name": "HR",
|
||
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"],
|
||
psg_data["Rpeak"]["fs"]),
|
||
"fs": psg_data["ECG_Sync"]["fs"],
|
||
"length": psg_data["ECG_Sync"]["length"],
|
||
"second": psg_data["ECG_Sync"]["second"]
|
||
}
|
||
# 预处理与滤波
|
||
tho_data, tho_data_filt, tho_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Tho"]["data"], effort_fs=psg_data["Effort Tho"]["fs"])
|
||
abd_data, abd_data_filt, abd_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Effort Abd"]["data"], effort_fs=psg_data["Effort Abd"]["fs"])
|
||
flowp_data, flowp_data_filt, flowp_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow P"]["data"], effort_fs=psg_data["Flow P"]["fs"])
|
||
flowt_data, flowt_data_filt, flowt_fs = signal_method.psg_effort_filter(conf=conf, effort_data_raw=psg_data["Flow T"]["data"], effort_fs=psg_data["Flow T"]["fs"])
|
||
|
||
rri, rri_fs = signal_method.rpeak2rri_interpolation(rpeak_indices=psg_data["Rpeak"]["data"], ecg_fs=psg_data["ECG_Sync"]["fs"], rri_fs=100)
|
||
|
||
|
||
mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv")
|
||
if verbose:
|
||
print(f"mask_excel_path: {mask_excel_path}")
|
||
|
||
event_mask, event_list = utils.read_mask_execl(mask_excel_path)
|
||
|
||
enable_list = [[0, psg_data["Effort Tho"]["second"]]]
|
||
normalized_tho_signal = signal_method.normalize_resp_signal_by_segment(tho_data_filt, tho_fs, np.zeros(psg_data["Effort Tho"]["second"]), enable_list)
|
||
normalized_abd_signal = signal_method.normalize_resp_signal_by_segment(abd_data_filt, abd_fs, np.zeros(psg_data["Effort Abd"]["second"]), enable_list)
|
||
normalized_flowp_signal = signal_method.normalize_resp_signal_by_segment(flowp_data_filt, flowp_fs, np.zeros(psg_data["Flow P"]["second"]), enable_list)
|
||
normalized_flowt_signal = signal_method.normalize_resp_signal_by_segment(flowt_data_filt, flowt_fs, np.zeros(psg_data["Flow T"]["second"]), enable_list)
|
||
|
||
|
||
# 都调整至100Hz采样率
|
||
target_fs = 100
|
||
normalized_tho_signal = utils.adjust_sample_rate(normalized_tho_signal, tho_fs, target_fs)
|
||
normalized_abd_signal = utils.adjust_sample_rate(normalized_abd_signal, abd_fs, target_fs)
|
||
normalized_flowp_signal = utils.adjust_sample_rate(normalized_flowp_signal, flowp_fs, target_fs)
|
||
normalized_flowt_signal = utils.adjust_sample_rate(normalized_flowt_signal, flowt_fs, target_fs)
|
||
spo2_data_filt = utils.adjust_sample_rate(psg_data["SpO2"]["data"], psg_data["SpO2"]["fs"], target_fs)
|
||
normalized_effort_signal = (normalized_tho_signal + normalized_abd_signal) / 2
|
||
rri = utils.adjust_sample_rate(rri, rri_fs, target_fs)
|
||
|
||
# 调整至相同长度
|
||
min_length = min(len(normalized_tho_signal), len(normalized_abd_signal), len(normalized_flowp_signal), len(normalized_flowt_signal), len(spo2_data_filt), len(normalized_effort_signal)
|
||
,len(rri))
|
||
min_length = min_length - min_length % target_fs # 保证是整数秒
|
||
normalized_tho_signal = normalized_tho_signal[:min_length]
|
||
normalized_abd_signal = normalized_abd_signal[:min_length]
|
||
normalized_flowp_signal = normalized_flowp_signal[:min_length]
|
||
normalized_flowt_signal = normalized_flowt_signal[:min_length]
|
||
spo2_data_filt = spo2_data_filt[:min_length]
|
||
normalized_effort_signal = normalized_effort_signal[:min_length]
|
||
rri = rri[:min_length]
|
||
|
||
tho_second = min_length / target_fs
|
||
for i in event_mask.keys():
|
||
event_mask[i] = event_mask[i][:int(tho_second)]
|
||
|
||
spo2_data_filt_fill, spo2_disable_mask = utils.fill_spo2_anomaly(spo2_data=spo2_data_filt,
|
||
spo2_fs=target_fs,
|
||
max_fill_duration=30,
|
||
min_gap_duration=10,)
|
||
|
||
|
||
draw_tools.draw_psg_signal(
|
||
samp_id=samp_id,
|
||
tho_signal=normalized_tho_signal,
|
||
abd_signal=normalized_abd_signal,
|
||
flowp_signal=normalized_flowp_signal,
|
||
flowt_signal=normalized_flowt_signal,
|
||
spo2_signal=spo2_data_filt,
|
||
effort_signal=normalized_effort_signal,
|
||
rri_signal = rri,
|
||
fs=target_fs,
|
||
event_mask=event_mask["SA_Label"],
|
||
save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots.png",
|
||
show=show
|
||
)
|
||
|
||
draw_tools.draw_psg_signal(
|
||
samp_id=samp_id,
|
||
tho_signal=normalized_tho_signal,
|
||
abd_signal=normalized_abd_signal,
|
||
flowp_signal=normalized_flowp_signal,
|
||
flowt_signal=normalized_flowt_signal,
|
||
spo2_signal=spo2_data_filt_fill,
|
||
effort_signal=normalized_effort_signal,
|
||
rri_signal = rri,
|
||
fs=target_fs,
|
||
event_mask=event_mask["SA_Label"],
|
||
save_path= mask_path / f"{samp_id}" / f"{samp_id}_Signal_Plots_fill.png",
|
||
show=show
|
||
)
|
||
|
||
spo2_disable_mask = spo2_disable_mask[::target_fs]
|
||
min_len = min(len(event_mask["Disable_Label"]), len(spo2_disable_mask))
|
||
|
||
if len(event_mask["Disable_Label"]) != len(spo2_disable_mask):
|
||
print(f"Warning: Data length mismatch! Truncating to {min_len}.")
|
||
event_mask["Disable_Label"] = event_mask["Disable_Label"][:min_len] & spo2_disable_mask[:min_len]
|
||
|
||
event_list = {
|
||
"EnableSegment": utils.event_mask_2_list(1 - event_mask["Disable_Label"]),
|
||
"DisableSegment": utils.event_mask_2_list(event_mask["Disable_Label"])}
|
||
|
||
spo2_data_filt_fill = np.nan_to_num(spo2_data_filt_fill, nan=95)
|
||
|
||
segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose)
|
||
if verbose:
|
||
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
|
||
|
||
|
||
# 复制mask到processed_Labels文件夹
|
||
save_mask_excel_path = save_processed_label_path / f"{samp_id}_Processed_Labels.csv"
|
||
shutil.copyfile(mask_excel_path, save_mask_excel_path)
|
||
|
||
# 复制SA Label_corrected.csv到processed_Labels文件夹
|
||
sa_label_corrected_path = Path(mask_path, f"{samp_id}", f"{samp_id}_SA Label_Sync.csv")
|
||
if sa_label_corrected_path.exists():
|
||
save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv"
|
||
shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path)
|
||
else:
|
||
if verbose:
|
||
print(f"Warning: {sa_label_corrected_path} does not exist.")
|
||
|
||
# 保存处理后的信号和截取的片段列表
|
||
save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz"
|
||
save_segment_path = save_segment_list_path / f"{samp_id}_Segment_List.npz"
|
||
|
||
# psg_data更新为处理后的信号
|
||
# 用下划线替换键里面的空格
|
||
psg_data = {
|
||
"Effort Tho": {
|
||
"name": "Effort_Tho",
|
||
"data": normalized_tho_signal,
|
||
"fs": target_fs,
|
||
"length": len(normalized_tho_signal),
|
||
"second": len(normalized_tho_signal) / target_fs
|
||
},
|
||
"Effort Abd": {
|
||
"name": "Effort_Abd",
|
||
"data": normalized_abd_signal,
|
||
"fs": target_fs,
|
||
"length": len(normalized_abd_signal),
|
||
"second": len(normalized_abd_signal) / target_fs
|
||
},
|
||
"Effort": {
|
||
"name": "Effort",
|
||
"data": normalized_effort_signal,
|
||
"fs": target_fs,
|
||
"length": len(normalized_effort_signal),
|
||
"second": len(normalized_effort_signal) / target_fs
|
||
},
|
||
"Flow P": {
|
||
"name": "Flow_P",
|
||
"data": normalized_flowp_signal,
|
||
"fs": target_fs,
|
||
"length": len(normalized_flowp_signal),
|
||
"second": len(normalized_flowp_signal) / target_fs
|
||
},
|
||
"Flow T": {
|
||
"name": "Flow_T",
|
||
"data": normalized_flowt_signal,
|
||
"fs": target_fs,
|
||
"length": len(normalized_flowt_signal),
|
||
"second": len(normalized_flowt_signal) / target_fs
|
||
},
|
||
"SpO2": {
|
||
"name": "SpO2",
|
||
"data": spo2_data_filt_fill,
|
||
"fs": target_fs,
|
||
"length": len(spo2_data_filt_fill),
|
||
"second": len(spo2_data_filt_fill) / target_fs
|
||
},
|
||
"HR": {
|
||
"name": "HR",
|
||
"data": psg_data["HR"]["data"],
|
||
"fs": psg_data["HR"]["fs"],
|
||
"length": psg_data["HR"]["length"],
|
||
"second": psg_data["HR"]["second"]
|
||
},
|
||
"RRI": {
|
||
"name": "RRI",
|
||
"data": rri,
|
||
"fs": target_fs,
|
||
"length": len(rri),
|
||
"second": len(rri) / target_fs
|
||
},
|
||
"5_class": {
|
||
"name": "Stage",
|
||
"data": psg_data["5_class"]["data"],
|
||
"fs": psg_data["5_class"]["fs"],
|
||
"length": psg_data["5_class"]["length"],
|
||
"second": psg_data["5_class"]["second"]
|
||
}
|
||
}
|
||
|
||
np.savez_compressed(save_signal_path, **psg_data)
|
||
np.savez_compressed(save_segment_path,
|
||
segment_list=segment_list,
|
||
disable_segment_list=disable_segment_list)
|
||
if verbose:
|
||
print(f"Saved processed signals to: {save_signal_path}")
|
||
print(f"Saved segment list to: {save_segment_path}")
|
||
|
||
if draw_segment:
|
||
total_len = len(segment_list) + len(disable_segment_list)
|
||
if verbose:
|
||
print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}")
|
||
|
||
|
||
draw_tools.draw_psg_label(
|
||
psg_data=psg_data,
|
||
psg_label=event_mask["SA_Label"],
|
||
segment_list=segment_list,
|
||
save_path=visual_path / f"{samp_id}" / "enable",
|
||
verbose=verbose,
|
||
multi_p=multi_p,
|
||
multi_task_id=multi_task_id
|
||
)
|
||
|
||
draw_tools.draw_psg_label(
|
||
psg_data=psg_data,
|
||
psg_label=event_mask["SA_Label"],
|
||
segment_list=disable_segment_list,
|
||
save_path=visual_path / f"{samp_id}" / "disable",
|
||
verbose=verbose,
|
||
multi_p=multi_p,
|
||
multi_task_id=multi_task_id
|
||
)
|
||
|
||
# 显式删除大型对象
|
||
try:
|
||
del psg_data
|
||
del normalized_tho_signal, normalized_abd_signal
|
||
del normalized_flowp_signal, normalized_flowt_signal
|
||
del normalized_effort_signal
|
||
del spo2_data_filt, spo2_data_filt_fill
|
||
del rri
|
||
del event_mask, event_list
|
||
del segment_list, disable_segment_list
|
||
except:
|
||
pass
|
||
|
||
# 强制垃圾回收
|
||
gc.collect()
|
||
|
||
|
||
|
||
def multiprocess_entry(_progress, task_id, _id):
|
||
build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id)
|
||
|
||
|
||
def multiprocess_with_tqdm(args_list, n_processes):
|
||
from concurrent.futures import ProcessPoolExecutor
|
||
from rich import progress
|
||
|
||
with progress.Progress(
|
||
"[progress.description]{task.description}",
|
||
progress.BarColumn(),
|
||
"[progress.percentage]{task.percentage:>3.0f}%",
|
||
progress.MofNCompleteColumn(),
|
||
progress.TimeRemainingColumn(),
|
||
progress.TimeElapsedColumn(),
|
||
refresh_per_second=1, # bit slower updates
|
||
transient=False
|
||
) as progress:
|
||
futures = []
|
||
with multiprocessing.Manager() as manager:
|
||
_progress = manager.dict()
|
||
overall_progress_task = progress.add_task("[green]All jobs progress:")
|
||
with ProcessPoolExecutor(max_workers=n_processes, mp_context=multiprocessing.get_context("spawn")) as executor:
|
||
for i_args in range(len(args_list)):
|
||
args = args_list[i_args]
|
||
task_id = progress.add_task(f"task {i_args}", visible=True)
|
||
futures.append(executor.submit(multiprocess_entry, _progress, task_id, args_list[i_args]))
|
||
# monitor the progress:
|
||
while (n_finished := sum([future.done() for future in futures])) < len(
|
||
futures
|
||
):
|
||
progress.update(
|
||
overall_progress_task, completed=n_finished, total=len(futures)
|
||
)
|
||
for task_id, update_data in _progress.items():
|
||
desc = update_data.get("desc", "")
|
||
# update the progress bar for this task:
|
||
progress.update(
|
||
task_id,
|
||
completed=update_data.get("progress", 0),
|
||
total=update_data.get("total", 0),
|
||
description=desc
|
||
)
|
||
|
||
# raise any errors:
|
||
for future in futures:
|
||
future.result()
|
||
|
||
|
||
def multiprocess_with_pool(args_list, n_processes):
|
||
"""使用Pool,每个worker处理固定数量任务后重启"""
|
||
from multiprocessing import Pool
|
||
|
||
# maxtasksperchild 设置每个worker处理多少任务后重启(释放内存)
|
||
with Pool(processes=n_processes, maxtasksperchild=2) as pool:
|
||
results = []
|
||
for samp_id in args_list:
|
||
result = pool.apply_async(
|
||
build_HYS_dataset_segment,
|
||
args=(samp_id, False, True, False, None, None)
|
||
)
|
||
results.append(result)
|
||
|
||
# 等待所有任务完成
|
||
for i, result in enumerate(results):
|
||
try:
|
||
result.get()
|
||
print(f"Completed {i + 1}/{len(args_list)}")
|
||
except Exception as e:
|
||
print(f"Error processing task {i}: {e}")
|
||
|
||
pool.close()
|
||
pool.join()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
yaml_path = project_root_path / "dataset_config/HYS_PSG_config.yaml"
|
||
|
||
conf = utils.load_dataset_conf(yaml_path)
|
||
select_ids = conf["select_ids"]
|
||
root_path = Path(conf["root_path"])
|
||
mask_path = Path(conf["mask_save_path"])
|
||
save_path = Path(conf["dataset_config"]["dataset_save_path"])
|
||
visual_path = Path(conf["dataset_config"]["dataset_visual_path"])
|
||
dataset_config = conf["dataset_config"]
|
||
|
||
visual_path.mkdir(parents=True, exist_ok=True)
|
||
|
||
save_processed_signal_path = save_path / "Signals"
|
||
save_processed_signal_path.mkdir(parents=True, exist_ok=True)
|
||
|
||
save_segment_list_path = save_path / "Segments_List"
|
||
save_segment_list_path.mkdir(parents=True, exist_ok=True)
|
||
|
||
save_processed_label_path = save_path / "Labels"
|
||
save_processed_label_path.mkdir(parents=True, exist_ok=True)
|
||
|
||
# print(f"select_ids: {select_ids}")
|
||
# print(f"root_path: {root_path}")
|
||
# print(f"save_path: {save_path}")
|
||
# print(f"visual_path: {visual_path}")
|
||
|
||
org_signal_root_path = root_path / "OrgBCG_Aligned"
|
||
psg_signal_root_path = root_path / "PSG_Aligned"
|
||
print(select_ids)
|
||
|
||
# build_HYS_dataset_segment(select_ids[3], show=False, draw_segment=True)
|
||
|
||
# for samp_id in select_ids:
|
||
# print(f"Processing sample ID: {samp_id}")
|
||
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
|
||
|
||
# multiprocess_with_tqdm(args_list=select_ids, n_processes=8)
|
||
multiprocess_with_pool(args_list=select_ids, n_processes=8) |