修正SA_Score的处理逻辑,移除不必要的Alpha转换,优化数据掩码生成

This commit is contained in:
marques 2025-12-30 16:54:45 +08:00
parent d829f3e43d
commit d09ffecf70
8 changed files with 170 additions and 48 deletions

View File

@ -1,2 +1,5 @@
# DataPrepare # DataPrepare
## 操作步骤
1. 信号预处理
2. 数据集构建
3. 数据可视化(可选)

View File

@ -1,8 +1,8 @@
import multiprocessing
import sys import sys
from pathlib import Path from pathlib import Path
import os import os
import numpy as np import numpy as np
os.environ['DISPLAY'] = "localhost:10.0" os.environ['DISPLAY'] = "localhost:10.0"
@ -16,21 +16,23 @@ import draw_tools
import shutil import shutil
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False): def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None):
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt")) signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
if not signal_path: if not signal_path:
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}") raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
signal_path = signal_path[0] signal_path = signal_path[0]
if verbose:
print(f"Processing OrgBCG_Sync signal file: {signal_path}") print(f"Processing OrgBCG_Sync signal file: {signal_path}")
mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv") mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv")
if verbose:
print(f"mask_excel_path: {mask_excel_path}") print(f"mask_excel_path: {mask_excel_path}")
event_mask, event_list = utils.read_mask_execl(mask_excel_path) event_mask, event_list = utils.read_mask_execl(mask_excel_path)
bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float) bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=verbose)
bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs) bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs, verbose=verbose)
normalized_resp_signal = signal_method.normalize_resp_signal(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"]) normalized_resp_signal = signal_method.normalize_resp_signal(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"])
@ -63,7 +65,8 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
# show=show, # show=show,
# save_path=None) # save_path=None)
segment_list = utils.resp_split(dataset_config, event_mask, event_list) segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose)
if verbose:
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}") print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
@ -77,6 +80,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv" save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv"
shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path) shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path)
else: else:
if verbose:
print(f"Warning: {sa_label_corrected_path} does not exist.") print(f"Warning: {sa_label_corrected_path} does not exist.")
# 保存处理后的信号和截取的片段列表 # 保存处理后的信号和截取的片段列表
@ -109,12 +113,14 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
np.savez_compressed(save_signal_path, **bcg_data) np.savez_compressed(save_signal_path, **bcg_data)
np.savez_compressed(save_segment_path, np.savez_compressed(save_segment_path,
segment_list=segment_list) segment_list=segment_list,
disable_segment_list=disable_segment_list)
if verbose:
print(f"Saved processed signals to: {save_signal_path}") print(f"Saved processed signals to: {save_signal_path}")
print(f"Saved segment list to: {save_segment_path}") print(f"Saved segment list to: {save_segment_path}")
if draw_segment: if draw_segment:
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8]) psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
psg_data["HR"] = { psg_data["HR"] = {
"name": "HR", "name": "HR",
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"]), "data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"]),
@ -124,14 +130,86 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
} }
psg_label = utils.read_psg_label(sa_label_corrected_path) psg_label = utils.read_psg_label(sa_label_corrected_path, verbose=verbose)
psg_event_mask, _ = utils.generate_event_mask(event_df=psg_label, signal_second=psg_data["ECG_Sync"]["second"], use_correct=False) psg_event_mask, _ = utils.generate_event_mask(event_df=psg_label, signal_second=psg_data["ECG_Sync"]["second"], use_correct=False)
total_len = len(segment_list) + len(disable_segment_list)
if verbose:
print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}")
draw_tools.draw_psg_bcg_label(psg_data=psg_data, draw_tools.draw_psg_bcg_label(psg_data=psg_data,
psg_label=psg_event_mask, psg_label=psg_event_mask,
bcg_data=bcg_data, bcg_data=bcg_data,
event_mask=event_mask, event_mask=event_mask,
segment_list=segment_list, segment_list=segment_list,
save_path=visual_path / f"{samp_id}") save_path=visual_path / f"{samp_id}" / "enable",
verbose=verbose,
multi_p=multi_p,
multi_task_id=multi_task_id
)
draw_tools.draw_psg_bcg_label(
psg_data=psg_data,
psg_label=psg_event_mask,
bcg_data=bcg_data,
event_mask=event_mask,
segment_list=disable_segment_list,
save_path=visual_path / f"{samp_id}" / "disable",
verbose=verbose,
multi_p=multi_p,
multi_task_id=multi_task_id
)
def multiprocess_entry(_progress, task_id, _id):
build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id)
def multiprocess_with_tqdm(args_list, n_processes):
from concurrent.futures import ProcessPoolExecutor
from rich import progress
with progress.Progress(
"[progress.description]{task.description}",
progress.BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
progress.MofNCompleteColumn(),
progress.TimeRemainingColumn(),
progress.TimeElapsedColumn(),
refresh_per_second=1, # bit slower updates
transient=False
) as progress:
futures = []
with multiprocessing.Manager() as manager:
_progress = manager.dict()
overall_progress_task = progress.add_task("[green]All jobs progress:")
with ProcessPoolExecutor(max_workers=n_processes) as executor:
for i_args in range(len(args_list)):
args = args_list[i_args]
task_id = progress.add_task(f"task {i_args}", visible=True)
futures.append(executor.submit(multiprocess_entry, _progress, task_id, args_list[i_args]))
# monitor the progress:
while (n_finished := sum([future.done() for future in futures])) < len(
futures
):
progress.update(
overall_progress_task, completed=n_finished, total=len(futures)
)
for task_id, update_data in _progress.items():
desc = update_data.get("desc", "")
# update the progress bar for this task:
progress.update(
task_id,
completed=update_data.get("progress", 0),
total=update_data.get("total", 0),
description=desc
)
# raise any errors:
for future in futures:
future.result()
if __name__ == '__main__': if __name__ == '__main__':
@ -156,16 +234,18 @@ if __name__ == '__main__':
save_processed_label_path = save_path / "Labels" save_processed_label_path = save_path / "Labels"
save_processed_label_path.mkdir(parents=True, exist_ok=True) save_processed_label_path.mkdir(parents=True, exist_ok=True)
print(f"select_ids: {select_ids}") # print(f"select_ids: {select_ids}")
print(f"root_path: {root_path}") # print(f"root_path: {root_path}")
print(f"save_path: {save_path}") # print(f"save_path: {save_path}")
print(f"visual_path: {visual_path}") # print(f"visual_path: {visual_path}")
org_signal_root_path = root_path / "OrgBCG_Aligned" org_signal_root_path = root_path / "OrgBCG_Aligned"
psg_signal_root_path = root_path / "PSG_Aligned" psg_signal_root_path = root_path / "PSG_Aligned"
# build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True) build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True)
for samp_id in select_ids: # for samp_id in select_ids:
print(f"Processing sample ID: {samp_id}") # print(f"Processing sample ID: {samp_id}")
build_HYS_dataset_segment(samp_id, show=False, draw_segment=True) # build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
# multiprocess_with_tqdm(args_list=select_ids, n_processes=16)

