from matplotlib.axes import Axes from matplotlib.gridspec import GridSpec from matplotlib.colors import PowerNorm import matplotlib.pyplot as plt import matplotlib.patches as mpatches import seaborn as sns import numpy as np import utils # 添加with_prediction参数 psg_chn_name2ax = { "SpO2": 0, "Flow T": 1, "Flow P": 2, "Effort Tho": 3, "Effort Abd": 4, "HR": 5, "resp": 6, "bcg": 7, "Stage": 8 } resp_chn_name2ax = { "resp": 0, "bcg": 1, } def create_psg_bcg_figure(): fig = plt.figure(figsize=(12, 8), dpi=100) gs = GridSpec(9, 1, height_ratios=[1, 1, 1, 1, 1, 1, 3, 2, 1]) fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) axes = [] for i in range(9): ax = fig.add_subplot(gs[i]) axes.append(ax) axes[0].grid(True) # axes[0].xaxis.set_major_formatter(Params.FORMATTER) axes[0].set_ylim((85, 100)) axes[0].tick_params(axis='x', colors="white") axes[1].grid(True) # axes[1].xaxis.set_major_formatter(Params.FORMATTER) axes[1].tick_params(axis='x', colors="white") axes[2].grid(True) # axes[2].xaxis.set_major_formatter(Params.FORMATTER) axes[2].tick_params(axis='x', colors="white") axes[3].grid(True) # axes[3].xaxis.set_major_formatter(Params.FORMATTER) axes[3].tick_params(axis='x', colors="white") axes[4].grid(True) # axes[4].xaxis.set_major_formatter(Params.FORMATTER) axes[4].tick_params(axis='x', colors="white") axes[5].grid(True) axes[5].tick_params(axis='x', colors="white") axes[6].grid(True) # axes[5].xaxis.set_major_formatter(Params.FORMATTER) axes[6].tick_params(axis='x', colors="white") axes[7].grid(True) # axes[6].xaxis.set_major_formatter(Params.FORMATTER) axes[7].tick_params(axis='x', colors="white") axes[8].grid(True) # axes[7].xaxis.set_major_formatter(Params.FORMATTER) return fig, axes def create_resp_figure(): fig = plt.figure(figsize=(12, 6), dpi=100) gs = GridSpec(2, 1, height_ratios=[3, 2]) fig.subplots_adjust(top=0.98, bottom=0.05, right=0.98, left=0.1, hspace=0, wspace=0) axes = [] for i in range(2): ax = fig.add_subplot(gs[i]) axes.append(ax) axes[0].grid(True) # axes[0].xaxis.set_major_formatter(Params.FORMATTER) axes[0].tick_params(axis='x', colors="white") axes[1].grid(True) # axes[1].xaxis.set_major_formatter(Params.FORMATTER) axes[1].tick_params(axis='x', colors="white") return fig, axes def plt_signal_label_on_ax(ax: Axes, signal_data, segment_start, segment_end, event_mask=None, event_codes: list[int] = None, multi_labels=None): signal_fs = signal_data["fs"] chn_signal = signal_data["data"] time_axis = np.linspace(segment_start, segment_end, (segment_end - segment_start) * signal_fs) ax.plot(time_axis, chn_signal[segment_start * signal_fs:segment_end * signal_fs], color='black', label=signal_data["name"]) if event_mask is not None: if multi_labels is None and event_codes is not None: for event_code in event_codes: mask = event_mask[segment_start:segment_end].repeat(signal_fs) == event_code y = (chn_signal[segment_start * signal_fs:segment_end * signal_fs] * mask).astype(float) np.place(y, y == 0, np.nan) ax.plot(time_axis, y, color=utils.ColorCycle[event_code]) elif multi_labels == "resp" and event_codes is not None: ax.set_ylim(-6, 6) # 建立第二个y轴坐标 ax2 = ax.twinx() ax2.plot(time_axis, event_mask["Resp_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, color='blue', alpha=0.8, label='Low Amplitude Mask') ax2.plot(time_axis, event_mask["Resp_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, color='orange', alpha=0.8, label='Movement Mask') ax2.plot(time_axis, event_mask["Resp_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3, color='green', alpha=0.8, label='Amplitude Change Mask') for event_code in event_codes: sa_mask = event_mask["SA_Label"][segment_start:segment_end].repeat(signal_fs) == event_code score_mask = event_mask["SA_Score_Alpha"][segment_start:segment_end].repeat(signal_fs) y = (sa_mask * score_mask).astype(float) np.place(y, y == 0, np.nan) ax2.plot(time_axis, y, color=utils.ColorCycle[event_code]) ax2.set_ylim(-4, 5) elif multi_labels == "bcg" and event_codes is not None: # 建立第二个y轴坐标 ax2 = ax.twinx() ax2.plot(time_axis, event_mask["BCG_LowAmp_Label"][segment_start:segment_end].repeat(signal_fs) * -1, color='blue', alpha=0.8, label='Low Amplitude Mask') ax2.plot(time_axis, event_mask["BCG_Movement_Label"][segment_start:segment_end].repeat(signal_fs) * -2, color='orange', alpha=0.8, label='Movement Mask') ax2.plot(time_axis, event_mask["BCG_AmpChange_Label"][segment_start:segment_end].repeat(signal_fs) * -3, color='green', alpha=0.8, label='Amplitude Change Mask') ax2.set_ylim(-4, 4) ax.set_ylabel("Amplitude") ax.legend(loc=1) def plt_stage_on_ax(ax, stage_data, segment_start, segment_end): stage_signal = stage_data["data"] stage_fs = stage_data["fs"] time_axis = np.linspace(segment_start / stage_fs, segment_end / stage_fs, segment_end - segment_start) ax.plot(time_axis, stage_signal[segment_start:segment_end], color='black', label=stage_data["name"]) ax.set_ylim(0, 6) ax.set_yticks([1, 2, 3, 4, 5]) ax.set_yticklabels(["N3", "N2", "N1", "REM", "Awake"]) ax.set_ylabel("Stage") ax.legend(loc=1) def plt_spo2_on_ax(ax: Axes, spo2_data, segment_start, segment_end): spo2_signal = spo2_data["data"] spo2_fs = spo2_data["fs"] time_axis = np.linspace(segment_start / spo2_fs, segment_end / spo2_fs, segment_end - segment_start) ax.plot(time_axis, spo2_signal[segment_start:segment_end], color='black', label=spo2_data["name"]) if spo2_signal[segment_start:segment_end].min() < 85: ax.set_ylim((spo2_signal[segment_start:segment_end].min() - 5, 100)) else: ax.set_ylim((85, 100)) ax.set_ylabel("SpO2 (%)") ax.legend(loc=1) def score_mask2alpha(score_mask): alpha_mask = np.zeros_like(score_mask, dtype=float) alpha_mask[score_mask <= 0] = 0 alpha_mask[score_mask == 1] = 0.9 alpha_mask[score_mask == 2] = 0.6 alpha_mask[score_mask == 3] = 0.1 return alpha_mask def draw_psg_bcg_label(psg_data, psg_label, bcg_data, event_mask, segment_list): for mask in event_mask.keys(): if mask.startswith("Resp_") or mask.endswith("BCG_"): event_mask[mask] = utils.none_to_nan_mask(event_mask[mask], 0) event_mask["SA_Score_Alpha"] = score_mask2alpha(event_mask["SA_Score"]) event_mask["SA_Score_Alpha"] = utils.none_to_nan_mask(event_mask["SA_Score_Alpha"], 0) fig, axes = create_psg_bcg_figure() for segment_start, segment_end in segment_list: print(f"Drawing segment: {segment_start} to {segment_end} seconds") for ax in axes: ax.cla() plt_spo2_on_ax(axes[psg_chn_name2ax["SpO2"]], psg_data["SpO2"], segment_start, segment_end) plt_stage_on_ax(axes[psg_chn_name2ax["Stage"]], psg_data["5_class"], segment_start, segment_end) plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow T"]], psg_data["Flow T"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["Flow P"]], psg_data["Flow P"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Tho"]], psg_data["Effort Tho"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["Effort Abd"]], psg_data["Effort Abd"], segment_start, segment_end, psg_label, event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["HR"]], psg_data["HR"], segment_start, segment_end) plt_signal_label_on_ax(axes[psg_chn_name2ax["resp"]], bcg_data["resp_signal"], segment_start, segment_end, event_mask, multi_labels="resp", event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[psg_chn_name2ax["bcg"]], bcg_data["bcg_signal"], segment_start, segment_end, event_mask, multi_labels="bcg", event_codes=[1, 2, 3, 4]) plt.show() print(f"Finished drawing segment: {segment_start} to {segment_end} seconds") def draw_resp_label(resp_data, resp_label, segment_list): for mask in resp_label.keys(): if mask.startswith("Resp_"): resp_label[mask] = utils.none_to_nan_mask(resp_label[mask], 0) resp_label["Resp_Score_Alpha"] = score_mask2alpha(resp_label["Resp_Score"]) resp_label["Resp_Label_Alpha"] = utils.none_to_nan_mask(resp_label["Resp_Label_Alpha"], 0) fig, axes = create_resp_figure() for segment_start, segment_end in segment_list: for ax in axes: ax.cla() plt_signal_label_on_ax(axes[resp_chn_name2ax["resp"]], resp_data["resp_signal"], segment_start, segment_end, resp_label, multi_labels="resp", event_codes=[1, 2, 3, 4]) plt_signal_label_on_ax(axes[resp_chn_name2ax["bcg"]], resp_data["bcg_signal"], segment_start, segment_end, resp_label, multi_labels="bcg", event_codes=[1, 2, 3, 4]) plt.show()