修正SA_Score的处理逻辑,移除不必要的Alpha转换,优化数据掩码生成
This commit is contained in:
parent
d829f3e43d
commit
d09ffecf70
@ -1,8 +1,8 @@
|
||||
import multiprocessing
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
os.environ['DISPLAY'] = "localhost:10.0"
|
||||
@ -16,21 +16,23 @@ import draw_tools
|
||||
import shutil
|
||||
|
||||
|
||||
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
|
||||
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None):
|
||||
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
|
||||
if not signal_path:
|
||||
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
|
||||
signal_path = signal_path[0]
|
||||
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
|
||||
if verbose:
|
||||
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
|
||||
|
||||
mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv")
|
||||
print(f"mask_excel_path: {mask_excel_path}")
|
||||
if verbose:
|
||||
print(f"mask_excel_path: {mask_excel_path}")
|
||||
|
||||
event_mask, event_list = utils.read_mask_execl(mask_excel_path)
|
||||
|
||||
bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float)
|
||||
bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=verbose)
|
||||
|
||||
bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs)
|
||||
bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs, verbose=verbose)
|
||||
normalized_resp_signal = signal_method.normalize_resp_signal(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"])
|
||||
|
||||
|
||||
@ -63,8 +65,9 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
|
||||
# show=show,
|
||||
# save_path=None)
|
||||
|
||||
segment_list = utils.resp_split(dataset_config, event_mask, event_list)
|
||||
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
|
||||
segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose)
|
||||
if verbose:
|
||||
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
|
||||
|
||||
|
||||
# 复制mask到processed_Labels文件夹
|
||||
@ -77,7 +80,8 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
|
||||
save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv"
|
||||
shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path)
|
||||
else:
|
||||
print(f"Warning: {sa_label_corrected_path} does not exist.")
|
||||
if verbose:
|
||||
print(f"Warning: {sa_label_corrected_path} does not exist.")
|
||||
|
||||
# 保存处理后的信号和截取的片段列表
|
||||
save_signal_path = save_processed_signal_path / f"{samp_id}_Processed_Signals.npz"
|
||||
@ -109,12 +113,14 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
|
||||
|
||||
np.savez_compressed(save_signal_path, **bcg_data)
|
||||
np.savez_compressed(save_segment_path,
|
||||
segment_list=segment_list)
|
||||
print(f"Saved processed signals to: {save_signal_path}")
|
||||
print(f"Saved segment list to: {save_segment_path}")
|
||||
segment_list=segment_list,
|
||||
disable_segment_list=disable_segment_list)
|
||||
if verbose:
|
||||
print(f"Saved processed signals to: {save_signal_path}")
|
||||
print(f"Saved segment list to: {save_segment_path}")
|
||||
|
||||
if draw_segment:
|
||||
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8])
|
||||
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
|
||||
psg_data["HR"] = {
|
||||
"name": "HR",
|
||||
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"]),
|
||||
@ -124,14 +130,86 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
|
||||
}
|
||||
|
||||
|
||||
psg_label = utils.read_psg_label(sa_label_corrected_path)
|
||||
psg_label = utils.read_psg_label(sa_label_corrected_path, verbose=verbose)
|
||||
psg_event_mask, _ = utils.generate_event_mask(event_df=psg_label, signal_second=psg_data["ECG_Sync"]["second"], use_correct=False)
|
||||
total_len = len(segment_list) + len(disable_segment_list)
|
||||
if verbose:
|
||||
print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}")
|
||||
|
||||
draw_tools.draw_psg_bcg_label(psg_data=psg_data,
|
||||
psg_label=psg_event_mask,
|
||||
bcg_data=bcg_data,
|
||||
event_mask=event_mask,
|
||||
segment_list=segment_list,
|
||||
save_path=visual_path / f"{samp_id}")
|
||||
save_path=visual_path / f"{samp_id}" / "enable",
|
||||
verbose=verbose,
|
||||
multi_p=multi_p,
|
||||
multi_task_id=multi_task_id
|
||||
)
|
||||
|
||||
draw_tools.draw_psg_bcg_label(
|
||||
psg_data=psg_data,
|
||||
psg_label=psg_event_mask,
|
||||
bcg_data=bcg_data,
|
||||
event_mask=event_mask,
|
||||
segment_list=disable_segment_list,
|
||||
save_path=visual_path / f"{samp_id}" / "disable",
|
||||
verbose=verbose,
|
||||
multi_p=multi_p,
|
||||
multi_task_id=multi_task_id
|
||||
)
|
||||
|
||||
|
||||
|
||||
def multiprocess_entry(_progress, task_id, _id):
|
||||
build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id)
|
||||
|
||||
|
||||
def multiprocess_with_tqdm(args_list, n_processes):
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from rich import progress
|
||||
|
||||
with progress.Progress(
|
||||
"[progress.description]{task.description}",
|
||||
progress.BarColumn(),
|
||||
"[progress.percentage]{task.percentage:>3.0f}%",
|
||||
progress.MofNCompleteColumn(),
|
||||
progress.TimeRemainingColumn(),
|
||||
progress.TimeElapsedColumn(),
|
||||
refresh_per_second=1, # bit slower updates
|
||||
transient=False
|
||||
) as progress:
|
||||
futures = []
|
||||
with multiprocessing.Manager() as manager:
|
||||
_progress = manager.dict()
|
||||
overall_progress_task = progress.add_task("[green]All jobs progress:")
|
||||
with ProcessPoolExecutor(max_workers=n_processes) as executor:
|
||||
for i_args in range(len(args_list)):
|
||||
args = args_list[i_args]
|
||||
task_id = progress.add_task(f"task {i_args}", visible=True)
|
||||
futures.append(executor.submit(multiprocess_entry, _progress, task_id, args_list[i_args]))
|
||||
# monitor the progress:
|
||||
while (n_finished := sum([future.done() for future in futures])) < len(
|
||||
futures
|
||||
):
|
||||
progress.update(
|
||||
overall_progress_task, completed=n_finished, total=len(futures)
|
||||
)
|
||||
for task_id, update_data in _progress.items():
|
||||
desc = update_data.get("desc", "")
|
||||
# update the progress bar for this task:
|
||||
progress.update(
|
||||
task_id,
|
||||
completed=update_data.get("progress", 0),
|
||||
total=update_data.get("total", 0),
|
||||
description=desc
|
||||
)
|
||||
|
||||
# raise any errors:
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -156,16 +234,18 @@ if __name__ == '__main__':
|
||||
save_processed_label_path = save_path / "Labels"
|
||||
save_processed_label_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"select_ids: {select_ids}")
|
||||
print(f"root_path: {root_path}")
|
||||
print(f"save_path: {save_path}")
|
||||
print(f"visual_path: {visual_path}")
|
||||
# print(f"select_ids: {select_ids}")
|
||||
# print(f"root_path: {root_path}")
|
||||
# print(f"save_path: {save_path}")
|
||||
# print(f"visual_path: {visual_path}")
|
||||
|
||||
org_signal_root_path = root_path / "OrgBCG_Aligned"
|
||||
psg_signal_root_path = root_path / "PSG_Aligned"
|
||||
|
||||
# build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True)
|
||||
build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True)
|
||||
|
||||
for samp_id in select_ids:
|
||||
print(f"Processing sample ID: {samp_id}")
|
||||
build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
|
||||
# for samp_id in select_ids:
|
||||
# print(f"Processing sample ID: {samp_id}")
|
||||
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
|
||||
|
||||
# multiprocess_with_tqdm(args_list=select_ids, n_processes=16)
|
||||
@ -19,8 +19,8 @@ resp:
|
||||
|
||||
resp_filter:
|
||||
filter_type: bandpass
|
||||
low_cut: 0.01
|
||||
high_cut: 0.7
|
||||
low_cut: 0.05
|
||||
high_cut: 0.6
|
||||
order: 3
|
||||
|
||||
resp_low_amp:
|
||||
|
||||
@ -72,7 +72,6 @@ def create_psg_bcg_figure():
|
||||
axes[psg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white")
|
||||
axes.append(axes[psg_chn_name2ax["bcg"]].twinx())
|
||||
|
||||
|
||||
axes[psg_chn_name2ax["Stage"]].grid(True)
|
||||
# axes[7].xaxis.set_major_formatter(Params.FORMATTER)
|
||||
|
||||
@ -183,7 +182,8 @@ def score_mask2alpha(score_mask):
|
||||
return alpha_mask
|
||||
|
||||
|
||||
def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None):
|
||||
def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None, verbose=True,
|
||||
multi_p=None, multi_task_id=None):
|
||||
if save_path is not None:
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -191,13 +191,13 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list,
|
||||
if mask.startswith("Resp_") or mask.startswith("BCG_"):
|
||||
event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0)
|
||||
|
||||
event_mask["SA_Score"] = utils.none_to_nan_mask(event_mask["SA_Score"], 0)
|
||||
event_mask["SA_Score"] = utils.none_to_nan_mask(event_mask["SA_Score"], 0)
|
||||
|
||||
# event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"])
|
||||
# event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0)
|
||||
|
||||
fig, axes = create_psg_bcg_figure()
|
||||
for segment_start, segment_end in tqdm(segment_list):
|
||||
for i, (segment_start, segment_end) in enumerate(segment_list):
|
||||
for ax in axes:
|
||||
ax.cla()
|
||||
|
||||
@ -213,17 +213,21 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list,
|
||||
psg_label, event_codes=[1, 2, 3, 4])
|
||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
|
||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
|
||||
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["resp_twinx"]])
|
||||
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4],
|
||||
ax2=axes[psg_chn_name2ax["resp_twinx"]])
|
||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
|
||||
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["bcg_twinx"]])
|
||||
|
||||
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4],
|
||||
ax2=axes[psg_chn_name2ax["bcg_twinx"]])
|
||||
|
||||
|
||||
if save_path is not None:
|
||||
fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png")
|
||||
tqdm.write(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
|
||||
# print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
|
||||
|
||||
if multi_p is not None:
|
||||
multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"}
|
||||
|
||||
|
||||
def draw_resp_label(resp_data, resp_label, segment_list):
|
||||
for mask in resp_label.keys():
|
||||
if mask.startswith("Resp_"):
|
||||
|
||||
@ -52,7 +52,7 @@ def process_one_signal(samp_id, show=False):
|
||||
save_samp_path = save_path / f"{samp_id}"
|
||||
save_samp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, verbose=True)
|
||||
signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=True)
|
||||
|
||||
signal_data, resp_data, resp_fs, bcg_data, bcg_fs = signal_method.signal_filter_split(conf=conf, signal_data_raw=signal_data_raw, signal_fs=signal_fs)
|
||||
|
||||
|
||||
@ -2,11 +2,12 @@ import numpy as np
|
||||
|
||||
import utils
|
||||
|
||||
def signal_filter_split(conf, signal_data_raw, signal_fs):
|
||||
def signal_filter_split(conf, signal_data_raw, signal_fs, verbose=True):
|
||||
# 滤波
|
||||
# 50Hz陷波滤波器
|
||||
# signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs)
|
||||
print("Applying 50Hz notch filter...")
|
||||
if verbose:
|
||||
print("Applying 50Hz notch filter...")
|
||||
signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs)
|
||||
|
||||
resp_data_0 = utils.butterworth(data=signal_data, _type="lowpass", low_cut=50, order=10, sample_rate=signal_fs)
|
||||
@ -17,7 +18,8 @@ def signal_filter_split(conf, signal_data_raw, signal_fs):
|
||||
low_cut=conf["resp_filter"]["low_cut"],
|
||||
high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"],
|
||||
sample_rate=resp_fs)
|
||||
print("Begin plotting signal data...")
|
||||
if verbose:
|
||||
print("Begin plotting signal data...")
|
||||
|
||||
# fig = plt.figure(figsize=(12, 8))
|
||||
# # 绘制三个图raw_data、resp_data_1、resp_data_2
|
||||
|
||||
@ -20,6 +20,7 @@ def read_signal_txt(path: Union[str, Path], dtype, verbose=True, is_peak=False):
|
||||
Read a txt file and return the first column as a numpy array.
|
||||
|
||||
Args:
|
||||
:param is_peak:
|
||||
:param path:
|
||||
:param verbose:
|
||||
:param dtype:
|
||||
@ -217,14 +218,15 @@ def read_mask_execl(path: Union[str, Path]):
|
||||
|
||||
event_list = {"RespAmpChangeSegment": event_mask_2_list(1 - event_mask["Resp_AmpChange_Label"]),
|
||||
"BCGAmpChangeSegment": event_mask_2_list(1 - event_mask["BCG_AmpChange_Label"]),
|
||||
"EnableSegment": event_mask_2_list(1 - event_mask["Disable_Label"]),}
|
||||
"EnableSegment": event_mask_2_list(1 - event_mask["Disable_Label"]),
|
||||
"DisableSegment": event_mask_2_list(event_mask["Disable_Label"])}
|
||||
|
||||
|
||||
return event_mask, event_list
|
||||
|
||||
|
||||
|
||||
def read_psg_channel(path_str: Union[str, Path], channel_number: list[int]):
|
||||
def read_psg_channel(path_str: Union[str, Path], channel_number: list[int], verbose=True):
|
||||
"""
|
||||
读取PSG文件中特定通道的数据。
|
||||
|
||||
@ -254,16 +256,16 @@ def read_psg_channel(path_str: Union[str, Path], channel_number: list[int]):
|
||||
|
||||
if ch_id == 8:
|
||||
# sleep stage 特例 读取为整数
|
||||
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=str, verbose=True)
|
||||
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=str, verbose=verbose)
|
||||
# 转换为整数数组
|
||||
for stage_str, stage_number in utils.Stage2N.items():
|
||||
np.place(ch_signal, ch_signal == stage_str, stage_number)
|
||||
ch_signal = ch_signal.astype(int)
|
||||
elif ch_id == 1:
|
||||
# Rpeak 特例 读取为整数
|
||||
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=int, verbose=True, is_peak=True)
|
||||
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=int, verbose=verbose, is_peak=True)
|
||||
else:
|
||||
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=float, verbose=True)
|
||||
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=float, verbose=verbose)
|
||||
channel_data[ch_name] = {
|
||||
"name": ch_name,
|
||||
"path": ch_path[0],
|
||||
|
||||
@ -1,27 +1,58 @@
|
||||
|
||||
def check_split(event_mask, current_start, window_sec, verbose=False):
|
||||
# 检查当前窗口是否包含在禁用区间或低幅值区间内
|
||||
resp_movement_mask = event_mask["Resp_Movement_Label"][current_start : current_start + window_sec]
|
||||
resp_low_amp_mask = event_mask["Resp_LowAmp_Label"][current_start : current_start + window_sec]
|
||||
|
||||
# 体动与低幅值进行与计算
|
||||
low_move_mask = resp_movement_mask | resp_low_amp_mask
|
||||
if low_move_mask.sum() > 2/3 * window_sec:
|
||||
if verbose:
|
||||
print(f"{current_start}-{current_start + window_sec} rejected due to movement/low amplitude mask more than 2/3")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
|
||||
def resp_split(dataset_config, event_mask, event_list):
|
||||
def resp_split(dataset_config, event_mask, event_list, verbose=False):
|
||||
# 提取体动区间和呼吸低幅值区间
|
||||
enable_list = event_list["EnableSegment"]
|
||||
disable_list = event_list["DisableSegment"]
|
||||
|
||||
# 读取数据集配置
|
||||
window_sec = dataset_config["window_sec"]
|
||||
stride_sec = dataset_config["stride_sec"]
|
||||
|
||||
segment_list = []
|
||||
disable_segment_list = []
|
||||
|
||||
# 遍历每个enable区间, 如果最后一个窗口不足stride的1/2,则舍弃,否则以enable_end为结尾截取一个窗口
|
||||
for enable_start, enable_end in enable_list:
|
||||
current_start = enable_start
|
||||
while current_start + window_sec <= enable_end:
|
||||
segment_list.append((current_start, current_start + window_sec))
|
||||
if check_split(event_mask, current_start, window_sec, verbose):
|
||||
segment_list.append((current_start, current_start + window_sec))
|
||||
else:
|
||||
disable_segment_list.append((current_start, current_start + window_sec))
|
||||
current_start += stride_sec
|
||||
# 检查最后一个窗口是否需要添加
|
||||
if (enable_end - current_start >= stride_sec / 2) and (enable_end - current_start >= window_sec):
|
||||
segment_list.append((enable_end - window_sec, enable_end))
|
||||
if check_split(event_mask, enable_end - window_sec, window_sec, verbose):
|
||||
segment_list.append((enable_end - window_sec, enable_end))
|
||||
else:
|
||||
disable_segment_list.append((enable_end - window_sec, enable_end))
|
||||
|
||||
return segment_list
|
||||
# 遍历每个disable区间, 如果最后一个窗口不足stride的1/2,则舍弃,否则以disable_end为结尾截取一个窗口
|
||||
for disable_start, disable_end in disable_list:
|
||||
current_start = disable_start
|
||||
while current_start + window_sec <= disable_end:
|
||||
disable_segment_list.append((current_start, current_start + window_sec))
|
||||
current_start += stride_sec
|
||||
# 检查最后一个窗口是否需要添加
|
||||
if (disable_end - current_start >= stride_sec / 2) and (disable_end - current_start >= window_sec):
|
||||
disable_segment_list.append((disable_end - window_sec, disable_end))
|
||||
|
||||
|
||||
return segment_list, disable_segment_list
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user