0915CXH_DL_SA/utils/ModelDetection.py
2023-09-17 00:46:14 +08:00

64 lines
1.9 KiB
Python

#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:andrew
@file:ModelDetection.py
@email:admin@marques22.com
@email:2021022362@m.scnu.edu.cn
@time:2023/09/15
"""
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from SADetectModel.SaDetectModel import ResNet18_v2_LSTM
gpu = torch.cuda.is_available()
class TestApneaDataset(Dataset):
def __init__(self, datasets):
super(TestApneaDataset, self).__init__()
self.datasets = datasets
self.labels = np.linspace(0, len(self.datasets) // 5 - 60, len(self.datasets) // 5 - 60 - 1).astype(int)
# print(len(datasets))
# print(self.labels)
def __getitem__(self, index):
SP = self.labels[index]
segment = self.datasets[SP * 5:(SP + 60) * 5].copy().reshape((1, -1))
return segment, SP
def __len__(self):
return len(self.labels)
def SA_Detect(model_path, data, batch_size):
model = ResNet18_v2_LSTM()
model.load_state_dict(torch.load(model_path))
model = model.cuda() if gpu else model
model.eval()
test_dataset = TestApneaDataset(datasets=data)
test_loader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=False, num_workers=0)
temp_result = []
for resp, SP in tqdm(test_loader, total=len(test_loader)):
resp = resp.float().cuda() if gpu else resp.float()
with torch.no_grad():
out = model(resp)
out = torch.softmax(out, dim=1)[:, 1]
temp = np.stack((SP.cpu().numpy(), out.cpu().numpy()), axis=1)
temp_result.append(temp)
df_segment = pd.DataFrame(data=np.concatenate(temp_result, axis=0), columns=["SP", "pred"])
df_segment["pred"] = df_segment["pred"].copy().apply(lambda x: round(x, 3))
df_segment["SP"] = df_segment["SP"].astype(int)
df_segment = df_segment.sort_values(by="SP", ascending=True)
return df_segment