修正SA_Score的处理逻辑,移除不必要的Alpha转换,优化数据掩码生成
This commit is contained in:
parent
d829f3e43d
commit
d09ffecf70
@ -1,2 +1,5 @@
|
|||||||
# DataPrepare
|
# DataPrepare
|
||||||
|
## 操作步骤
|
||||||
|
1. 信号预处理
|
||||||
|
2. 数据集构建
|
||||||
|
3. 数据可视化(可选)
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
|
import multiprocessing
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
os.environ['DISPLAY'] = "localhost:10.0"
|
os.environ['DISPLAY'] = "localhost:10.0"
|
||||||
@ -16,21 +16,23 @@ import draw_tools
|
|||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
|
def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False, verbose=True, multi_p=None, multi_task_id=None):
|
||||||
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
|
signal_path = list((org_signal_root_path / f"{samp_id}").glob("OrgBCG_Sync_*.txt"))
|
||||||
if not signal_path:
|
if not signal_path:
|
||||||
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
|
raise FileNotFoundError(f"OrgBCG_Sync file not found for sample ID: {samp_id}")
|
||||||
signal_path = signal_path[0]
|
signal_path = signal_path[0]
|
||||||
|
if verbose:
|
||||||
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
|
print(f"Processing OrgBCG_Sync signal file: {signal_path}")
|
||||||
|
|
||||||
mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv")
|
mask_excel_path = Path(mask_path, f"{samp_id}", f"{samp_id}_Processed_Labels.csv")
|
||||||
|
if verbose:
|
||||||
print(f"mask_excel_path: {mask_excel_path}")
|
print(f"mask_excel_path: {mask_excel_path}")
|
||||||
|
|
||||||
event_mask, event_list = utils.read_mask_execl(mask_excel_path)
|
event_mask, event_list = utils.read_mask_execl(mask_excel_path)
|
||||||
|
|
||||||
bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float)
|
bcg_signal_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=verbose)
|
||||||
|
|
||||||
bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs)
|
bcg_signal_notch, resp_signal, resp_fs, bcg_signal, bcg_fs = signal_method.signal_filter_split(conf, bcg_signal_raw, signal_fs, verbose=verbose)
|
||||||
normalized_resp_signal = signal_method.normalize_resp_signal(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"])
|
normalized_resp_signal = signal_method.normalize_resp_signal(resp_signal, resp_fs, event_mask["Resp_Movement_Label"], event_list["RespAmpChangeSegment"])
|
||||||
|
|
||||||
|
|
||||||
@ -63,7 +65,8 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
|
|||||||
# show=show,
|
# show=show,
|
||||||
# save_path=None)
|
# save_path=None)
|
||||||
|
|
||||||
segment_list = utils.resp_split(dataset_config, event_mask, event_list)
|
segment_list, disable_segment_list = utils.resp_split(dataset_config, event_mask, event_list, verbose=verbose)
|
||||||
|
if verbose:
|
||||||
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
|
print(f"Total segments extracted for sample ID {samp_id}: {len(segment_list)}")
|
||||||
|
|
||||||
|
|
||||||
@ -77,6 +80,7 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
|
|||||||
save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv"
|
save_sa_label_corrected_path = save_processed_label_path / f"{samp_id}_SA Label_corrected.csv"
|
||||||
shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path)
|
shutil.copyfile(sa_label_corrected_path, save_sa_label_corrected_path)
|
||||||
else:
|
else:
|
||||||
|
if verbose:
|
||||||
print(f"Warning: {sa_label_corrected_path} does not exist.")
|
print(f"Warning: {sa_label_corrected_path} does not exist.")
|
||||||
|
|
||||||
# 保存处理后的信号和截取的片段列表
|
# 保存处理后的信号和截取的片段列表
|
||||||
@ -109,12 +113,14 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
|
|||||||
|
|
||||||
np.savez_compressed(save_signal_path, **bcg_data)
|
np.savez_compressed(save_signal_path, **bcg_data)
|
||||||
np.savez_compressed(save_segment_path,
|
np.savez_compressed(save_segment_path,
|
||||||
segment_list=segment_list)
|
segment_list=segment_list,
|
||||||
|
disable_segment_list=disable_segment_list)
|
||||||
|
if verbose:
|
||||||
print(f"Saved processed signals to: {save_signal_path}")
|
print(f"Saved processed signals to: {save_signal_path}")
|
||||||
print(f"Saved segment list to: {save_segment_path}")
|
print(f"Saved segment list to: {save_segment_path}")
|
||||||
|
|
||||||
if draw_segment:
|
if draw_segment:
|
||||||
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8])
|
psg_data = utils.read_psg_channel(psg_signal_root_path / f"{samp_id}", [1, 2, 3, 4, 5, 6, 7, 8], verbose=verbose)
|
||||||
psg_data["HR"] = {
|
psg_data["HR"] = {
|
||||||
"name": "HR",
|
"name": "HR",
|
||||||
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"]),
|
"data": signal_method.rpeak2hr(psg_data["Rpeak"]["data"], psg_data["ECG_Sync"]["length"]),
|
||||||
@ -124,14 +130,86 @@ def build_HYS_dataset_segment(samp_id, show=False, draw_segment=False):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
psg_label = utils.read_psg_label(sa_label_corrected_path)
|
psg_label = utils.read_psg_label(sa_label_corrected_path, verbose=verbose)
|
||||||
psg_event_mask, _ = utils.generate_event_mask(event_df=psg_label, signal_second=psg_data["ECG_Sync"]["second"], use_correct=False)
|
psg_event_mask, _ = utils.generate_event_mask(event_df=psg_label, signal_second=psg_data["ECG_Sync"]["second"], use_correct=False)
|
||||||
|
total_len = len(segment_list) + len(disable_segment_list)
|
||||||
|
if verbose:
|
||||||
|
print(f"Drawing segments for sample ID {samp_id}, total segments (enable + disable): {total_len}")
|
||||||
|
|
||||||
draw_tools.draw_psg_bcg_label(psg_data=psg_data,
|
draw_tools.draw_psg_bcg_label(psg_data=psg_data,
|
||||||
psg_label=psg_event_mask,
|
psg_label=psg_event_mask,
|
||||||
bcg_data=bcg_data,
|
bcg_data=bcg_data,
|
||||||
event_mask=event_mask,
|
event_mask=event_mask,
|
||||||
segment_list=segment_list,
|
segment_list=segment_list,
|
||||||
save_path=visual_path / f"{samp_id}")
|
save_path=visual_path / f"{samp_id}" / "enable",
|
||||||
|
verbose=verbose,
|
||||||
|
multi_p=multi_p,
|
||||||
|
multi_task_id=multi_task_id
|
||||||
|
)
|
||||||
|
|
||||||
|
draw_tools.draw_psg_bcg_label(
|
||||||
|
psg_data=psg_data,
|
||||||
|
psg_label=psg_event_mask,
|
||||||
|
bcg_data=bcg_data,
|
||||||
|
event_mask=event_mask,
|
||||||
|
segment_list=disable_segment_list,
|
||||||
|
save_path=visual_path / f"{samp_id}" / "disable",
|
||||||
|
verbose=verbose,
|
||||||
|
multi_p=multi_p,
|
||||||
|
multi_task_id=multi_task_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def multiprocess_entry(_progress, task_id, _id):
|
||||||
|
build_HYS_dataset_segment(samp_id=_id, show=False, draw_segment=True, verbose=False, multi_p=_progress, multi_task_id=task_id)
|
||||||
|
|
||||||
|
|
||||||
|
def multiprocess_with_tqdm(args_list, n_processes):
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
from rich import progress
|
||||||
|
|
||||||
|
with progress.Progress(
|
||||||
|
"[progress.description]{task.description}",
|
||||||
|
progress.BarColumn(),
|
||||||
|
"[progress.percentage]{task.percentage:>3.0f}%",
|
||||||
|
progress.MofNCompleteColumn(),
|
||||||
|
progress.TimeRemainingColumn(),
|
||||||
|
progress.TimeElapsedColumn(),
|
||||||
|
refresh_per_second=1, # bit slower updates
|
||||||
|
transient=False
|
||||||
|
) as progress:
|
||||||
|
futures = []
|
||||||
|
with multiprocessing.Manager() as manager:
|
||||||
|
_progress = manager.dict()
|
||||||
|
overall_progress_task = progress.add_task("[green]All jobs progress:")
|
||||||
|
with ProcessPoolExecutor(max_workers=n_processes) as executor:
|
||||||
|
for i_args in range(len(args_list)):
|
||||||
|
args = args_list[i_args]
|
||||||
|
task_id = progress.add_task(f"task {i_args}", visible=True)
|
||||||
|
futures.append(executor.submit(multiprocess_entry, _progress, task_id, args_list[i_args]))
|
||||||
|
# monitor the progress:
|
||||||
|
while (n_finished := sum([future.done() for future in futures])) < len(
|
||||||
|
futures
|
||||||
|
):
|
||||||
|
progress.update(
|
||||||
|
overall_progress_task, completed=n_finished, total=len(futures)
|
||||||
|
)
|
||||||
|
for task_id, update_data in _progress.items():
|
||||||
|
desc = update_data.get("desc", "")
|
||||||
|
# update the progress bar for this task:
|
||||||
|
progress.update(
|
||||||
|
task_id,
|
||||||
|
completed=update_data.get("progress", 0),
|
||||||
|
total=update_data.get("total", 0),
|
||||||
|
description=desc
|
||||||
|
)
|
||||||
|
|
||||||
|
# raise any errors:
|
||||||
|
for future in futures:
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -156,16 +234,18 @@ if __name__ == '__main__':
|
|||||||
save_processed_label_path = save_path / "Labels"
|
save_processed_label_path = save_path / "Labels"
|
||||||
save_processed_label_path.mkdir(parents=True, exist_ok=True)
|
save_processed_label_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
print(f"select_ids: {select_ids}")
|
# print(f"select_ids: {select_ids}")
|
||||||
print(f"root_path: {root_path}")
|
# print(f"root_path: {root_path}")
|
||||||
print(f"save_path: {save_path}")
|
# print(f"save_path: {save_path}")
|
||||||
print(f"visual_path: {visual_path}")
|
# print(f"visual_path: {visual_path}")
|
||||||
|
|
||||||
org_signal_root_path = root_path / "OrgBCG_Aligned"
|
org_signal_root_path = root_path / "OrgBCG_Aligned"
|
||||||
psg_signal_root_path = root_path / "PSG_Aligned"
|
psg_signal_root_path = root_path / "PSG_Aligned"
|
||||||
|
|
||||||
# build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True)
|
build_HYS_dataset_segment(select_ids[0], show=False, draw_segment=True)
|
||||||
|
|
||||||
for samp_id in select_ids:
|
# for samp_id in select_ids:
|
||||||
print(f"Processing sample ID: {samp_id}")
|
# print(f"Processing sample ID: {samp_id}")
|
||||||
build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
|
# build_HYS_dataset_segment(samp_id, show=False, draw_segment=True)
|
||||||
|
|
||||||
|
# multiprocess_with_tqdm(args_list=select_ids, n_processes=16)
|
||||||
@ -19,8 +19,8 @@ resp:
|
|||||||
|
|
||||||
resp_filter:
|
resp_filter:
|
||||||
filter_type: bandpass
|
filter_type: bandpass
|
||||||
low_cut: 0.01
|
low_cut: 0.05
|
||||||
high_cut: 0.7
|
high_cut: 0.6
|
||||||
order: 3
|
order: 3
|
||||||
|
|
||||||
resp_low_amp:
|
resp_low_amp:
|
||||||
|
|||||||
@ -72,7 +72,6 @@ def create_psg_bcg_figure():
|
|||||||
axes[psg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white")
|
axes[psg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white")
|
||||||
axes.append(axes[psg_chn_name2ax["bcg"]].twinx())
|
axes.append(axes[psg_chn_name2ax["bcg"]].twinx())
|
||||||
|
|
||||||
|
|
||||||
axes[psg_chn_name2ax["Stage"]].grid(True)
|
axes[psg_chn_name2ax["Stage"]].grid(True)
|
||||||
# axes[7].xaxis.set_major_formatter(Params.FORMATTER)
|
# axes[7].xaxis.set_major_formatter(Params.FORMATTER)
|
||||||
|
|
||||||
@ -183,7 +182,8 @@ def score_mask2alpha(score_mask):
|
|||||||
return alpha_mask
|
return alpha_mask
|
||||||
|
|
||||||
|
|
||||||
def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None):
|
def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None, verbose=True,
|
||||||
|
multi_p=None, multi_task_id=None):
|
||||||
if save_path is not None:
|
if save_path is not None:
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@ -197,7 +197,7 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list,
|
|||||||
# event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0)
|
# event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0)
|
||||||
|
|
||||||
fig, axes = create_psg_bcg_figure()
|
fig, axes = create_psg_bcg_figure()
|
||||||
for segment_start, segment_end in tqdm(segment_list):
|
for i, (segment_start, segment_end) in enumerate(segment_list):
|
||||||
for ax in axes:
|
for ax in axes:
|
||||||
ax.cla()
|
ax.cla()
|
||||||
|
|
||||||
@ -213,17 +213,21 @@ def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list,
|
|||||||
psg_label, event_codes=[1, 2, 3, 4])
|
psg_label, event_codes=[1, 2, 3, 4])
|
||||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
|
plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end)
|
||||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
|
plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end,
|
||||||
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["resp_twinx"]])
|
event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4],
|
||||||
|
ax2=axes[psg_chn_name2ax["resp_twinx"]])
|
||||||
plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
|
plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end,
|
||||||
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["bcg_twinx"]])
|
event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4],
|
||||||
|
ax2=axes[psg_chn_name2ax["bcg_twinx"]])
|
||||||
|
|
||||||
|
|
||||||
if save_path is not None:
|
if save_path is not None:
|
||||||
fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png")
|
fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png")
|
||||||
tqdm.write(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
|
|
||||||
# print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
|
# print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}")
|
||||||
|
|
||||||
|
if multi_p is not None:
|
||||||
|
multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"}
|
||||||
|
|
||||||
|
|
||||||
def draw_resp_label(resp_data, resp_label, segment_list):
|
def draw_resp_label(resp_data, resp_label, segment_list):
|
||||||
for mask in resp_label.keys():
|
for mask in resp_label.keys():
|
||||||
if mask.startswith("Resp_"):
|
if mask.startswith("Resp_"):
|
||||||
|
|||||||
@ -52,7 +52,7 @@ def process_one_signal(samp_id, show=False):
|
|||||||
save_samp_path = save_path / f"{samp_id}"
|
save_samp_path = save_path / f"{samp_id}"
|
||||||
save_samp_path.mkdir(parents=True, exist_ok=True)
|
save_samp_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, verbose=True)
|
signal_data_raw, signal_length, signal_fs, signal_second = utils.read_signal_txt(signal_path, dtype=float, verbose=True)
|
||||||
|
|
||||||
signal_data, resp_data, resp_fs, bcg_data, bcg_fs = signal_method.signal_filter_split(conf=conf, signal_data_raw=signal_data_raw, signal_fs=signal_fs)
|
signal_data, resp_data, resp_fs, bcg_data, bcg_fs = signal_method.signal_filter_split(conf=conf, signal_data_raw=signal_data_raw, signal_fs=signal_fs)
|
||||||
|
|
||||||
|
|||||||
@ -2,10 +2,11 @@ import numpy as np
|
|||||||
|
|
||||||
import utils
|
import utils
|
||||||
|
|
||||||
def signal_filter_split(conf, signal_data_raw, signal_fs):
|
def signal_filter_split(conf, signal_data_raw, signal_fs, verbose=True):
|
||||||
# 滤波
|
# 滤波
|
||||||
# 50Hz陷波滤波器
|
# 50Hz陷波滤波器
|
||||||
# signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs)
|
# signal_data = utils.butterworth(data=signal_data, _type="bandpass", low_cut=0.5, high_cut=45, order=10, sample_rate=signal_fs)
|
||||||
|
if verbose:
|
||||||
print("Applying 50Hz notch filter...")
|
print("Applying 50Hz notch filter...")
|
||||||
signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs)
|
signal_data = utils.notch_filter(data=signal_data_raw, notch_freq=50.0, quality_factor=30.0, sample_rate=signal_fs)
|
||||||
|
|
||||||
@ -17,6 +18,7 @@ def signal_filter_split(conf, signal_data_raw, signal_fs):
|
|||||||
low_cut=conf["resp_filter"]["low_cut"],
|
low_cut=conf["resp_filter"]["low_cut"],
|
||||||
high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"],
|
high_cut=conf["resp_filter"]["high_cut"], order=conf["resp_filter"]["order"],
|
||||||
sample_rate=resp_fs)
|
sample_rate=resp_fs)
|
||||||
|
if verbose:
|
||||||
print("Begin plotting signal data...")
|
print("Begin plotting signal data...")
|
||||||
|
|
||||||
# fig = plt.figure(figsize=(12, 8))
|
# fig = plt.figure(figsize=(12, 8))
|
||||||
|
|||||||
@ -20,6 +20,7 @@ def read_signal_txt(path: Union[str, Path], dtype, verbose=True, is_peak=False):
|
|||||||
Read a txt file and return the first column as a numpy array.
|
Read a txt file and return the first column as a numpy array.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
:param is_peak:
|
||||||
:param path:
|
:param path:
|
||||||
:param verbose:
|
:param verbose:
|
||||||
:param dtype:
|
:param dtype:
|
||||||
@ -217,14 +218,15 @@ def read_mask_execl(path: Union[str, Path]):
|
|||||||
|
|
||||||
event_list = {"RespAmpChangeSegment": event_mask_2_list(1 - event_mask["Resp_AmpChange_Label"]),
|
event_list = {"RespAmpChangeSegment": event_mask_2_list(1 - event_mask["Resp_AmpChange_Label"]),
|
||||||
"BCGAmpChangeSegment": event_mask_2_list(1 - event_mask["BCG_AmpChange_Label"]),
|
"BCGAmpChangeSegment": event_mask_2_list(1 - event_mask["BCG_AmpChange_Label"]),
|
||||||
"EnableSegment": event_mask_2_list(1 - event_mask["Disable_Label"]),}
|
"EnableSegment": event_mask_2_list(1 - event_mask["Disable_Label"]),
|
||||||
|
"DisableSegment": event_mask_2_list(event_mask["Disable_Label"])}
|
||||||
|
|
||||||
|
|
||||||
return event_mask, event_list
|
return event_mask, event_list
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def read_psg_channel(path_str: Union[str, Path], channel_number: list[int]):
|
def read_psg_channel(path_str: Union[str, Path], channel_number: list[int], verbose=True):
|
||||||
"""
|
"""
|
||||||
读取PSG文件中特定通道的数据。
|
读取PSG文件中特定通道的数据。
|
||||||
|
|
||||||
@ -254,16 +256,16 @@ def read_psg_channel(path_str: Union[str, Path], channel_number: list[int]):
|
|||||||
|
|
||||||
if ch_id == 8:
|
if ch_id == 8:
|
||||||
# sleep stage 特例 读取为整数
|
# sleep stage 特例 读取为整数
|
||||||
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=str, verbose=True)
|
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=str, verbose=verbose)
|
||||||
# 转换为整数数组
|
# 转换为整数数组
|
||||||
for stage_str, stage_number in utils.Stage2N.items():
|
for stage_str, stage_number in utils.Stage2N.items():
|
||||||
np.place(ch_signal, ch_signal == stage_str, stage_number)
|
np.place(ch_signal, ch_signal == stage_str, stage_number)
|
||||||
ch_signal = ch_signal.astype(int)
|
ch_signal = ch_signal.astype(int)
|
||||||
elif ch_id == 1:
|
elif ch_id == 1:
|
||||||
# Rpeak 特例 读取为整数
|
# Rpeak 特例 读取为整数
|
||||||
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=int, verbose=True, is_peak=True)
|
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=int, verbose=verbose, is_peak=True)
|
||||||
else:
|
else:
|
||||||
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=float, verbose=True)
|
ch_signal, ch_length, ch_fs, ch_second = read_signal_txt(ch_path[0], dtype=float, verbose=verbose)
|
||||||
channel_data[ch_name] = {
|
channel_data[ch_name] = {
|
||||||
"name": ch_name,
|
"name": ch_name,
|
||||||
"path": ch_path[0],
|
"path": ch_path[0],
|
||||||
|
|||||||
@ -1,27 +1,58 @@
|
|||||||
|
|
||||||
|
def check_split(event_mask, current_start, window_sec, verbose=False):
|
||||||
|
# 检查当前窗口是否包含在禁用区间或低幅值区间内
|
||||||
|
resp_movement_mask = event_mask["Resp_Movement_Label"][current_start : current_start + window_sec]
|
||||||
|
resp_low_amp_mask = event_mask["Resp_LowAmp_Label"][current_start : current_start + window_sec]
|
||||||
|
|
||||||
|
# 体动与低幅值进行与计算
|
||||||
|
low_move_mask = resp_movement_mask | resp_low_amp_mask
|
||||||
|
if low_move_mask.sum() > 2/3 * window_sec:
|
||||||
|
if verbose:
|
||||||
|
print(f"{current_start}-{current_start + window_sec} rejected due to movement/low amplitude mask more than 2/3")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def resp_split(dataset_config, event_mask, event_list, verbose=False):
|
||||||
def resp_split(dataset_config, event_mask, event_list):
|
|
||||||
# 提取体动区间和呼吸低幅值区间
|
# 提取体动区间和呼吸低幅值区间
|
||||||
enable_list = event_list["EnableSegment"]
|
enable_list = event_list["EnableSegment"]
|
||||||
|
disable_list = event_list["DisableSegment"]
|
||||||
|
|
||||||
# 读取数据集配置
|
# 读取数据集配置
|
||||||
window_sec = dataset_config["window_sec"]
|
window_sec = dataset_config["window_sec"]
|
||||||
stride_sec = dataset_config["stride_sec"]
|
stride_sec = dataset_config["stride_sec"]
|
||||||
|
|
||||||
segment_list = []
|
segment_list = []
|
||||||
|
disable_segment_list = []
|
||||||
|
|
||||||
# 遍历每个enable区间, 如果最后一个窗口不足stride的1/2,则舍弃,否则以enable_end为结尾截取一个窗口
|
# 遍历每个enable区间, 如果最后一个窗口不足stride的1/2,则舍弃,否则以enable_end为结尾截取一个窗口
|
||||||
for enable_start, enable_end in enable_list:
|
for enable_start, enable_end in enable_list:
|
||||||
current_start = enable_start
|
current_start = enable_start
|
||||||
while current_start + window_sec <= enable_end:
|
while current_start + window_sec <= enable_end:
|
||||||
|
if check_split(event_mask, current_start, window_sec, verbose):
|
||||||
segment_list.append((current_start, current_start + window_sec))
|
segment_list.append((current_start, current_start + window_sec))
|
||||||
|
else:
|
||||||
|
disable_segment_list.append((current_start, current_start + window_sec))
|
||||||
current_start += stride_sec
|
current_start += stride_sec
|
||||||
# 检查最后一个窗口是否需要添加
|
# 检查最后一个窗口是否需要添加
|
||||||
if (enable_end - current_start >= stride_sec / 2) and (enable_end - current_start >= window_sec):
|
if (enable_end - current_start >= stride_sec / 2) and (enable_end - current_start >= window_sec):
|
||||||
|
if check_split(event_mask, enable_end - window_sec, window_sec, verbose):
|
||||||
segment_list.append((enable_end - window_sec, enable_end))
|
segment_list.append((enable_end - window_sec, enable_end))
|
||||||
|
else:
|
||||||
|
disable_segment_list.append((enable_end - window_sec, enable_end))
|
||||||
|
|
||||||
return segment_list
|
# 遍历每个disable区间, 如果最后一个窗口不足stride的1/2,则舍弃,否则以disable_end为结尾截取一个窗口
|
||||||
|
for disable_start, disable_end in disable_list:
|
||||||
|
current_start = disable_start
|
||||||
|
while current_start + window_sec <= disable_end:
|
||||||
|
disable_segment_list.append((current_start, current_start + window_sec))
|
||||||
|
current_start += stride_sec
|
||||||
|
# 检查最后一个窗口是否需要添加
|
||||||
|
if (disable_end - current_start >= stride_sec / 2) and (disable_end - current_start >= window_sec):
|
||||||
|
disable_segment_list.append((disable_end - window_sec, disable_end))
|
||||||
|
|
||||||
|
|
||||||
|
return segment_list, disable_segment_list
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user