View File

@ -19,8 +19,8 @@ resp:
resp_filter: resp_filter:
filter_type: bandpass filter_type: bandpass
low_cut: 0.01 low_cut: 0.05
high_cut: 0.7 high_cut: 0.6
order: 3 order: 3
resp_low_amp: resp_low_amp:

View File

@ -72,7 +72,6 @@ def create_psg_bcg_figure():
axes[psg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white")
axes.append(axes[psg_chn_name2ax["bcg"]].twinx()) axes.append(axes[psg_chn_name2ax["bcg"]].twinx())
axes[psg_chn_name2ax["Stage"]].grid(True) axes[psg_chn_name2ax["Stage"]].grid(True)
# axes[7].xaxis.set_major_formatter(Params.FORMATTER) # axes[7].xaxis.set_major_formatter(Params.FORMATTER)
@ -183,7 +182,8 @@ def score_mask2alpha(score_mask):
return alpha_mask return alpha_mask
def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None): def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None, verbose=True,
multi_p=None, multi_task_id=None):
if save_path is not None: if save_path is not None:
save_path.mkdir(parents=True, exist_ok=True) save_path.mkdir(parents=True, exist_ok=True)
@ -197,7 +197,7 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list,
# event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0) # event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0)
fig, axes = create_psg_bcg_figure() fig, axes = create_psg_bcg_figure()
for segment_start, segment_end in tqdm(segment_list): for i, (segment_start, segment_end) in enumerate(segment_list):
for ax in axes: for ax in axes:
ax.cla() ax.cla()
@ -213,17 +213,21 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list,
psg_label, event_codes=[1, 2, 3, 4]) psg_label, event_codes=[1, 2, 3, 4])
plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end, plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["resp_twinx"]]) event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4],
ax2=axes[psg_chn_name2ax["resp_twinx"]])
plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end, plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["bcg_twinx"]]) event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4],
ax2=axes[psg_chn_name2ax["bcg_twinx"]])
if save_path is not None: if save_path is not None:
fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png") fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png")
tqdm.write(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
# print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}") # print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
if multi_p is not None:
multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"}
def draw_resp_label(resp_data, resp_label, segment_list): def draw_resp_label(resp_data, resp_label, segment_list):
for mask in resp_label.keys(): for mask in resp_label.keys():
if mask.startswith("Resp_"): if mask.startswith("Resp_"):

View File

@ -52,7 +52,7 @@ def process_one_signal(samp_id, show=False):
save_samp_path = save_path / f"{samp_id}" save_samp_path = save_path / f"{samp_id}"
save_samp_path.mkdir(parents=True, exist_ok=True) save_samp_path.mkdir(parents=True, exist_ok=True)
signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, verbose=True) signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=True)
signal_data, resp_data, resp_fs, bcg_data, bcg_fs = signal_method.signal_filter_split(conf=conf, signal_data_raw=signal_data_raw, signal_fs=signal_fs) signal_data, resp_data, resp_fs, bcg_data, bcg_fs = signal_method.signal_filter_split(conf=conf, signal_data_raw=signal_data_raw, signal_fs=signal_fs)

