0917backup
This commit is contained in:
parent
17f8963e0f
commit
9fe80b02a8
6
.gitignore
vendored
6
.gitignore
vendored
@ -158,5 +158,9 @@ cython_debug/
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.idea/
|
||||
|
||||
Data/*
|
||||
Output/*
|
||||
SADetectModel/*.pt
|
||||
|
||||
|
88
ApneaDetection.py
Normal file
88
ApneaDetection.py
Normal file
@ -0,0 +1,88 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@author:andrew
|
||||
@file:ApneaDetection.py
|
||||
@email:admin@marques22.com
|
||||
@email:2021022362@m.scnu.edu.cn
|
||||
@time:2023/09/15
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from utils.SignalPreprocess import XinXiaoPreprocess
|
||||
from utils.ModelDetection import SA_Detect
|
||||
from utils.ResultSummary import AnalyseSegment
|
||||
import argparse
|
||||
|
||||
|
||||
def is_valid_file(par, arg):
|
||||
if not Path(arg).is_file():
|
||||
par.error(f'只能输入模型文件而不是文件夹:{arg}')
|
||||
else:
|
||||
return arg
|
||||
|
||||
|
||||
def read_auto(file_path):
|
||||
if file_path.suffix == '.txt':
|
||||
# 使用pandas read csv读取txt
|
||||
return pd.read_csv(file_path, header=None).to_numpy().reshape(-1)
|
||||
elif file_path.suffix == '.npy':
|
||||
return np.load(file_path.__str__())
|
||||
else:
|
||||
raise ValueError('这个文件类型好像不支持需要自己写读取程序')
|
||||
|
||||
|
||||
def main(opt):
|
||||
# 读取数据集
|
||||
data_path = Path(opt.data_path)
|
||||
all_file_path = []
|
||||
print(data_path)
|
||||
if data_path.is_file():
|
||||
all_file_path.append(data_path)
|
||||
elif data_path.is_dir():
|
||||
all_file_path = list(data_path.glob('*.txt')) + list(data_path.glob('*.npy'))
|
||||
else:
|
||||
raise FileNotFoundError('没有找到以txt或npy结尾的文件')
|
||||
|
||||
if not all_file_path:
|
||||
raise FileNotFoundError('没有找到以txt或npy结尾的文件')
|
||||
|
||||
for one_file in all_file_path:
|
||||
data = read_auto(one_file)
|
||||
|
||||
data = XinXiaoPreprocess(data, opt.hz)
|
||||
model_path = f'SADetectModel/SAmodel{opt.model}.pt'
|
||||
if not Path(model_path).exists():
|
||||
raise FileNotFoundError('模型文件不存在')
|
||||
else:
|
||||
print(f'正在使用模型{model_path}进行预测')
|
||||
result = SA_Detect(model_path, data, opt.batch_size)
|
||||
# result.to_csv(Path(opt.output) / (one_file.stem + "segment.csv"), index=False)
|
||||
if opt.output:
|
||||
save_path = Path(opt.output) / (one_file.stem + ".csv")
|
||||
else:
|
||||
save_path = None
|
||||
|
||||
AnalyseSegment(result, save_path=save_path, true_sleep_time=opt.TST)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='SA Detection Configuration ----- Marques')
|
||||
parser.add_argument('-d', '--data_path', nargs='?', type=str, default='./Data', help='待测试数据文件路径或文件夹')
|
||||
parser.add_argument('--hz', type=int, nargs='?', default=1000, help='信号采样率')
|
||||
parser.add_argument('-m', '--model', nargs='?', default=0, choices=[0, 1, 3, 4], help='选择一个模型路径')
|
||||
parser.add_argument('-b', "--batch_size", nargs="?", default=4096, type=int, help="模型每次预测片段数量")
|
||||
parser.add_argument('-o', '--output', nargs='?', default='./Output', help='输出文件夹')
|
||||
parser.add_argument('-t', "--TST", nargs="?", default=0, type=int, help="睡眠总时长(仅单文件测试是支持手动输入)")
|
||||
option = parser.parse_args()
|
||||
|
||||
# 手动输入,优先级高于默认值和命令行输入
|
||||
option.data_path = "/home/marques/code/marques/apnea/dataset/zhongda/zhongda_origin_npy/3103.npy"
|
||||
# option.batch_size = 1
|
||||
# option.TST = 565
|
||||
# option.model = 0
|
||||
|
||||
main(option)
|
141
SADetectModel/SaDetectModel.py
Normal file
141
SADetectModel/SaDetectModel.py
Normal file
@ -0,0 +1,141 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@author:andrew
|
||||
@file:SaDetectModel.py
|
||||
@email:admin@marques22.com
|
||||
@email:2021022362@m.scnu.edu.cn
|
||||
@time:2023/09/16
|
||||
"""
|
||||
from torch import nn
|
||||
|
||||
|
||||
class BasicBlock_1d(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, input_channel, output_channel, stride=1):
|
||||
super(BasicBlock_1d, self).__init__()
|
||||
self.left = nn.Sequential(
|
||||
nn.Conv1d(in_channels=input_channel, out_channels=output_channel,
|
||||
kernel_size=3, stride=stride, padding=1, bias=False),
|
||||
nn.BatchNorm1d(output_channel),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv1d(in_channels=output_channel, out_channels=output_channel,
|
||||
kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm1d(output_channel)
|
||||
)
|
||||
self.right = nn.Sequential()
|
||||
|
||||
if stride != 1 or input_channel != self.expansion * output_channel:
|
||||
self.right = nn.Sequential(
|
||||
nn.Conv1d(in_channels=input_channel, out_channels=output_channel * self.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm1d(self.expansion * output_channel)
|
||||
)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.left(x)
|
||||
residual = self.right(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet_1d(nn.Module):
|
||||
def __init__(self, block, number_block, num_classes=2):
|
||||
super(ResNet_1d, self).__init__()
|
||||
|
||||
self.in_channel = 64
|
||||
|
||||
self.conv1 = nn.Conv1d(in_channels=1, out_channels=self.in_channel,
|
||||
kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm1d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.pool1 = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.layer1 = self._make_layer(block=block, out_channel=64, num_block=number_block[0], stride=1)
|
||||
self.layer2 = self._make_layer(block=block, out_channel=128, num_block=number_block[1], stride=2)
|
||||
self.layer3 = self._make_layer(block=block, out_channel=256, num_block=number_block[2], stride=2)
|
||||
self.layer4 = self._make_layer(block=block, out_channel=512, num_block=number_block[3], stride=2)
|
||||
# self.layer5 = self._make_layer(block=block, out_channel=512, num_block=number_block[4], stride=2)
|
||||
self.pool2 = nn.AdaptiveAvgPool1d(1)
|
||||
|
||||
self.features = nn.Sequential(
|
||||
# nn.Linear(in_features=1024, out_features=nc),
|
||||
nn.Flatten(),
|
||||
# nn.Linear(in_features=512 * 23, out_features=512),
|
||||
nn.Linear(in_features=512, out_features=num_classes)
|
||||
|
||||
# nn.Softmax()
|
||||
# nn.Sigmoid()
|
||||
)
|
||||
# self.linear = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, out_channel, num_block, stride):
|
||||
strides = [stride] + [1] * (num_block - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_channel, out_channel, stride))
|
||||
self.in_channel = out_channel * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.pool1(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
# x = self.layer5(x)
|
||||
x = self.pool2(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
x = self.features(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResNet18_LSTM_1d_v2(ResNet_1d):
|
||||
def __init__(self, block, number_block, num_classes, hidden_size, num_layers, bidirectional):
|
||||
super(ResNet18_LSTM_1d_v2, self).__init__(
|
||||
block=block,
|
||||
number_block=number_block,
|
||||
num_classes=num_classes
|
||||
)
|
||||
# self.pool3 = nn.MaxPool1d(4)
|
||||
self.lstm = nn.LSTM(input_size=512,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
bidirectional=bidirectional,
|
||||
batch_first=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.pool1(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = x.transpose(2, 1)
|
||||
x, (h_n1, c_n1) = self.lstm(x)
|
||||
x = x[:, -1, :]
|
||||
x = self.features(x)
|
||||
return x
|
||||
|
||||
|
||||
def ResNet18_v2_LSTM():
|
||||
return ResNet18_LSTM_1d_v2(BasicBlock_1d, [2, 2, 2, 2],
|
||||
num_classes=2, hidden_size=512, num_layers=2, bidirectional=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# from torchinfo import summary
|
||||
# resnet = ResNet18_v2_LSTM().cuda()
|
||||
# summary(resnet, (4, 1, 300))
|
||||
pass
|
||||
|
63
utils/ModelDetection.py
Normal file
63
utils/ModelDetection.py
Normal file
@ -0,0 +1,63 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@author:andrew
|
||||
@file:ModelDetection.py
|
||||
@email:admin@marques22.com
|
||||
@email:2021022362@m.scnu.edu.cn
|
||||
@time:2023/09/15
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from tqdm import tqdm
|
||||
from SADetectModel.SaDetectModel import ResNet18_v2_LSTM
|
||||
|
||||
gpu = torch.cuda.is_available()
|
||||
|
||||
|
||||
class TestApneaDataset(Dataset):
|
||||
def __init__(self, datasets):
|
||||
super(TestApneaDataset, self).__init__()
|
||||
self.datasets = datasets
|
||||
self.labels = np.linspace(0, len(self.datasets) // 5 - 60, len(self.datasets) // 5 - 60 - 1).astype(int)
|
||||
# print(len(datasets))
|
||||
# print(self.labels)
|
||||
|
||||
def __getitem__(self, index):
|
||||
SP = self.labels[index]
|
||||
segment = self.datasets[SP * 5:(SP + 60) * 5].copy().reshape((1, -1))
|
||||
return segment, SP
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
|
||||
|
||||
def SA_Detect(model_path, data, batch_size):
|
||||
model = ResNet18_v2_LSTM()
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
model = model.cuda() if gpu else model
|
||||
model.eval()
|
||||
|
||||
test_dataset = TestApneaDataset(datasets=data)
|
||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=False, num_workers=0)
|
||||
|
||||
temp_result = []
|
||||
|
||||
for resp, SP in tqdm(test_loader, total=len(test_loader)):
|
||||
resp = resp.float().cuda() if gpu else resp.float()
|
||||
with torch.no_grad():
|
||||
out = model(resp)
|
||||
out = torch.softmax(out, dim=1)[:, 1]
|
||||
|
||||
temp = np.stack((SP.cpu().numpy(), out.cpu().numpy()), axis=1)
|
||||
temp_result.append(temp)
|
||||
|
||||
df_segment = pd.DataFrame(data=np.concatenate(temp_result, axis=0), columns=["SP", "pred"])
|
||||
df_segment["pred"] = df_segment["pred"].copy().apply(lambda x: round(x, 3))
|
||||
df_segment["SP"] = df_segment["SP"].astype(int)
|
||||
df_segment = df_segment.sort_values(by="SP", ascending=True)
|
||||
|
||||
return df_segment
|
66
utils/ResultSummary.py
Normal file
66
utils/ResultSummary.py
Normal file
@ -0,0 +1,66 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@author:andrew
|
||||
@file:ResultSummary.py
|
||||
@email:admin@marques22.com
|
||||
@email:2021022362@m.scnu.edu.cn
|
||||
@time:2023/09/15
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def AnalyseSegment(df_segment, thresh=0.5, thresh_event_interval=2, thresh_event_length=8,
|
||||
save_path=None, true_sleep_time=0):
|
||||
df_segment["thresh_Pred"] = df_segment["pred"].apply(lambda x: 1 if x > thresh else 0)
|
||||
thresh_Pred = df_segment["thresh_Pred"].values
|
||||
thresh_Pred2 = thresh_Pred.copy()
|
||||
|
||||
# 扩充
|
||||
indices = np.where(thresh_Pred == 1)[0]
|
||||
for i in range(len(indices) - 1):
|
||||
if indices[i + 1] - indices[i] <= thresh_event_interval:
|
||||
thresh_Pred2[indices[i]:indices[i + 1]] = 1
|
||||
# 事件判断
|
||||
# 如果连续的1的长度大于阈值,则判断为事件,记录起止位置
|
||||
diffs = np.diff(thresh_Pred2, prepend=0, append=0)
|
||||
start_indices = np.where(diffs == 1)[0]
|
||||
end_indices = np.where(diffs == -1)[0]
|
||||
event_indices = np.where(end_indices - start_indices >= thresh_event_length)[0]
|
||||
result = [(start_indices[i], end_indices[i]) for i in event_indices]
|
||||
|
||||
df_event = pd.DataFrame(result, columns=["start", "end"])
|
||||
df_event["length"] = df_event["end"] - df_event["start"]
|
||||
df_event["start"] = df_event["start"] - 29
|
||||
df_event["end"] = df_event["end"] - 29
|
||||
|
||||
if save_path is not None:
|
||||
df_event.to_csv(save_path, index=False)
|
||||
|
||||
# 根据AHI评估严重程度
|
||||
if true_sleep_time is 0:
|
||||
record_length = len(df_segment) / 60 / 60
|
||||
AHI = len(result) / record_length
|
||||
else:
|
||||
record_length = true_sleep_time / 60
|
||||
AHI = len(result) / record_length
|
||||
|
||||
if AHI < 5:
|
||||
severity = "Healthy"
|
||||
elif AHI < 15:
|
||||
severity = "Mild"
|
||||
elif AHI < 30:
|
||||
severity = "Moderate"
|
||||
else:
|
||||
severity = "Severe"
|
||||
|
||||
SA_HYP_count = len(result)
|
||||
print("patient: ", save_path.stem)
|
||||
print("Event number: ", SA_HYP_count)
|
||||
print("Record length (hours): ", round(record_length))
|
||||
print("AHI: ", AHI)
|
||||
print("Severity: ", severity)
|
||||
print("=====================================")
|
||||
|
||||
return SA_HYP_count, AHI, severity
|
52
utils/SignalPreprocess.py
Normal file
52
utils/SignalPreprocess.py
Normal file
@ -0,0 +1,52 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@author:andrew
|
||||
@file:SignalPreprocess.py
|
||||
@email:admin@marques22.com
|
||||
@email:2021022362@m.scnu.edu.cn
|
||||
@time:2023/09/15
|
||||
"""
|
||||
|
||||
# 输入信号为一维的的整晚信号
|
||||
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
|
||||
|
||||
def XinXiaoPreprocess(data=None, signal_frequency=1000, low_cut=0.01, high_cut=0.7, order=4):
|
||||
"""
|
||||
# 预处理操作
|
||||
# 提取呼吸
|
||||
# 四阶巴特沃斯滤波器 提取呼吸信号主成分
|
||||
# 四秒移动平均滤波器 平滑呼吸信号
|
||||
# Z-Score 标准化
|
||||
# 降采样至5Hz
|
||||
|
||||
:param data:
|
||||
:param signal_frequency:
|
||||
:param low_cut:
|
||||
:param high_cut:
|
||||
:param order:
|
||||
|
||||
:return:
|
||||
"""
|
||||
order = order
|
||||
low_cut = low_cut
|
||||
high_cut = high_cut
|
||||
sample_rate = signal_frequency
|
||||
low = low_cut / (sample_rate * 0.5)
|
||||
high = high_cut / (sample_rate * 0.5)
|
||||
sos = signal.butter(N=order, Wn=[low, high], btype="bandpass", output='sos')
|
||||
data = signal.sosfilt(sos, data)
|
||||
# 比对采样率 将信号采样率调整至100Hz
|
||||
data= data[::10]
|
||||
|
||||
data -= np.convolve(data, np.ones(400) / 400, mode='same')
|
||||
data = (data - data.mean()) / data.std()
|
||||
data[data > 3] = 3
|
||||
data[data < -3] = -3
|
||||
# 降采样至5Hz
|
||||
|
||||
data = data[::20]
|
||||
return data
|
Loading…
Reference in New Issue
Block a user