from matplotlib.axes import Axes from matplotlib.gridspec import GridSpec from matplotlib.colors import PowerNorm import matplotlib.pyplot as plt import matplotlib.patches as mpatches import seaborn as sns import numpy as np from tqdm.rich import tqdm import utils import gc # 添加with_prediction参数 psg_bcg_chn_name2ax = { "SpO2": 0, "Flow T": 1, "Flow P": 2, "Effort Tho": 3, "Effort Abd": 4, "HR": 5, "resp": 6, "bcg": 7, "Stage": 8, "resp_twinx": 9, "bcg_twinx": 10, } psg_chn_name2ax = { "SpO2": 0, "Flow T": 1, "Flow P": 2, "Effort Tho": 3, "Effort Abd": 4, "Effort": 5, "HR": 6, "RRI": 7, "Stage": 8, } resp_chn_name2ax = { "resp": 0, "bcg": 1, } def create_psg_bcg_figure(): fig = plt.figure(figsize=(12, 8), dpi=200) gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1]) fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) axes = [] for i in range(9): ax = fig.add_subplot(gs[i]) axes.append(ax) axes[psg_bcg_chn_name2ax["SpO2"]].grid(True) # axes[0].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_bcg_chn_name2ax["SpO2"]].set_ylim((85, 100)) axes[psg_bcg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white") axes[psg_bcg_chn_name2ax["Flow T"]].grid(True) # axes[1].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_bcg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white") axes[psg_bcg_chn_name2ax["Flow P"]].grid(True) # axes[2].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_bcg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white") axes[psg_bcg_chn_name2ax["Effort Tho"]].grid(True) # axes[3].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_bcg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white") axes[psg_bcg_chn_name2ax["Effort Abd"]].grid(True) # axes[4].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_bcg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white") axes[psg_bcg_chn_name2ax["HR"]].grid(True) axes[psg_bcg_chn_name2ax["HR"]].tick_params(axis='x', colors="white") axes[psg_bcg_chn_name2ax["resp"]].grid(True) axes[psg_bcg_chn_name2ax["resp"]].tick_params(axis='x', colors="white") axes.append(axes[psg_bcg_chn_name2ax["resp"]].twinx()) axes[psg_bcg_chn_name2ax["bcg"]].grid(True) # axes[5].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_bcg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white") axes.append(axes[psg_bcg_chn_name2ax["bcg"]].twinx()) axes[psg_bcg_chn_name2ax["Stage"]].grid(True) # axes[7].xaxis.set_major_formatter(Params.FORMATTER) return fig, axes def create_psg_figure(): fig = plt.figure(figsize=(12, 8), dpi=200) gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 1, 1, 1]) fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) axes = [] for i in range(9): ax = fig.add_subplot(gs[i]) axes.append(ax) axes[psg_chn_name2ax["SpO2"]].grid(True) # axes[0].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_chn_name2ax["SpO2"]].set_ylim((85, 100)) axes[psg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["Flow T"]].grid(True) # axes[1].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["Flow P"]].grid(True) # axes[2].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["Effort Tho"]].grid(True) # axes[3].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["Effort Abd"]].grid(True) # axes[4].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["Effort"]].grid(True) axes[psg_chn_name2ax["Effort"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["HR"]].grid(True) axes[psg_chn_name2ax["HR"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["RRI"]].grid(True) axes[psg_chn_name2ax["RRI"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["Stage"]].grid(True) return fig, axes def create_resp_figure(): fig = plt.figure(figsize=(12, 6), dpi=100) gs = GridSpec(2, 1, height_ratios=[3, 2]) fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) axes = [] for i in range(2): ax = fig.add_subplot(gs[i]) axes.append(ax) axes[0].grid(True) # axes[0].xaxis.set_major_formatter(Params.FORMATTER) axes[0].tick_params(axis='x', colors="white") axes[1].grid(True) # axes[1].xaxis.set_major_formatter(Params.FORMATTER) axes[1].tick_params(axis='x', colors="white") return fig, axes def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None, event_codes: list[int] = None, multi_labels=None, ax2: Axes = None): signal_fs = signal_data["fs"] chn_signal = signal_data["data"] time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs) ax.plot(time_axis, chn_signal[segment_start * signal_fs:segment_end * signal_fs], color='black', label=signal_data["name"]) if event_mask is not None: if multi_labels is None and event_codes is not None: for event_code in event_codes: mask = event_mask[segment_start:segment_end].repeat(signal_fs) == event_code y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * mask).astype(float) np.place(y, y == 0, np.nan) ax.plot(time_axis, y, color=utils.ColorCycle[event_code]) elif multi_labels == "resp" and event_codes is not None: ax.set_ylim(-6, 6) # 建立第二个y轴坐标 ax2.cla() ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, color='blue', alpha=0.8, label='Low Amplitude Mask') ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, color='orange', alpha=0.8, label='Movement Mask') ax2.plot(time_axis, event_mask["Resp_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3, color='green', alpha=0.8, label='Amplitude Change Mask') for event_code in event_codes: sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code score_mask = event_mask["SA_Score"][segment_start:segment_end].repeat(signal_fs) # y = (sa_mask * score_mask).astype(float) y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * sa_mask).astype(float) np.place(y, y == 0, np.nan) ax.plot(time_axis, y, color=utils.ColorCycle[event_code]) ax2.plot(time_axis, score_mask, color="orange") ax2.set_ylim(-4, 5) elif multi_labels == "bcg" and event_codes is not None: # 建立第二个y轴坐标 ax2.cla() ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, color='blue', alpha=0.8, label='Low Amplitude Mask') ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, color='orange', alpha=0.8, label='Movement Mask') ax2.plot(time_axis, event_mask["BCG_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3, color='green', alpha=0.8, label='Amplitude Change Mask') ax2.set_ylim(-4, 4) ax.set_ylabel("Amplitude") ax.legend(loc=1) def plt_stage_on_ax(ax, stage_data, segment_start, segment_end): stage_signal = stage_data["data"] stage_fs = stage_data["fs"] time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * stage_fs) ax.plot(time_axis, stage_signal[segment_start * stage_fs:segment_end * stage_fs], color='black', label=stage_data["name"]) ax.set_ylim(0, 6) ax.set_yticks([1, 2, 3, 4, 5]) ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"]) ax.set_ylabel("Stage") ax.legend(loc=1) def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end): spo2_signal = spo2_data["data"] spo2_fs = spo2_data["fs"] time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * spo2_fs) ax.plot(time_axis, spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs], color='black', label=spo2_data["name"]) if spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() < 85: ax.set_ylim((spo2_signal[segment_start * spo2_fs:segment_end * spo2_fs].min() - 5, 100)) else: ax.set_ylim((85, 100)) ax.set_ylabel("SpO2 (%)") ax.legend(loc=1) def score_mask2alpha(score_mask): alpha_mask = np.zeros_like(score_mask, dtype=float) alpha_mask[score_mask <= 0] = 0 alpha_mask[score_mask == 1] = 0.9 alpha_mask[score_mask == 2] = 0.6 alpha_mask[score_mask == 3] = 0.1 return alpha_mask def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None, verbose=True, multi_p=None, multi_task_id=None): if save_path is not None: save_path.mkdir(parents=True, exist_ok=True) for mask in event_mask.keys(): if mask.startswith("Resp_") or mask.startswith("BCG_"): event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0) event_mask["SA_Score"] = utils.none_to_nan_mask(event_mask["SA_Score"], 0) # event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"]) # event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0) fig, axes = create_psg_bcg_figure() for i, (segment_start, segment_end) in enumerate(segment_list): for ax in axes: ax.cla() plt_spo2_on_ax(axes[psg_bcg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end) plt_stage_on_ax(axes[psg_bcg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end) plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end, event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], ax2=axes[psg_bcg_chn_name2ax["resp_twinx"]]) plt_signal_label_on_ax(axes[psg_bcg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end, event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], ax2=axes[psg_bcg_chn_name2ax["bcg_twinx"]]) if save_path is not None: fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png") # print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}") if multi_p is not None: multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"} plt.close(fig) plt.close('all') gc.collect() def draw_psg_label(psg_data, psg_label, segment_list, save_path=None, verbose=True, multi_p=None, multi_task_id=None): if save_path is not None: save_path.mkdir(parents=True, exist_ok=True) if multi_p is None: # 遍历psg_data中所有数据的长度 for i in range(len(psg_data.keys())): chn_name = list(psg_data.keys())[i] print(f"{chn_name} data length: {len(psg_data[chn_name]['data'])}, fs: {psg_data[chn_name]['fs']}") # psg_label的长度 print(f"psg_label length: {len(psg_label)}") fig, axes = create_psg_figure() for i, (segment_start, segment_end) in enumerate(segment_list): for ax in axes: ax.cla() plt_spo2_on_ax(axes[psg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end) plt_stage_on_ax(axes[psg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end) plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort"]], psg_data["Effort"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) plt_signal_label_on_ax(axes[psg_chn_name2ax["RRI"]], psg_data["RRI"], segment_start, segment_end) if save_path is not None: fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png") # print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}") if multi_p is not None: multi_p[multi_task_id] = {"progress": i + 1, "total": len(segment_list), "desc": f"task_id:{multi_task_id} drawing {save_path.name}"} plt.close(fig) plt.close('all') gc.collect()