View File

@ -2,10 +2,11 @@ import numpy as np
import utils import utils
def signal_filter_split(conf, signal_data_raw, signal_fs): def signal_filter_split(conf, signal_data_raw, signal_fs, verbose=True):
# 滤波 # 滤波
# 50Hz陷波滤波器 # 50Hz陷波滤波器
# signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs) # signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs)
if verbose:
print("Applying 50Hz notch filter...") print("Applying 50Hz notch filter...")
signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs) signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs)
@ -17,6 +18,7 @@ def signal_filter_split(conf, signal_data_raw, signal_fs):
low_cut=conf["resp_filter"]["low_cut"], low_cut=conf["resp_filter"]["low_cut"],
high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"], high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"],
sample_rate=resp_fs) sample_rate=resp_fs)
if verbose:
print("Begin plotting signal data...") print("Begin plotting signal data...")
# fig = plt.figure(figsize=(12, 8)) # fig = plt.figure(figsize=(12, 8))

View File

@ -20,6 +20,7 @@ def read_signal_txt(path: Union[str, Path], dtype, verbose=True, is_peak=False):
Read a txt file and return the first column as a numpy array. Read a txt file and return the first column as a numpy array.
Args: Args:
:param is_peak:
:param path: :param path:
:param verbose: :param verbose:
:param dtype: :param dtype:
@ -217,14 +218,15 @@ def read_mask_execl(path: Union[str, Path]):
event_list = {"RespAmpChangeSegment": event_mask_2_list(1 - event_mask["Resp_AmpChange_Label"]), event_list = {"RespAmpChangeSegment": event_mask_2_list(1 - event_mask["Resp_AmpChange_Label"]),
"BCGAmpChangeSegment": event_mask_2_list(1 - event_mask["BCG_AmpChange_Label"]), "BCGAmpChangeSegment": event_mask_2_list(1 - event_mask["BCG_AmpChange_Label"]),
"EnableSegment": event_mask_2_list(1 - event_mask["Disable_Label"]),} "EnableSegment": event_mask_2_list(1 - event_mask["Disable_Label"]),
"DisableSegment": event_mask_2_list(event_mask["Disable_Label"])}
return event_mask, event_list return event_mask, event_list
def read_psg_channel(path_str: Union[str, Path], channel_number: list[int]): def read_psg_channel(path_str: Union[str, Path], channel_number: list[int], verbose=True):
""" """
读取PSG文件中特定通道的数据 读取PSG文件中特定通道的数据
@ -254,16 +256,16 @@ def read_psg_channel(path_str: Union[str, Path], channel_number: list[int]):
if ch_id == 8: if ch_id == 8:
# sleep stage 特例 读取为整数 # sleep stage 特例 读取为整数
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=str, verbose=True) ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=str, verbose=verbose)
# 转换为整数数组 # 转换为整数数组
for stage_str, stage_number in utils.Stage2N.items(): for stage_str, stage_number in utils.Stage2N.items():
np.place(ch_signal, ch_signal == stage_str, stage_number) np.place(ch_signal, ch_signal == stage_str, stage_number)
ch_signal = ch_signal.astype(int) ch_signal = ch_signal.astype(int)
elif ch_id == 1: elif ch_id == 1:
# Rpeak 特例 读取为整数 # Rpeak 特例 读取为整数
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=int, verbose=True, is_peak=True) ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=int, verbose=verbose, is_peak=True)
else: else:
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=float, verbose=True) ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=float, verbose=verbose)
channel_data[ch_name] = { channel_data[ch_name] = {
"name": ch_name, "name": ch_name,
"path": ch_path[0], "path": ch_path[0],

View File

@ -1,27 +1,58 @@
def check_split(event_mask, current_start, window_sec, verbose=False):
# 检查当前窗口是否包含在禁用区间或低幅值区间内
resp_movement_mask = event_mask["Resp_Movement_Label"][current_start : current_start + window_sec]
resp_low_amp_mask = event_mask["Resp_LowAmp_Label"][current_start : current_start + window_sec]
# 体动与低幅值进行与计算
low_move_mask = resp_movement_mask | resp_low_amp_mask
if low_move_mask.sum() > 2/3 * window_sec:
if verbose:
print(f"{current_start}-{current_start + window_sec} rejected due to movement/low amplitude mask more than 2/3")
return False
return True
def resp_split(dataset_config, event_mask, event_list, verbose=False):
def resp_split(dataset_config, event_mask, event_list):
# 提取体动区间和呼吸低幅值区间 # 提取体动区间和呼吸低幅值区间
enable_list = event_list["EnableSegment"] enable_list = event_list["EnableSegment"]
disable_list = event_list["DisableSegment"]
# 读取数据集配置 # 读取数据集配置
window_sec = dataset_config["window_sec"] window_sec = dataset_config["window_sec"]
stride_sec = dataset_config["stride_sec"] stride_sec = dataset_config["stride_sec"]
segment_list = [] segment_list = []
disable_segment_list = []
# 遍历每个enable区间, 如果最后一个窗口不足stride的1/2则舍弃否则以enable_end为结尾截取一个窗口 # 遍历每个enable区间, 如果最后一个窗口不足stride的1/2则舍弃否则以enable_end为结尾截取一个窗口
for enable_start, enable_end in enable_list: for enable_start, enable_end in enable_list:
current_start = enable_start current_start = enable_start
while current_start + window_sec <= enable_end: while current_start + window_sec <= enable_end:
if check_split(event_mask, current_start, window_sec, verbose):
segment_list.append((current_start, current_start + window_sec)) segment_list.append((current_start, current_start + window_sec))
else:
disable_segment_list.append((current_start, current_start + window_sec))
current_start += stride_sec current_start += stride_sec
# 检查最后一个窗口是否需要添加 # 检查最后一个窗口是否需要添加
if (enable_end - current_start >= stride_sec / 2) and (enable_end - current_start >= window_sec): if (enable_end - current_start >= stride_sec / 2) and (enable_end - current_start >= window_sec):
if check_split(event_mask, enable_end - window_sec, window_sec, verbose):
segment_list.append((enable_end - window_sec, enable_end)) segment_list.append((enable_end - window_sec, enable_end))
else:
disable_segment_list.append((enable_end - window_sec, enable_end))
return segment_list # 遍历每个disable区间, 如果最后一个窗口不足stride的1/2则舍弃否则以disable_end为结尾截取一个窗口
for disable_start, disable_end in disable_list:
current_start = disable_start
while current_start + window_sec <= disable_end:
disable_segment_list.append((current_start, current_start + window_sec))
current_start += stride_sec
# 检查最后一个窗口是否需要添加
if (disable_end - current_start >= stride_sec / 2) and (disable_end - current_start >= window_sec):
disable_segment_list.append((disable_end - window_sec, disable_end))
return segment_list, disable_segment_list