0917backup

This commit is contained in:
andrew 2023-09-17 00:46:14 +08:00
parent 17f8963e0f
commit 9fe80b02a8
6 changed files with 415 additions and 1 deletions

6
.gitignore vendored
View File

@ -158,5 +158,9 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
Data/*
Output/*
SADetectModel/*.pt

88
ApneaDetection.py Normal file
View File

@ -0,0 +1,88 @@
#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:andrew
@file:ApneaDetection.py
@email:admin@marques22.com
@email:2021022362@m.scnu.edu.cn
@time:2023/09/15
"""
from pathlib import Path
import numpy as np
import pandas as pd
from utils.SignalPreprocess import XinXiaoPreprocess
from utils.ModelDetection import SA_Detect
from utils.ResultSummary import AnalyseSegment
import argparse
def is_valid_file(par, arg):
if not Path(arg).is_file():
par.error(f'只能输入模型文件而不是文件夹:{arg}')
else:
return arg
def read_auto(file_path):
if file_path.suffix == '.txt':
# 使用pandas read csv读取txt
return pd.read_csv(file_path, header=None).to_numpy().reshape(-1)
elif file_path.suffix == '.npy':
return np.load(file_path.__str__())
else:
raise ValueError('这个文件类型好像不支持需要自己写读取程序')
def main(opt):
# 读取数据集
data_path = Path(opt.data_path)
all_file_path = []
print(data_path)
if data_path.is_file():
all_file_path.append(data_path)
elif data_path.is_dir():
all_file_path = list(data_path.glob('*.txt')) + list(data_path.glob('*.npy'))
else:
raise FileNotFoundError('没有找到以txt或npy结尾的文件')
if not all_file_path:
raise FileNotFoundError('没有找到以txt或npy结尾的文件')
for one_file in all_file_path:
data = read_auto(one_file)
data = XinXiaoPreprocess(data, opt.hz)
model_path = f'SADetectModel/SAmodel{opt.model}.pt'
if not Path(model_path).exists():
raise FileNotFoundError('模型文件不存在')
else:
print(f'正在使用模型{model_path}进行预测')
result = SA_Detect(model_path, data, opt.batch_size)
# result.to_csv(Path(opt.output) / (one_file.stem + "segment.csv"), index=False)
if opt.output:
save_path = Path(opt.output) / (one_file.stem + ".csv")
else:
save_path = None
AnalyseSegment(result, save_path=save_path, true_sleep_time=opt.TST)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='SA Detection Configuration ----- Marques')
parser.add_argument('-d', '--data_path', nargs='?', type=str, default='./Data', help='待测试数据文件路径或文件夹')
parser.add_argument('--hz', type=int, nargs='?', default=1000, help='信号采样率')
parser.add_argument('-m', '--model', nargs='?', default=0, choices=[0, 1, 3, 4], help='选择一个模型路径')
parser.add_argument('-b', "--batch_size", nargs="?", default=4096, type=int, help="模型每次预测片段数量")
parser.add_argument('-o', '--output', nargs='?', default='./Output', help='输出文件夹')
parser.add_argument('-t', "--TST", nargs="?", default=0, type=int, help="睡眠总时长(仅单文件测试是支持手动输入)")
option = parser.parse_args()
# 手动输入,优先级高于默认值和命令行输入
option.data_path = "/home/marques/code/marques/apnea/dataset/zhongda/zhongda_origin_npy/3103.npy"
# option.batch_size = 1
# option.TST = 565
# option.model = 0
main(option)

View File

@ -0,0 +1,141 @@
#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:andrew
@file:SaDetectModel.py
@email:admin@marques22.com
@email:2021022362@m.scnu.edu.cn
@time:2023/09/16
"""
from torch import nn
class BasicBlock_1d(nn.Module):
expansion = 1
def __init__(self, input_channel, output_channel, stride=1):
super(BasicBlock_1d, self).__init__()
self.left = nn.Sequential(
nn.Conv1d(in_channels=input_channel, out_channels=output_channel,
kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm1d(output_channel),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=output_channel, out_channels=output_channel,
kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm1d(output_channel)
)
self.right = nn.Sequential()
if stride != 1 or input_channel != self.expansion * output_channel:
self.right = nn.Sequential(
nn.Conv1d(in_channels=input_channel, out_channels=output_channel * self.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm1d(self.expansion * output_channel)
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
out = self.left(x)
residual = self.right(x)
out += residual
out = self.relu(out)
return out
class ResNet_1d(nn.Module):
def __init__(self, block, number_block, num_classes=2):
super(ResNet_1d, self).__init__()
self.in_channel = 64
self.conv1 = nn.Conv1d(in_channels=1, out_channels=self.in_channel,
kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm1d(64)
self.relu = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block=block, out_channel=64, num_block=number_block[0], stride=1)
self.layer2 = self._make_layer(block=block, out_channel=128, num_block=number_block[1], stride=2)
self.layer3 = self._make_layer(block=block, out_channel=256, num_block=number_block[2], stride=2)
self.layer4 = self._make_layer(block=block, out_channel=512, num_block=number_block[3], stride=2)
# self.layer5 = self._make_layer(block=block, out_channel=512, num_block=number_block[4], stride=2)
self.pool2 = nn.AdaptiveAvgPool1d(1)
self.features = nn.Sequential(
# nn.Linear(in_features=1024, out_features=nc),
nn.Flatten(),
# nn.Linear(in_features=512 * 23, out_features=512),
nn.Linear(in_features=512, out_features=num_classes)
# nn.Softmax()
# nn.Sigmoid()
)
# self.linear = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, out_channel, num_block, stride):
strides = [stride] + [1] * (num_block - 1)
layers = []
for stride in strides:
layers.append(block(self.in_channel, out_channel, stride))
self.in_channel = out_channel * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.pool1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# x = self.layer5(x)
x = self.pool2(x)
x = x.view(x.size(0), -1)
x = self.features(x)
return x
class ResNet18_LSTM_1d_v2(ResNet_1d):
def __init__(self, block, number_block, num_classes, hidden_size, num_layers, bidirectional):
super(ResNet18_LSTM_1d_v2, self).__init__(
block=block,
number_block=number_block,
num_classes=num_classes
)
# self.pool3 = nn.MaxPool1d(4)
self.lstm = nn.LSTM(input_size=512,
hidden_size=hidden_size,
num_layers=num_layers,
bidirectional=bidirectional,
batch_first=True)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.pool1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.transpose(2, 1)
x, (h_n1, c_n1) = self.lstm(x)
x = x[:, -1, :]
x = self.features(x)
return x
def ResNet18_v2_LSTM():
return ResNet18_LSTM_1d_v2(BasicBlock_1d, [2, 2, 2, 2],
num_classes=2, hidden_size=512, num_layers=2, bidirectional=False)
if __name__ == '__main__':
# from torchinfo import summary
# resnet = ResNet18_v2_LSTM().cuda()
# summary(resnet, (4, 1, 300))
pass

63
utils/ModelDetection.py Normal file
View File

@ -0,0 +1,63 @@
#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:andrew
@file:ModelDetection.py
@email:admin@marques22.com
@email:2021022362@m.scnu.edu.cn
@time:2023/09/15
"""
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from SADetectModel.SaDetectModel import ResNet18_v2_LSTM
gpu = torch.cuda.is_available()
class TestApneaDataset(Dataset):
def __init__(self, datasets):
super(TestApneaDataset, self).__init__()
self.datasets = datasets
self.labels = np.linspace(0, len(self.datasets) // 5 - 60, len(self.datasets) // 5 - 60 - 1).astype(int)
# print(len(datasets))
# print(self.labels)
def __getitem__(self, index):
SP = self.labels[index]
segment = self.datasets[SP * 5:(SP + 60) * 5].copy().reshape((1, -1))
return segment, SP
def __len__(self):
return len(self.labels)
def SA_Detect(model_path, data, batch_size):
model = ResNet18_v2_LSTM()
model.load_state_dict(torch.load(model_path))
model = model.cuda() if gpu else model
model.eval()
test_dataset = TestApneaDataset(datasets=data)
test_loader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=False, num_workers=0)
temp_result = []
for resp, SP in tqdm(test_loader, total=len(test_loader)):
resp = resp.float().cuda() if gpu else resp.float()
with torch.no_grad():
out = model(resp)
out = torch.softmax(out, dim=1)[:, 1]
temp = np.stack((SP.cpu().numpy(), out.cpu().numpy()), axis=1)
temp_result.append(temp)
df_segment = pd.DataFrame(data=np.concatenate(temp_result, axis=0), columns=["SP", "pred"])
df_segment["pred"] = df_segment["pred"].copy().apply(lambda x: round(x, 3))
df_segment["SP"] = df_segment["SP"].astype(int)
df_segment = df_segment.sort_values(by="SP", ascending=True)
return df_segment

66
utils/ResultSummary.py Normal file
View File

@ -0,0 +1,66 @@
#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:andrew
@file:ResultSummary.py
@email:admin@marques22.com
@email:2021022362@m.scnu.edu.cn
@time:2023/09/15
"""
import numpy as np
import pandas as pd
def AnalyseSegment(df_segment, thresh=0.5, thresh_event_interval=2, thresh_event_length=8,
save_path=None, true_sleep_time=0):
df_segment["thresh_Pred"] = df_segment["pred"].apply(lambda x: 1 if x > thresh else 0)
thresh_Pred = df_segment["thresh_Pred"].values
thresh_Pred2 = thresh_Pred.copy()
# 扩充
indices = np.where(thresh_Pred == 1)[0]
for i in range(len(indices) - 1):
if indices[i + 1] - indices[i] <= thresh_event_interval:
thresh_Pred2[indices[i]:indices[i + 1]] = 1
# 事件判断
# 如果连续的1的长度大于阈值则判断为事件记录起止位置
diffs = np.diff(thresh_Pred2, prepend=0, append=0)
start_indices = np.where(diffs == 1)[0]
end_indices = np.where(diffs == -1)[0]
event_indices = np.where(end_indices - start_indices >= thresh_event_length)[0]
result = [(start_indices[i], end_indices[i]) for i in event_indices]
df_event = pd.DataFrame(result, columns=["start", "end"])
df_event["length"] = df_event["end"] - df_event["start"]
df_event["start"] = df_event["start"] - 29
df_event["end"] = df_event["end"] - 29
if save_path is not None:
df_event.to_csv(save_path, index=False)
# 根据AHI评估严重程度
if true_sleep_time is 0:
record_length = len(df_segment) / 60 / 60
AHI = len(result) / record_length
else:
record_length = true_sleep_time / 60
AHI = len(result) / record_length
if AHI < 5:
severity = "Healthy"
elif AHI < 15:
severity = "Mild"
elif AHI < 30:
severity = "Moderate"
else:
severity = "Severe"
SA_HYP_count = len(result)
print("patient: ", save_path.stem)
print("Event number: ", SA_HYP_count)
print("Record length (hours): ", round(record_length))
print("AHI: ", AHI)
print("Severity: ", severity)
print("=====================================")
return SA_HYP_count, AHI, severity

52
utils/SignalPreprocess.py Normal file
View File

@ -0,0 +1,52 @@
#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:andrew
@file:SignalPreprocess.py
@email:admin@marques22.com
@email:2021022362@m.scnu.edu.cn
@time:2023/09/15
"""
# 输入信号为一维的的整晚信号
import numpy as np
from scipy import signal
def XinXiaoPreprocess(data=None, signal_frequency=1000, low_cut=0.01, high_cut=0.7, order=4):
"""
# 预处理操作
# 提取呼吸
# 四阶巴特沃斯滤波器 提取呼吸信号主成分
# 四秒移动平均滤波器 平滑呼吸信号
# Z-Score 标准化
# 降采样至5Hz
:param data:
:param signal_frequency:
:param low_cut:
:param high_cut:
:param order:
:return:
"""
order = order
low_cut = low_cut
high_cut = high_cut
sample_rate = signal_frequency
low = low_cut / (sample_rate * 0.5)
high = high_cut / (sample_rate * 0.5)
sos = signal.butter(N=order, Wn=[low, high], btype="bandpass", output='sos')
data = signal.sosfilt(sos, data)
# 比对采样率 将信号采样率调整至100Hz
data= data[::10]
data -= np.convolve(data, np.ones(400) / 400, mode='same')
data = (data - data.mean()) / data.std()
data[data > 3] = 3
data[data < -3] = -3
# 降采样至5Hz
data = data[::20]
return data