from matplotlib.axes import Axes from matplotlib.gridspec import GridSpec from matplotlib.colors import PowerNorm import matplotlib.pyplot as plt import matplotlib.patches as mpatches import seaborn as sns import numpy as np from tqdm.rich import tqdm import utils # 添加with_prediction参数 psg_chn_name2ax = { "SpO2": 0, "Flow T": 1, "Flow P": 2, "Effort Tho": 3, "Effort Abd": 4, "HR": 5, "resp": 6, "bcg": 7, "Stage": 8, "resp_twinx": 9, "bcg_twinx": 10, } resp_chn_name2ax = { "resp": 0, "bcg": 1, } def create_psg_bcg_figure(): fig = plt.figure(figsize=(12, 8), dpi=200) gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1]) fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) axes = [] for i in range(9): ax = fig.add_subplot(gs[i]) axes.append(ax) axes[psg_chn_name2ax["SpO2"]].grid(True) # axes[0].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_chn_name2ax["SpO2"]].set_ylim((85, 100)) axes[psg_chn_name2ax["SpO2"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["Flow T"]].grid(True) # axes[1].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_chn_name2ax["Flow T"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["Flow P"]].grid(True) # axes[2].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_chn_name2ax["Flow P"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["Effort Tho"]].grid(True) # axes[3].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_chn_name2ax["Effort Tho"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["Effort Abd"]].grid(True) # axes[4].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_chn_name2ax["Effort Abd"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["HR"]].grid(True) axes[psg_chn_name2ax["HR"]].tick_params(axis='x', colors="white") axes[psg_chn_name2ax["resp"]].grid(True) axes[psg_chn_name2ax["resp"]].tick_params(axis='x', colors="white") axes.append(axes[psg_chn_name2ax["resp"]].twinx()) axes[psg_chn_name2ax["bcg"]].grid(True) # axes[5].xaxis.set_major_formatter(Params.FORMATTER) axes[psg_chn_name2ax["bcg"]].tick_params(axis='x', colors="white") axes.append(axes[psg_chn_name2ax["bcg"]].twinx()) axes[psg_chn_name2ax["Stage"]].grid(True) # axes[7].xaxis.set_major_formatter(Params.FORMATTER) return fig, axes def create_resp_figure(): fig = plt.figure(figsize=(12, 6), dpi=100) gs = GridSpec(2, 1, height_ratios=[3, 2]) fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) axes = [] for i in range(2): ax = fig.add_subplot(gs[i]) axes.append(ax) axes[0].grid(True) # axes[0].xaxis.set_major_formatter(Params.FORMATTER) axes[0].tick_params(axis='x', colors="white") axes[1].grid(True) # axes[1].xaxis.set_major_formatter(Params.FORMATTER) axes[1].tick_params(axis='x', colors="white") return fig, axes def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None, event_codes: list[int] = None, multi_labels=None, ax2: Axes = None): signal_fs = signal_data["fs"] chn_signal = signal_data["data"] time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs) ax.plot(time_axis, chn_signal[segment_start * signal_fs:segment_end * signal_fs], color='black', label=signal_data["name"]) if event_mask is not None: if multi_labels is None and event_codes is not None: for event_code in event_codes: mask = event_mask[segment_start:segment_end].repeat(signal_fs) == event_code y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * mask).astype(float) np.place(y, y == 0, np.nan) ax.plot(time_axis, y, color=utils.ColorCycle[event_code]) elif multi_labels == "resp" and event_codes is not None: ax.set_ylim(-6, 6) # 建立第二个y轴坐标 ax2.cla() ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, color='blue', alpha=0.8, label='Low Amplitude Mask') ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, color='orange', alpha=0.8, label='Movement Mask') ax2.plot(time_axis, event_mask["Resp_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3, color='green', alpha=0.8, label='Amplitude Change Mask') for event_code in event_codes: sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code score_mask = event_mask["SA_Score_Alpha"][segment_start:segment_end].repeat(signal_fs) # y = (sa_mask * score_mask).astype(float) y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * sa_mask).astype(float) np.place(y, y == 0, np.nan) ax.plot(time_axis, y, color=utils.ColorCycle[event_code]) ax2.plot(time_axis, score_mask, color="orange") ax2.set_ylim(-4, 5) elif multi_labels == "bcg" and event_codes is not None: # 建立第二个y轴坐标 ax2.cla() ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, color='blue', alpha=0.8, label='Low Amplitude Mask') ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, color='orange', alpha=0.8, label='Movement Mask') ax2.plot(time_axis, event_mask["BCG_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3, color='green', alpha=0.8, label='Amplitude Change Mask') ax2.set_ylim(-4, 4) ax.set_ylabel("Amplitude") ax.legend(loc=1) def plt_stage_on_ax(ax, stage_data, segment_start, segment_end): stage_signal = stage_data["data"] stage_fs = stage_data["fs"] time_axis = np.linspace(segment_start / stage_fs, segment_end / stage_fs, segment_end - segment_start) ax.plot(time_axis, stage_signal[segment_start:segment_end], color='black', label=stage_data["name"]) ax.set_ylim(0, 6) ax.set_yticks([1, 2, 3, 4, 5]) ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"]) ax.set_ylabel("Stage") ax.legend(loc=1) def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end): spo2_signal = spo2_data["data"] spo2_fs = spo2_data["fs"] time_axis = np.linspace(segment_start / spo2_fs, segment_end / spo2_fs, segment_end - segment_start) ax.plot(time_axis, spo2_signal[segment_start:segment_end], color='black', label=spo2_data["name"]) if spo2_signal[segment_start:segment_end].min() < 85: ax.set_ylim((spo2_signal[segment_start:segment_end].min() - 5, 100)) else: ax.set_ylim((85, 100)) ax.set_ylabel("SpO2 (%)") ax.legend(loc=1) def score_mask2alpha(score_mask): alpha_mask = np.zeros_like(score_mask, dtype=float) alpha_mask[score_mask <= 0] = 0 alpha_mask[score_mask == 1] = 0.9 alpha_mask[score_mask == 2] = 0.6 alpha_mask[score_mask == 3] = 0.1 return alpha_mask def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list, save_path=None): if save_path is not None: save_path.mkdir(parents=True, exist_ok=True) for mask in event_mask.keys(): if mask.startswith("Resp_") or mask.startswith("BCG_"): event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0) event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"]) event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0) fig, axes = create_psg_bcg_figure() for segment_start, segment_end in tqdm(segment_list): for ax in axes: ax.cla() plt_spo2_on_ax(axes[psg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end) plt_stage_on_ax(axes[psg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end) plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end, event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["resp_twinx"]]) plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end, event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4], ax2=axes[psg_chn_name2ax["bcg_twinx"]]) if save_path is not None: fig.savefig(save_path / f"Segment_{segment_start}_{segment_end}.png") tqdm.write(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}") # print(f"Saved figure to: {save_path / f'Segment_{segment_start}_{segment_end}.png'}") def draw_resp_label(resp_data, resp_label, segment_list): for mask in resp_label.keys(): if mask.startswith("Resp_"): resp_label[mask] = utils.none_to_nan_mask(resp_label[mask], 0) resp_label["Resp_Score_Alpha"] = score_mask2alpha(resp_label["Resp_Score"]) resp_label["Resp_Label_Alpha"] = utils.none_to_nan_mask(resp_label["Resp_Label_Alpha"], 0) fig, axes = create_resp_figure() for segment_start, segment_end in segment_list: for ax in axes: ax.cla() plt_signal_label_on_ax(axes[resp_chn_name2ax["resp"]], resp_data["resp_signal"], segment_start, segment_end, resp_label, multi_labels="resp", event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[resp_chn_name2ax["bcg"]], resp_data["bcg_signal"], segment_start, segment_end, resp_label, multi_labels="bcg", event_codes=[1, 2, 3, 4]) plt.show()