sleep_apnea_hybrid/exam/023/test_save_result.py
2022-10-14 22:33:34 +08:00

479 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:Marques
@file:test_analysis.py
@email:admin@marques22.com
@email:2021022362@m.scnu.edu.cn
@time:2022/02/21
"""
import logging
import os
import sys
import pandas as pd
import torch.cuda
import numpy as np
import yaml
from matplotlib import pyplot as plt
from tqdm import tqdm
from pathlib import Path
from torch.nn import functional as F
from torch.utils.data import DataLoader
from load_dataset import TestApneaDataset2, read_dataset
from utils.Draw_ConfusionMatrix import draw_confusionMatrix
from torch import nn
from utils.calc_metrics import CALC_METRICS
from my_augment import my_augment, my_segment_augment
from model.Hybrid_Net003 import HYBRIDNET003
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
exam_path = Path("./output/")
# 置信率阈值
thresh = 0.5
# 间隔最小距离
thresh_event_interval = 2
# 最小事件长度
thresh_event_length = 8
#
event_thresh = 1
severity_path = Path(r"/home/marques/code/marques/apnea/dataset/loc_first_csa.xlsx")
severity_label = {"all": "none"}
severity_df = pd.read_excel(severity_path)
for one_data in severity_df.index:
one_data = severity_df.loc[one_data]
severity_label[str(one_data["数据编号"])] = one_data["程度"]
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
gpu = torch.cuda.is_available()
num_classes = 1
calc_metrics = CALC_METRICS(num_classes)
with open("./settings.yaml") as f:
hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
data_path = hyp["Path"]["dataset"]
read_dataset(data_path, augment=my_augment)
del hyp
# 默认取最新的文件夹
all_output_path, output_path, segments_results_save_path, events_results_save_path, = [None, ] * 4
my_augment, model_path, label_path, data_path, model, model_name = [None, ] * 6
train_set, test_set = None, None
loss_func = nn.CrossEntropyLoss()
columns = ["sampNo", "segmentNo", "label_type", "new_label", "SP", "EP", "pred"]
columns2 = ["sampNo", "severity", "origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN",
"acc", "recall", "spec", "pre", "NPV", "F1score", "support"]
logging.getLogger('matplotlib.font_manager').disabled = True
logging.getLogger('matplotlib.ticker').disabled = True
logger = logging.getLogger()
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(logging.Formatter("%(asctime)s: %(message)s"))
logger.addHandler(ch)
if (exam_path / "test.log").exists():
(exam_path / "test.log").unlink()
fh = logging.FileHandler(exam_path / "test.log", mode='a')
fh.setLevel(logging.INFO)
fh.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(fh)
logger.info("------------------------------------")
def set_environment(i):
global output_path, segments_results_save_path, events_results_save_path, model_path, label_path, data_path, \
model, model_name, train_set, test_set
output_path = all_output_path[i]
logger.info(output_path)
segments_results_save_path = (output_path / "segments_results")
segments_results_save_path.mkdir(exist_ok=True)
events_results_save_path = (output_path / "events_results")
events_results_save_path.mkdir(exist_ok=True)
# 加载配置
with open(output_path / "settings.yaml") as f:
hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
data_path = hyp["Path"]["dataset"]
label_path = hyp["Path"]["label"]
train_set = hyp["train_set"]
test_set = hyp["test_set"]
model_path = output_path / "weights" / "best.pt"
model = eval(hyp["model_name"])()
model_name = hyp["model_name"]
model.load_state_dict(torch.load(model_path))
model.cuda()
model.eval()
def test_and_analysis_and_visual(dataset_type):
if dataset_type == "test":
sampNo = train_set
elif dataset_type == "all_test":
sampNo = test_set
else:
sampNo = None
logger.info("出错了")
exam_name = Path("./").absolute().name
test_dataset = TestApneaDataset2(data_path, label_path, select_sampno=sampNo, dataset_type=dataset_type,
segment_augment=my_segment_augment)
test_loader = DataLoader(test_dataset, batch_size=128, pin_memory=True, num_workers=0)
test_loss = 0.0
df_segment = pd.DataFrame(columns=columns)
for one in tqdm(test_loader, total=len(test_loader)):
resp, stft, labels = one[:3]
other_info = one[3:]
resp = resp.float().cuda() if gpu else resp.float()
stft = stft.float().cuda() if gpu else stft.float()
labels = labels.cuda() if gpu else labels
with torch.no_grad():
out = model(resp, stft)
loss = loss_func(out, labels)
test_loss += loss.item()
labels = torch.unsqueeze(labels, dim=1)
out = F.softmax(out, dim=1)
out = torch.unsqueeze(out[:, 1], dim=1)
calc_metrics.update(out.cpu(), labels.cpu())
# one[0] = list(one[0].cpu().numpy())
# one[1] = list(one[1].cpu().numpy())
# one = one[1:]
# out = out.view(1, -1).cpu().numpy().tolist()
# one += out
# result_record += [i for i in list(np.array(one, dtype=object).transpose(1, 0))]
one2 = np.array([i.cpu().numpy() for i in (other_info + [out.squeeze()])])
one2 = one2.transpose((1, 0))
df = pd.DataFrame(data=one2, columns=columns)
df_segment = df_segment.append(df, ignore_index=True)
test_loss /= len(test_loader)
calc_metrics.compute()
logger.info(f"EXAM_NAME: {exam_name} SampNO: {sampNo}")
logger.info(calc_metrics.get_matrix(loss=test_loss, epoch=0, epoch_type="test"))
calc_metrics.reset()
df_segment["thresh_label"] = 1 * (df_segment["label_type"] > event_thresh).copy()
df_segment["thresh_Pred"] = 1 * (df_segment["pred"] > thresh).copy()
df_segment["pred"] = df_segment["pred"].copy().apply(lambda x: round(x, 3))
# 片段级分析
df_segment_metrics = analysis_results(df_segment, segments_results_save_path, dataset_type)
# 绘制混淆矩阵
# 每个样本都绘制一份
confusionMatrix(df_segment_metrics, segments_results_save_path, dataset_type)
# 绘制柱状图
# 事件级分析
# 对于inner_test 每个编号就是一个事件
# 而对于整晚的independence_test需要另行计算
df_all_event = segment_to_event(df_segment, dataset_type)
df_event_metrics = analysis_results(df_all_event, events_results_save_path, dataset_type, is_event=True)
confusionMatrix(df_event_metrics, events_results_save_path, dataset_type)
# 剔除质量不好的样本
df_bad_segment = df_segment[
(df_segment["label_type"].isin([2, 3])) & (df_segment["new_label"] == 2)]
df_select_segment = df_segment.drop(df_bad_segment.index)
df_select_segment_metrics = analysis_results(df_select_segment, segments_results_save_path / "remove_2",
dataset_type)
df_select_event = segment_to_event(df_select_segment, dataset_type)
df_event_metrics = analysis_results(df_select_event, events_results_save_path / "remove_2", dataset_type,
is_event=True)
def analysis_results(df_result, base_path, dataset_type, is_event=False):
if df_result.empty:
logger.info(base_path, dataset_type, "is_empty")
return None
(base_path / dataset_type).mkdir(exist_ok=True, parents=True)
all_sampNo = df_result["sampNo"].unique()
df_metrics = pd.DataFrame(columns=columns2)
df_metrics.loc[0] = 0
df_metrics.loc[0]["sampNo"] = dataset_type
for index, sampNo in enumerate(all_sampNo):
df = df_result[df_result["sampNo"] == sampNo]
df.to_csv(
base_path / dataset_type /
f"{int(sampNo)}_{model_name}_{dataset_type}_{'segment' if not is_event else 'event'}_result.csv",
index=False)
df_metrics.loc[index + 1] = np.NAN
df_metrics.loc[index + 1]["sampNo"] = str(int(sampNo))
df_metrics.loc[index + 1]["support"] = df.shape[0]
df_metrics.loc[index + 1]["severity"] = severity_label[str(int(sampNo))]
# if dataset_type == "independence_test" or dataset_type == "train_all_test":
# continue
# else:
df_metrics.loc[index + 1]["origin_P"] = df[df["thresh_label"] == 1].shape[0]
df_metrics.loc[index + 1]["origin_N"] = df[df["thresh_label"] == 0].shape[0]
df_metrics.loc[index + 1]["pred_P"] = df[df["thresh_Pred"] == 1].shape[0]
df_metrics.loc[index + 1]["pred_N"] = df[df["thresh_Pred"] == 0].shape[0]
df_metrics.loc[index + 1]["T"] = df[df["thresh_Pred"] == df["thresh_label"]].shape[0]
df_metrics.loc[index + 1]["F"] = df[df["thresh_Pred"] != df["thresh_label"]].shape[0]
df_metrics.loc[index + 1]["TP"] = \
df[(df["thresh_Pred"] == df["thresh_label"]) & (df["thresh_Pred"] == 1)].shape[0]
df_metrics.loc[index + 1]["FP"] = \
df[(df["thresh_Pred"] != df["thresh_label"]) & (df["thresh_Pred"] == 1)].shape[0]
df_metrics.loc[index + 1]["TN"] = \
df[(df["thresh_Pred"] == df["thresh_label"]) & (df["thresh_Pred"] == 0)].shape[0]
df_metrics.loc[index + 1]["FN"] = \
df[(df["thresh_Pred"] != df["thresh_label"]) & (df["thresh_Pred"] == 0)].shape[0]
df_metrics.loc[0]["origin_P"] += df_metrics.loc[index + 1]["origin_P"]
df_metrics.loc[0]["origin_N"] += df_metrics.loc[index + 1]["origin_N"]
df_metrics.loc[0]["pred_P"] += df_metrics.loc[index + 1]["pred_P"]
df_metrics.loc[0]["pred_N"] += df_metrics.loc[index + 1]["pred_N"]
df_metrics.loc[0]["T"] += df_metrics.loc[index + 1]["T"]
df_metrics.loc[0]["F"] += df_metrics.loc[index + 1]["F"]
df_metrics.loc[0]["TP"] += df_metrics.loc[index + 1]["TP"]
df_metrics.loc[0]["FP"] += df_metrics.loc[index + 1]["FP"]
df_metrics.loc[0]["TN"] += df_metrics.loc[index + 1]["TN"]
df_metrics.loc[0]["FN"] += df_metrics.loc[index + 1]["FN"]
df_metrics.loc[0]["support"] += df_metrics.loc[index + 1]["support"]
for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN"]:
df_metrics.loc[index + 1][col] = df_metrics.loc[index + 1][col] if df_metrics.loc[index + 1][
col] != 0 else np.NAN
df_metrics.loc[index + 1]["acc"] = df_metrics.iloc[index + 1]["T"] / df_metrics.iloc[index + 1]["support"]
df_metrics.loc[index + 1]["recall"] = df_metrics.iloc[index + 1]["TP"] / df_metrics.iloc[index + 1]["origin_P"]
df_metrics.loc[index + 1]["spec"] = df_metrics.iloc[index + 1]["TN"] / df_metrics.iloc[index + 1]["origin_N"]
df_metrics.loc[index + 1]["pre"] = df_metrics.iloc[index + 1]["TP"] / df_metrics.iloc[index + 1]["pred_P"]
df_metrics.loc[index + 1]["NPV"] = df_metrics.iloc[index + 1]["TN"] / df_metrics.iloc[index + 1]["pred_N"]
df_metrics.loc[index + 1]["F1score"] = 2 * df_metrics.iloc[index + 1]["recall"] * df_metrics.iloc[index + 1][
"pre"] / (df_metrics.iloc[index + 1]["recall"] + df_metrics.iloc[index + 1]["pre"])
for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN", "acc", "recall",
"spec", "pre", "NPV", "F1score"]:
df_metrics.loc[index + 1][col] = 0 if pd.isna(df_metrics.loc[index + 1][col]) else \
df_metrics.loc[index + 1][col]
df_metrics.loc[index + 1][col] = round(df_metrics.loc[index + 1][col], 3)
# if dataset_type == "independence_test" or dataset_type == "train_all_test":
# return None
for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN"]:
df_metrics.loc[0][col] = df_metrics.loc[0][col] if df_metrics.loc[0][col] != 0 else np.NAN
df_metrics.loc[0]["acc"] = df_metrics.iloc[0]["T"] / df_metrics.iloc[0]["support"]
df_metrics.loc[0]["recall"] = df_metrics.iloc[0]["TP"] / df_metrics.iloc[0]["origin_P"]
df_metrics.loc[0]["spec"] = df_metrics.iloc[0]["TN"] / df_metrics.iloc[0]["origin_N"]
df_metrics.loc[0]["pre"] = df_metrics.iloc[0]["TP"] / df_metrics.iloc[0]["pred_P"]
df_metrics.loc[0]["NPV"] = df_metrics.iloc[0]["TN"] / df_metrics.iloc[0]["pred_N"]
df_metrics.loc[0]["F1score"] = 2 * df_metrics.iloc[0]["recall"] * df_metrics.iloc[0]["pre"] / (
df_metrics.iloc[0]["recall"] + df_metrics.iloc[0]["pre"])
for col in ["TP", "TN", "FP", "FN", "acc", "recall", "spec", "pre", "NPV", "F1score"]:
df_metrics.loc[0][col] = 0 if pd.isna(df_metrics.loc[0][col]) else df_metrics.loc[0][col]
df_metrics.loc[0][col] = round(df_metrics.loc[0][col], 3)
# 在inner_test中根据 分严重程度绘制
if dataset_type == "test":
all_severity = ["正常", "轻度", "中度", "重度"]
for index, severity in enumerate(all_severity):
df_event = df_metrics[df_metrics["severity"] == severity]
df_temp = pd.DataFrame(columns=columns2)
df_temp.loc[0] = 0
df_temp.loc[0]["sampNo"] = severity
df_temp.loc[0]["severity"] = str(index + 1)
df_temp.loc[0]["origin_P"] += df_event["origin_P"].sum()
df_temp.loc[0]["origin_N"] += df_event["origin_N"].sum()
df_temp.loc[0]["pred_P"] += df_event["pred_P"].sum()
df_temp.loc[0]["pred_N"] += df_event["pred_N"].sum()
df_temp.loc[0]["T"] += df_event["T"].sum()
df_temp.loc[0]["F"] += df_event["F"].sum()
df_temp.loc[0]["TP"] += df_event["TP"].sum()
df_temp.loc[0]["FP"] += df_event["FP"].sum()
df_temp.loc[0]["TN"] += df_event["TN"].sum()
df_temp.loc[0]["FN"] += df_event["FN"].sum()
df_temp.loc[0]["support"] += df_event["support"].sum()
for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN"]:
df_temp.loc[0][col] = df_temp.loc[0][col] if df_temp.loc[0][col] != 0 else np.NAN
df_temp.loc[0]["acc"] = df_temp.iloc[0]["T"] / df_temp.iloc[0]["support"]
df_temp.loc[0]["recall"] = df_temp.iloc[0]["TP"] / df_temp.iloc[0]["origin_P"]
df_temp.loc[0]["spec"] = df_temp.iloc[0]["TN"] / df_temp.iloc[0]["origin_N"]
df_temp.loc[0]["pre"] = df_temp.iloc[0]["TP"] / df_temp.iloc[0]["pred_P"]
df_temp.loc[0]["NPV"] = df_temp.iloc[0]["TN"] / df_temp.iloc[0]["pred_N"]
df_temp.loc[0]["F1score"] = 2 * df_temp.iloc[0]["recall"] * df_temp.iloc[0]["pre"] / (
df_temp.iloc[0]["recall"] + df_temp.iloc[0]["pre"])
for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN", "acc", "recall",
"spec", "pre", "NPV", "F1score"]:
df_temp.loc[0][col] = 0 if pd.isna(df_temp.loc[0][col]) else df_temp.loc[0][col]
df_temp.loc[0][col] = round(df_temp.loc[0][col], 3)
df_metrics = df_metrics.append(df_temp, ignore_index=True)
df_backup = df_metrics
df_metrics = df_metrics.astype("str")
df_metrics = df_metrics.sort_values("severity")
df_metrics.to_csv(base_path / dataset_type /
f"{model_name}_{dataset_type}_{'segment' if not is_event else 'event'}_all_metrics.csv",
index=False, encoding="gbk")
return df_backup
def confusionMatrix(df_analysis, base_path, dataset_type):
if df_analysis is None:
logger.info(base_path, dataset_type, "is None")
return
if df_analysis.empty:
logger.info(base_path, dataset_type, "is_empty")
return
classes = ["normal", "SA"]
(base_path / dataset_type / "confusionMatrix").mkdir(exist_ok=True, parents=True)
for one_samp in df_analysis.index:
one_samp = df_analysis.loc[one_samp]
cm = np.array([[one_samp["TN"], one_samp["FP"]], [one_samp["FN"], one_samp["TP"]]])
draw_confusionMatrix(cm, classes=classes, title=str(one_samp["severity"]) + " " + one_samp["sampNo"],
save_path=base_path / dataset_type / "confusionMatrix" / f"{one_samp['sampNo']}.jpg")
def segment_to_event(df_segment, dataset_type):
df_all_event = pd.DataFrame(columns=columns)
all_sampNo = df_segment["sampNo"].unique()
if dataset_type == "test":
for index, sampNo in enumerate(all_sampNo):
df_event = pd.DataFrame(columns=columns)
df = df_segment[df_segment["sampNo"] == sampNo].copy()
df["thresh_label"] = 1 * (df["label_type"] > event_thresh)
df["thresh_Pred"] = 1 * (df["pred"] > thresh)
all_segments_no = df["segmentNo"].unique()
for index_se, segment_No in enumerate(all_segments_no):
df_temp = df[df["segmentNo"] == segment_No].copy()
SP = df_temp.iloc[0]["EP"]
EP = df_temp.iloc[-1]["EP"] + 1
df_event.loc[index_se] = [int(sampNo), segment_No, df_temp.iloc[0]["label_type"],
df_temp.iloc[0]["new_label"], SP, EP, 0]
thresh_Pred = df_temp["thresh_Pred"].values
thresh_Pred2 = thresh_Pred.copy()
# 扩充
for index_pred, pred in enumerate(thresh_Pred):
if pred == 0:
continue
for interval in range(1, thresh_event_interval):
if pred == 1 and index_pred + interval < thresh_Pred.size:
thresh_Pred2[index_pred + interval] = 1
else:
continue
# 判断
same_ar = np.concatenate(([True], thresh_Pred2[:-1] != thresh_Pred2[1:], [True]))
index_ar = np.where(same_ar)[0]
count_ar = np.diff(index_ar)
value_ar = thresh_Pred2[same_ar[:-1]] * count_ar
for i in value_ar:
if i > thresh_event_length:
df_event.iloc[index_se]["pred"] = 1
# df_event.to_csv(events_results / dataset_type / f"{int(sampNo)}_event_results.csv", index=False,
# encoding="gbk")
df_all_event = df_all_event.append(df_event, ignore_index=True)
else:
for index, sampNo in enumerate(all_sampNo):
df_event = pd.DataFrame(columns=columns)
df = df_segment[df_segment["sampNo"] == sampNo].copy()
df["thresh_label"] = 1 * (df["label_type"] > event_thresh)
df["thresh_Pred"] = 1 * (df["pred"] > thresh)
thresh_Pred = df["thresh_Pred"].values
thresh_Pred2 = thresh_Pred.copy()
# 扩充
for index_pred, pred in enumerate(thresh_Pred):
if pred == 0:
continue
for interval in range(1, thresh_event_interval):
if pred == 1 and index_pred + interval < thresh_Pred.size:
thresh_Pred2[index_pred + interval] = 1
else:
continue
# 判断
same_ar = np.concatenate(([True], thresh_Pred2[:-1] != thresh_Pred2[1:], [True]))
index_ar = np.where(same_ar)[0]
count_ar = np.diff(index_ar)
value_ar = thresh_Pred2[same_ar[:-1]] * count_ar
for value_index, value in enumerate(value_ar):
SP = index_ar[value_index]
EP = index_ar[value_index] + count_ar[value_index]
# TP, FP
if value > thresh_event_length:
# label_type = 1 if thresh_Pred2[SP:EP].sum() > 0 else 0
label_type = df["label_type"][SP:EP].max()
new_label = df["new_label"][SP:EP].max()
df_event = df_event.append(pd.DataFrame([[int(sampNo), SP // 30, label_type, new_label,
SP, EP, thresh_Pred2[SP]]], columns=columns),
ignore_index=True)
# if value > 30:
# logger.info([int(sampNo), SP // 30, label_type, new_label, SP, EP, thresh_Pred2[SP]])
# 长度不够
else:
df["thresh_Pred"][SP:EP] = 0
# 对负样本进行统计
for segment_no in df["segmentNo"].unique():
df_temp = df[df["segmentNo"] == segment_no]
if df_temp["thresh_Pred"].sum() > 0:
continue
df_event = df_event.append(pd.DataFrame(
[[int(sampNo), segment_no, df_temp["label_type"].max(), df_temp["new_label"].max(), segment_no * 30,
(segment_no + 1) * 30, 0]], columns=columns),
ignore_index=True)
df_all_event = df_all_event.append(df_event, ignore_index=True)
df_temp = df_all_event.loc[:, ["label_type", "pred"]]
df_all_event["thresh_label"] = 1 * (df_temp["label_type"] > event_thresh)
df_all_event["thresh_Pred"] = 1 * (df_temp["pred"] > thresh)
return df_all_event
# 分sampNo保存结果并不重合地可视化
# inner_test
# 分sampNo将与标签不一致的另行保存并不重合地可视化
# import shap
# explainer = shap.TreeExplainer()
# shap_values = explainer.shap_values()
if __name__ == '__main__':
all_output_path = list(exam_path.rglob("KFold_*"))
for exam_index, test_exam_path in enumerate(all_output_path):
# test_exam_path = exam_path / test_exam_path
set_environment(exam_index)
test_and_analysis_and_visual(dataset_type="test")
test_and_analysis_and_visual(dataset_type="all_test")