import numpy as np import pandas as pd from scipy import signal from matplotlib import pyplot as plt def calculate_hr_spo2_zhihu(X, fs, FFT_size=512): """ X : ndarray, shape (N, 2) # 第0列 RED,第1列 IR fs : 采样率 (Hz) FFT_size : FFT 点数(建议 2 的幂) 返回值 : (Heart_Rate, SpO2_Level) 已经取整 https://zhuanlan.zhihu.com/p/658858641 https://github.com/thinkng/ppgprocessing/blob/master/HR_SpO2_Estimation.m """ X = np.asarray(X) if X.ndim != 2 or X.shape[1] != 2: raise ValueError("X 必须是 (样本数, 2) 的数组,列0=RED,列1=IR") # 计算能处理的完整窗口数量(每个窗口长度 = FFT_size) step = fs # MATLAB 中是 n*fs 开始,步长为 fs(1 秒) n_windows = int((len(X) / (2 * fs)) - 2) # 与原 MATLAB 一致 HEART_RATE = np.zeros(n_windows) SpO2 = np.zeros(n_windows) for n in range(n_windows): start_idx = n * step end_idx = start_idx + FFT_size y1 = X[start_idx:end_idx, 0] # RED y2 = X[start_idx:end_idx, 1] # IR # ---------- FFT RED ---------- Y1 = np.fft.fft(y1, n=FFT_size) Y1_abs = np.abs(Y1[:FFT_size//2 + 1]) f1 = fs / 2 * np.linspace(0, 1, FFT_size//2 + 1) # ---------- FFT IR ---------- Y2 = np.fft.fft(y2, n=FFT_size) Y2_abs = np.abs(Y2[:FFT_size//2 + 1]) f2 = f1.copy() # 与 f1 完全相同 # ---------- 在 0.5~2.5 Hz(对应心率 30~150 bpm)范围内找局部最大 ---------- # MATLAB 中索引 6:12 对应频率大约 0.6~1.4 Hz(取决于 FFT_size/fs) # 这里统一取频率 0.5~2.5 Hz 对应的索引 idx_range = np.where((f1 >= 0.5) & (f1 <= 2.5))[0] # RED 峰值索引 segment_red = Y1_abs[idx_range] local_max_i = np.argmax(segment_red) pk_RED_i = idx_range[local_max_i] # IR 峰值索引 segment_ir = Y2_abs[idx_range] local_max_i = np.argmax(segment_ir) pk_IR_i = idx_range[local_max_i] # ---------- 心率 ---------- heart_rate_bpm = f2[pk_IR_i] * 60 HEART_RATE[n] = heart_rate_bpm # ---------- SpO2 ---------- R_RED = Y1_abs[pk_RED_i] / (Y1_abs[0] + 1e-12) # 防止除以 0 R_IR = Y2_abs[pk_IR_i] / (Y2_abs[0] + 1e-12) R = R_RED / R_IR spo2 = 104 - 28 * R SpO2[n] = spo2 # 去掉首尾(与原 MATLAB 相同) if len(HEART_RATE) > 2: HR_mean = np.mean(HEART_RATE[1:-1]) SpO2_mean = np.mean(SpO2[1:-1]) else: HR_mean = np.mean(HEART_RATE) SpO2_mean = np.mean(SpO2) Heart_Rate = round(HR_mean) SpO2_Level = round(SpO2_mean) return Heart_Rate, SpO2_Level def _culculate_spo2(ir_list_data, red_list_data): ir_dc = min(ir_list_data) red_dc = min(red_list_data) ir_ac = max(ir_list_data) - ir_dc red_ac = max(red_list_data) - red_dc temp1 = ir_ac * red_dc if temp1 < 1: temp1 = 1 R2 = (red_ac * ir_dc) / temp1 SPO2 = -45.060 * R2 * R2 + 30.354 * R2 + 94.845 if SPO2 > 100 or SPO2 < 0: SPO2 = 0 return SPO2 def _culculate_HR(ir_list_data_filtered, data_list_time): HR_num = signal.find_peaks(ir_list_data_filtered, distance=10)[0] time = data_list_time[-1] -data_list_time[0] HR = len(HR_num) / (time / 1000) * 60 return HR def process_signal(signal_segment, fs, highpass=True): if highpass: h_b, h_a = signal.butter(N=8, Wn=1/(fs/2), btype="highpass", output="ba") data = signal.filtfilt(h_b, h_a, signal_segment, axis=0) else: data = signal_segment data = signal.detrend(data, axis=0, type='linear', bp=0, overwrite_data=False) return data def ppg2spo2_pipeline(red, ir, fs=25): """ 采用滑窗分析法计算心率和血氧饱和度 每秒中输出结果 red : ndarray, 红光信号 ir : ndarray, 红外光信号 fs : 采样率 (Hz) """ red = np.asarray(red).reshape(-1) ir = np.asarray(ir).reshape(-1) if len(red) != len(ir): raise ValueError("红光和红外光信号长度必须相同") red_filtered = process_signal(red, fs, highpass=False) ir_filtered = process_signal(ir, fs, highpass=True) bpm_list_data = [] spo2_list_data = [] temp_bpm_list_data = [] temp_spo2_list_data = [] for i in range(len(red)//fs - 1): red_segment = red_filtered[i*fs:(i+1)*fs] ir_segment = ir_filtered[i*fs:(i+1)*fs] spo2 = _culculate_spo2(ir_segment, red_segment) bpm = _culculate_HR(ir_segment, np.arange(i*fs, (i+1)*fs) * (1000/fs)) temp_bpm_list_data.append(bpm) temp_spo2_list_data.append(spo2) # matlab # python plt.figure(figsize=(10, 5)) timestamp = np.linspace(0, len(red_filtered)/fs, len(red_filtered)) ax1 = plt.subplot(211) plt.plot( timestamp, red, label='Red Signal', color='red', alpha=0.5) plt.plot(timestamp,ir, label='IR Signal', color='blue', alpha=0.5) plt.title('Raw PPG Signals') plt.xlabel("seconds") plt.ylabel('Amplitude') plt.legend() plt.subplot(212, sharex=ax1) plt.plot(timestamp,red_filtered, label='Filtered Red Signal', color='red', alpha=0.5) plt.plot(timestamp,ir_filtered, label='Filtered IR Signal', color='blue', alpha=0.5) plt.title('Filtered PPG Signals') plt.xlabel("seconds") plt.ylabel('Amplitude') plt.legend() plt.show() def func_1(red, ir, fs=25): # Processing_PPG_Signal # make a move window find min and max of ArrayIR def movmin1(A, k): x = A.rolling(k, min_periods=1, center=True).min().to_numpy() # return x def movmax1(A, k): x = A.rolling(k, min_periods=1, center=True).max().to_numpy() return x ArrayIR = pd.DataFrame(ir) ArrayRed = pd.DataFrame(red) # calculate ac/dc ir max_ir = movmax1(ArrayIR, fs) # print(f"max_ir: {max_ir}") min_ir = movmin1(ArrayIR, fs) # print(f"min_ir: {min_ir}") baseline_data_ir = (max_ir + min_ir) / 2 # print(f"baseline_data_ir: {baseline_data_ir}") acDivDcIr = (max_ir - min_ir) / baseline_data_ir # calculate ac/dc red max_red = movmax1(ArrayRed, fs) min_red = movmin1(ArrayRed, fs) baseline_data_red = (max_red + min_red) / 2 acDivDcRed = (max_red - min_red) / baseline_data_red # Plot SPO2 = 110-25*(ac/dc_red)/(ac/dc_ir) SPO2 = 110 - 25 * (acDivDcRed / acDivDcIr) # plt.figure("SPO2") timestamp = np.linspace(0, len(red) / fs, len(red)) plt.figure(figsize=(10, 5)) ax1 = plt.subplot(311) plt.plot(timestamp, red, label='Red Signal', color='red', alpha=0.5) plt.plot(timestamp, ir, label='IR Signal', color='blue', alpha=0.5) plt.title('Raw PPG Signals') plt.xlabel("seconds") plt.ylabel('Amplitude') plt.legend() plt.subplot(312, sharex=ax1) plt.plot(timestamp, red - baseline_data_red, label='Detrended Red Signal', color='red', alpha=0.5) plt.plot(timestamp, ir - baseline_data_ir, label='Detrended IR Signal', color='blue', alpha=0.5) plt.title('Detrended PPG Signals') plt.xlabel("seconds") plt.ylabel('Amplitude') plt.legend() plt.subplot(313, sharex=ax1) plt.plot(timestamp,acDivDcRed, label='AC/DC Red', color='red', alpha=0.5) plt.plot(timestamp,acDivDcIr, label='AC/DC IR', color='blue', alpha=0.5) plt.title('AC/DC Ratios') plt.xlabel("seconds") plt.ylabel('Ratio') plt.legend() plt.show() plt.xlabel("Samples") plt.ylabel("SPO2") plt.title("SPO2") plt.plot(timestamp, SPO2) plt.show()