diff --git a/.gitignore b/.gitignore index 5d381cc..f43a180 100644 --- a/.gitignore +++ b/.gitignore @@ -158,5 +158,9 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ + +Data/* +Output/* +SADetectModel/*.pt diff --git a/ApneaDetection.py b/ApneaDetection.py new file mode 100644 index 0000000..14ff905 --- /dev/null +++ b/ApneaDetection.py @@ -0,0 +1,88 @@ +#!/usr/bin/python +# -*- coding: UTF-8 -*- +""" +@author:andrew +@file:ApneaDetection.py +@email:admin@marques22.com +@email:2021022362@m.scnu.edu.cn +@time:2023/09/15 +""" +from pathlib import Path + +import numpy as np +import pandas as pd + +from utils.SignalPreprocess import XinXiaoPreprocess +from utils.ModelDetection import SA_Detect +from utils.ResultSummary import AnalyseSegment +import argparse + + +def is_valid_file(par, arg): + if not Path(arg).is_file(): + par.error(f'只能输入模型文件而不是文件夹:{arg}') + else: + return arg + + +def read_auto(file_path): + if file_path.suffix == '.txt': + # 使用pandas read csv读取txt + return pd.read_csv(file_path, header=None).to_numpy().reshape(-1) + elif file_path.suffix == '.npy': + return np.load(file_path.__str__()) + else: + raise ValueError('这个文件类型好像不支持需要自己写读取程序') + + +def main(opt): + # 读取数据集 + data_path = Path(opt.data_path) + all_file_path = [] + print(data_path) + if data_path.is_file(): + all_file_path.append(data_path) + elif data_path.is_dir(): + all_file_path = list(data_path.glob('*.txt')) + list(data_path.glob('*.npy')) + else: + raise FileNotFoundError('没有找到以txt或npy结尾的文件') + + if not all_file_path: + raise FileNotFoundError('没有找到以txt或npy结尾的文件') + + for one_file in all_file_path: + data = read_auto(one_file) + + data = XinXiaoPreprocess(data, opt.hz) + model_path = f'SADetectModel/SAmodel{opt.model}.pt' + if not Path(model_path).exists(): + raise FileNotFoundError('模型文件不存在') + else: + print(f'正在使用模型{model_path}进行预测') + result = SA_Detect(model_path, data, opt.batch_size) + # result.to_csv(Path(opt.output) / (one_file.stem + "segment.csv"), index=False) + if opt.output: + save_path = Path(opt.output) / (one_file.stem + ".csv") + else: + save_path = None + + AnalyseSegment(result, save_path=save_path, true_sleep_time=opt.TST) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='SA Detection Configuration ----- Marques') + parser.add_argument('-d', '--data_path', nargs='?', type=str, default='./Data', help='待测试数据文件路径或文件夹') + parser.add_argument('--hz', type=int, nargs='?', default=1000, help='信号采样率') + parser.add_argument('-m', '--model', nargs='?', default=0, choices=[0, 1, 3, 4], help='选择一个模型路径') + parser.add_argument('-b', "--batch_size", nargs="?", default=4096, type=int, help="模型每次预测片段数量") + parser.add_argument('-o', '--output', nargs='?', default='./Output', help='输出文件夹') + parser.add_argument('-t', "--TST", nargs="?", default=0, type=int, help="睡眠总时长(仅单文件测试是支持手动输入)") + option = parser.parse_args() + + # 手动输入,优先级高于默认值和命令行输入 + option.data_path = "/home/marques/code/marques/apnea/dataset/zhongda/zhongda_origin_npy/3103.npy" + # option.batch_size = 1 + # option.TST = 565 + # option.model = 0 + + main(option) diff --git a/SADetectModel/SaDetectModel.py b/SADetectModel/SaDetectModel.py new file mode 100644 index 0000000..43e8282 --- /dev/null +++ b/SADetectModel/SaDetectModel.py @@ -0,0 +1,141 @@ +#!/usr/bin/python +# -*- coding: UTF-8 -*- +""" +@author:andrew +@file:SaDetectModel.py +@email:admin@marques22.com +@email:2021022362@m.scnu.edu.cn +@time:2023/09/16 +""" +from torch import nn + + +class BasicBlock_1d(nn.Module): + expansion = 1 + + def __init__(self, input_channel, output_channel, stride=1): + super(BasicBlock_1d, self).__init__() + self.left = nn.Sequential( + nn.Conv1d(in_channels=input_channel, out_channels=output_channel, + kernel_size=3, stride=stride, padding=1, bias=False), + nn.BatchNorm1d(output_channel), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=output_channel, out_channels=output_channel, + kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm1d(output_channel) + ) + self.right = nn.Sequential() + + if stride != 1 or input_channel != self.expansion * output_channel: + self.right = nn.Sequential( + nn.Conv1d(in_channels=input_channel, out_channels=output_channel * self.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm1d(self.expansion * output_channel) + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.left(x) + residual = self.right(x) + out += residual + out = self.relu(out) + return out + + +class ResNet_1d(nn.Module): + def __init__(self, block, number_block, num_classes=2): + super(ResNet_1d, self).__init__() + + self.in_channel = 64 + + self.conv1 = nn.Conv1d(in_channels=1, out_channels=self.in_channel, + kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm1d(64) + self.relu = nn.ReLU(inplace=True) + self.pool1 = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block=block, out_channel=64, num_block=number_block[0], stride=1) + self.layer2 = self._make_layer(block=block, out_channel=128, num_block=number_block[1], stride=2) + self.layer3 = self._make_layer(block=block, out_channel=256, num_block=number_block[2], stride=2) + self.layer4 = self._make_layer(block=block, out_channel=512, num_block=number_block[3], stride=2) + # self.layer5 = self._make_layer(block=block, out_channel=512, num_block=number_block[4], stride=2) + self.pool2 = nn.AdaptiveAvgPool1d(1) + + self.features = nn.Sequential( + # nn.Linear(in_features=1024, out_features=nc), + nn.Flatten(), + # nn.Linear(in_features=512 * 23, out_features=512), + nn.Linear(in_features=512, out_features=num_classes) + + # nn.Softmax() + # nn.Sigmoid() + ) + # self.linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, out_channel, num_block, stride): + strides = [stride] + [1] * (num_block - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_channel, out_channel, stride)) + self.in_channel = out_channel * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.pool1(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + # x = self.layer5(x) + x = self.pool2(x) + x = x.view(x.size(0), -1) + + x = self.features(x) + return x + + +class ResNet18_LSTM_1d_v2(ResNet_1d): + def __init__(self, block, number_block, num_classes, hidden_size, num_layers, bidirectional): + super(ResNet18_LSTM_1d_v2, self).__init__( + block=block, + number_block=number_block, + num_classes=num_classes + ) + # self.pool3 = nn.MaxPool1d(4) + self.lstm = nn.LSTM(input_size=512, + hidden_size=hidden_size, + num_layers=num_layers, + bidirectional=bidirectional, + batch_first=True) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.pool1(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = x.transpose(2, 1) + x, (h_n1, c_n1) = self.lstm(x) + x = x[:, -1, :] + x = self.features(x) + return x + + +def ResNet18_v2_LSTM(): + return ResNet18_LSTM_1d_v2(BasicBlock_1d, [2, 2, 2, 2], + num_classes=2, hidden_size=512, num_layers=2, bidirectional=False) + + +if __name__ == '__main__': + # from torchinfo import summary + # resnet = ResNet18_v2_LSTM().cuda() + # summary(resnet, (4, 1, 300)) + pass + diff --git a/utils/ModelDetection.py b/utils/ModelDetection.py new file mode 100644 index 0000000..1451afe --- /dev/null +++ b/utils/ModelDetection.py @@ -0,0 +1,63 @@ +#!/usr/bin/python +# -*- coding: UTF-8 -*- +""" +@author:andrew +@file:ModelDetection.py +@email:admin@marques22.com +@email:2021022362@m.scnu.edu.cn +@time:2023/09/15 +""" + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset, DataLoader +from tqdm import tqdm +from SADetectModel.SaDetectModel import ResNet18_v2_LSTM + +gpu = torch.cuda.is_available() + + +class TestApneaDataset(Dataset): + def __init__(self, datasets): + super(TestApneaDataset, self).__init__() + self.datasets = datasets + self.labels = np.linspace(0, len(self.datasets) // 5 - 60, len(self.datasets) // 5 - 60 - 1).astype(int) + # print(len(datasets)) + # print(self.labels) + + def __getitem__(self, index): + SP = self.labels[index] + segment = self.datasets[SP * 5:(SP + 60) * 5].copy().reshape((1, -1)) + return segment, SP + + def __len__(self): + return len(self.labels) + + +def SA_Detect(model_path, data, batch_size): + model = ResNet18_v2_LSTM() + model.load_state_dict(torch.load(model_path)) + model = model.cuda() if gpu else model + model.eval() + + test_dataset = TestApneaDataset(datasets=data) + test_loader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=False, num_workers=0) + + temp_result = [] + + for resp, SP in tqdm(test_loader, total=len(test_loader)): + resp = resp.float().cuda() if gpu else resp.float() + with torch.no_grad(): + out = model(resp) + out = torch.softmax(out, dim=1)[:, 1] + + temp = np.stack((SP.cpu().numpy(), out.cpu().numpy()), axis=1) + temp_result.append(temp) + + df_segment = pd.DataFrame(data=np.concatenate(temp_result, axis=0), columns=["SP", "pred"]) + df_segment["pred"] = df_segment["pred"].copy().apply(lambda x: round(x, 3)) + df_segment["SP"] = df_segment["SP"].astype(int) + df_segment = df_segment.sort_values(by="SP", ascending=True) + + return df_segment diff --git a/utils/ResultSummary.py b/utils/ResultSummary.py new file mode 100644 index 0000000..b9ad675 --- /dev/null +++ b/utils/ResultSummary.py @@ -0,0 +1,66 @@ +#!/usr/bin/python +# -*- coding: UTF-8 -*- +""" +@author:andrew +@file:ResultSummary.py +@email:admin@marques22.com +@email:2021022362@m.scnu.edu.cn +@time:2023/09/15 +""" +import numpy as np +import pandas as pd + + +def AnalyseSegment(df_segment, thresh=0.5, thresh_event_interval=2, thresh_event_length=8, + save_path=None, true_sleep_time=0): + df_segment["thresh_Pred"] = df_segment["pred"].apply(lambda x: 1 if x > thresh else 0) + thresh_Pred = df_segment["thresh_Pred"].values + thresh_Pred2 = thresh_Pred.copy() + + # 扩充 + indices = np.where(thresh_Pred == 1)[0] + for i in range(len(indices) - 1): + if indices[i + 1] - indices[i] <= thresh_event_interval: + thresh_Pred2[indices[i]:indices[i + 1]] = 1 + # 事件判断 + # 如果连续的1的长度大于阈值,则判断为事件,记录起止位置 + diffs = np.diff(thresh_Pred2, prepend=0, append=0) + start_indices = np.where(diffs == 1)[0] + end_indices = np.where(diffs == -1)[0] + event_indices = np.where(end_indices - start_indices >= thresh_event_length)[0] + result = [(start_indices[i], end_indices[i]) for i in event_indices] + + df_event = pd.DataFrame(result, columns=["start", "end"]) + df_event["length"] = df_event["end"] - df_event["start"] + df_event["start"] = df_event["start"] - 29 + df_event["end"] = df_event["end"] - 29 + + if save_path is not None: + df_event.to_csv(save_path, index=False) + + # 根据AHI评估严重程度 + if true_sleep_time is 0: + record_length = len(df_segment) / 60 / 60 + AHI = len(result) / record_length + else: + record_length = true_sleep_time / 60 + AHI = len(result) / record_length + + if AHI < 5: + severity = "Healthy" + elif AHI < 15: + severity = "Mild" + elif AHI < 30: + severity = "Moderate" + else: + severity = "Severe" + + SA_HYP_count = len(result) + print("patient: ", save_path.stem) + print("Event number: ", SA_HYP_count) + print("Record length (hours): ", round(record_length)) + print("AHI: ", AHI) + print("Severity: ", severity) + print("=====================================") + + return SA_HYP_count, AHI, severity diff --git a/utils/SignalPreprocess.py b/utils/SignalPreprocess.py new file mode 100644 index 0000000..0984206 --- /dev/null +++ b/utils/SignalPreprocess.py @@ -0,0 +1,52 @@ +#!/usr/bin/python +# -*- coding: UTF-8 -*- +""" +@author:andrew +@file:SignalPreprocess.py +@email:admin@marques22.com +@email:2021022362@m.scnu.edu.cn +@time:2023/09/15 +""" + +# 输入信号为一维的的整晚信号 + +import numpy as np +from scipy import signal + + +def XinXiaoPreprocess(data=None, signal_frequency=1000, low_cut=0.01, high_cut=0.7, order=4): + """ + # 预处理操作 + # 提取呼吸 + # 四阶巴特沃斯滤波器 提取呼吸信号主成分 + # 四秒移动平均滤波器 平滑呼吸信号 + # Z-Score 标准化 + # 降采样至5Hz + + :param data: + :param signal_frequency: + :param low_cut: + :param high_cut: + :param order: + + :return: + """ + order = order + low_cut = low_cut + high_cut = high_cut + sample_rate = signal_frequency + low = low_cut / (sample_rate * 0.5) + high = high_cut / (sample_rate * 0.5) + sos = signal.butter(N=order, Wn=[low, high], btype="bandpass", output='sos') + data = signal.sosfilt(sos, data) + # 比对采样率 将信号采样率调整至100Hz + data= data[::10] + + data -= np.convolve(data, np.ones(400) / 400, mode='same') + data = (data - data.mean()) / data.std() + data[data > 3] = 3 + data[data < -3] = -3 + # 降采样至5Hz + + data = data[::20] + return data