478 lines
21 KiB
Python
478 lines
21 KiB
Python
|
#!/usr/bin/python
|
|||
|
# -*- coding: UTF-8 -*-
|
|||
|
"""
|
|||
|
@author:Marques
|
|||
|
@file:test_analysis.py
|
|||
|
@email:admin@marques22.com
|
|||
|
@email:2021022362@m.scnu.edu.cn
|
|||
|
@time:2022/02/21
|
|||
|
"""
|
|||
|
import logging
|
|||
|
import os
|
|||
|
import sys
|
|||
|
|
|||
|
import pandas as pd
|
|||
|
import torch.cuda
|
|||
|
import numpy as np
|
|||
|
import yaml
|
|||
|
from matplotlib import pyplot as plt
|
|||
|
from tqdm import tqdm
|
|||
|
from pathlib import Path
|
|||
|
from torch.nn import functional as F
|
|||
|
from torch.utils.data import DataLoader
|
|||
|
from load_dataset import TestApneaDataset2, read_dataset
|
|||
|
from utils.Draw_ConfusionMatrix import draw_confusionMatrix
|
|||
|
from torch import nn
|
|||
|
from utils.calc_metrics import CALC_METRICS
|
|||
|
from my_augment import my_augment, my_segment_augment
|
|||
|
from model.Hybrid_Net018 import HYBRIDNET018
|
|||
|
|
|||
|
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
|||
|
exam_path = Path("./output/")
|
|||
|
|
|||
|
# 置信率阈值
|
|||
|
thresh = 0.5
|
|||
|
# 间隔最小距离
|
|||
|
thresh_event_interval = 0
|
|||
|
# 最小事件长度
|
|||
|
thresh_event_length = 2
|
|||
|
|
|||
|
#
|
|||
|
event_thresh = 1
|
|||
|
|
|||
|
severity_path = Path(r"/home/marques/code/marques/apnea/dataset/loc_first_csa.xlsx")
|
|||
|
severity_label = {"all": "none"}
|
|||
|
severity_df = pd.read_excel(severity_path)
|
|||
|
for one_data in severity_df.index:
|
|||
|
one_data = severity_df.loc[one_data]
|
|||
|
severity_label[str(one_data["数据编号"])] = one_data["程度"]
|
|||
|
|
|||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|||
|
gpu = torch.cuda.is_available()
|
|||
|
|
|||
|
num_classes = 1
|
|||
|
calc_metrics = CALC_METRICS(num_classes)
|
|||
|
|
|||
|
with open("./settings.yaml") as f:
|
|||
|
hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
|
|||
|
data_path = hyp["Path"]["dataset"]
|
|||
|
read_dataset(data_path, augment=my_augment)
|
|||
|
del hyp
|
|||
|
|
|||
|
# 默认取最新的文件夹
|
|||
|
all_output_path, output_path, segments_results_save_path, events_results_save_path, = [None, ] * 4
|
|||
|
my_augment, model_path, label_path, data_path, model, model_name = [None, ] * 6
|
|||
|
train_set, test_set = None, None
|
|||
|
loss_func = nn.CrossEntropyLoss()
|
|||
|
|
|||
|
columns = ["sampNo", "segmentNo", "label_type", "new_label", "SP", "EP", "pred"]
|
|||
|
columns2 = ["sampNo", "severity", "origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN",
|
|||
|
"acc", "recall", "spec", "pre", "NPV", "F1score", "support"]
|
|||
|
|
|||
|
logging.getLogger('matplotlib.font_manager').disabled = True
|
|||
|
logging.getLogger('matplotlib.ticker').disabled = True
|
|||
|
logger = logging.getLogger()
|
|||
|
logger.setLevel(logging.INFO)
|
|||
|
ch = logging.StreamHandler()
|
|||
|
ch.setLevel(logging.INFO)
|
|||
|
ch.setFormatter(logging.Formatter("%(asctime)s: %(message)s"))
|
|||
|
logger.addHandler(ch)
|
|||
|
|
|||
|
if (exam_path / "test.log").exists():
|
|||
|
(exam_path / "test.log").unlink()
|
|||
|
fh = logging.FileHandler(exam_path / "test.log", mode='a')
|
|||
|
fh.setLevel(logging.INFO)
|
|||
|
fh.setFormatter(logging.Formatter("%(message)s"))
|
|||
|
logger.addHandler(fh)
|
|||
|
logger.info("------------------------------------")
|
|||
|
|
|||
|
|
|||
|
def set_environment(i):
|
|||
|
global output_path, segments_results_save_path, events_results_save_path, model_path, label_path, data_path, \
|
|||
|
model, model_name, train_set, test_set
|
|||
|
|
|||
|
output_path = all_output_path[i]
|
|||
|
logger.info(output_path)
|
|||
|
segments_results_save_path = (output_path / "segments_results")
|
|||
|
segments_results_save_path.mkdir(exist_ok=True)
|
|||
|
events_results_save_path = (output_path / "events_results")
|
|||
|
events_results_save_path.mkdir(exist_ok=True)
|
|||
|
|
|||
|
# 加载配置
|
|||
|
with open(output_path / "settings.yaml") as f:
|
|||
|
hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
|
|||
|
data_path = hyp["Path"]["dataset"]
|
|||
|
label_path = hyp["Path"]["label"]
|
|||
|
train_set = hyp["train_set"]
|
|||
|
test_set = hyp["test_set"]
|
|||
|
|
|||
|
model_path = output_path / "weights" / "best.pt"
|
|||
|
model = eval(hyp["model_name"])()
|
|||
|
model_name = hyp["model_name"]
|
|||
|
model.load_state_dict(torch.load(model_path))
|
|||
|
model.cuda()
|
|||
|
model.eval()
|
|||
|
|
|||
|
|
|||
|
def test_and_analysis_and_visual(dataset_type):
|
|||
|
if dataset_type == "test":
|
|||
|
sampNo = train_set
|
|||
|
elif dataset_type == "all_test":
|
|||
|
sampNo = test_set
|
|||
|
else:
|
|||
|
sampNo = None
|
|||
|
logger.info("出错了")
|
|||
|
|
|||
|
exam_name = Path("./").absolute().name
|
|||
|
|
|||
|
test_dataset = TestApneaDataset2(data_path, label_path, select_sampno=sampNo, dataset_type=dataset_type,
|
|||
|
segment_augment=my_segment_augment)
|
|||
|
test_loader = DataLoader(test_dataset, batch_size=128, pin_memory=True, num_workers=0)
|
|||
|
|
|||
|
test_loss = 0.0
|
|||
|
|
|||
|
df_segment = pd.DataFrame(columns=columns)
|
|||
|
|
|||
|
for one in tqdm(test_loader, total=len(test_loader)):
|
|||
|
resp, labels = one[:2]
|
|||
|
other_info = one[2:]
|
|||
|
resp = resp.float().cuda() if gpu else resp.float()
|
|||
|
labels = labels.cuda() if gpu else labels
|
|||
|
with torch.no_grad():
|
|||
|
out = model(resp)
|
|||
|
|
|||
|
loss = loss_func(out, labels)
|
|||
|
|
|||
|
test_loss += loss.item()
|
|||
|
|
|||
|
labels = torch.unsqueeze(labels, dim=1)
|
|||
|
out = F.softmax(out, dim=1)
|
|||
|
out = torch.unsqueeze(out[:, 1], dim=1)
|
|||
|
|
|||
|
calc_metrics.update(out.cpu(), labels.cpu())
|
|||
|
# one[0] = list(one[0].cpu().numpy())
|
|||
|
# one[1] = list(one[1].cpu().numpy())
|
|||
|
# one = one[1:]
|
|||
|
# out = out.view(1, -1).cpu().numpy().tolist()
|
|||
|
# one += out
|
|||
|
# result_record += [i for i in list(np.array(one, dtype=object).transpose(1, 0))]
|
|||
|
|
|||
|
one2 = np.array([i.cpu().numpy() for i in (other_info + [out.squeeze()])])
|
|||
|
one2 = one2.transpose((1, 0))
|
|||
|
df = pd.DataFrame(data=one2, columns=columns)
|
|||
|
df_segment = df_segment.append(df, ignore_index=True)
|
|||
|
|
|||
|
test_loss /= len(test_loader)
|
|||
|
calc_metrics.compute()
|
|||
|
logger.info(f"EXAM_NAME: {exam_name} SampNO: {sampNo}")
|
|||
|
logger.info(calc_metrics.get_matrix(loss=test_loss, epoch=0, epoch_type="test"))
|
|||
|
calc_metrics.reset()
|
|||
|
|
|||
|
df_segment["thresh_label"] = 1 * (df_segment["label_type"] > event_thresh).copy()
|
|||
|
df_segment["thresh_Pred"] = 1 * (df_segment["pred"] > thresh).copy()
|
|||
|
df_segment["pred"] = df_segment["pred"].copy().apply(lambda x: round(x, 3))
|
|||
|
|
|||
|
# 片段级分析
|
|||
|
df_segment_metrics = analysis_results(df_segment, segments_results_save_path, dataset_type)
|
|||
|
|
|||
|
# 绘制混淆矩阵
|
|||
|
# 每个样本都绘制一份
|
|||
|
confusionMatrix(df_segment_metrics, segments_results_save_path, dataset_type)
|
|||
|
# 绘制柱状图
|
|||
|
|
|||
|
# 事件级分析
|
|||
|
# 对于inner_test 每个编号就是一个事件
|
|||
|
# 而对于整晚的independence_test,需要另行计算
|
|||
|
df_all_event = segment_to_event(df_segment, dataset_type)
|
|||
|
df_event_metrics = analysis_results(df_all_event, events_results_save_path, dataset_type, is_event=True)
|
|||
|
confusionMatrix(df_event_metrics, events_results_save_path, dataset_type)
|
|||
|
|
|||
|
# 剔除质量不好的样本
|
|||
|
df_bad_segment = df_segment[
|
|||
|
(df_segment["label_type"].isin([2, 3])) & (df_segment["new_label"] == 2)]
|
|||
|
df_select_segment = df_segment.drop(df_bad_segment.index)
|
|||
|
df_select_segment_metrics = analysis_results(df_select_segment, segments_results_save_path / "remove_2",
|
|||
|
dataset_type)
|
|||
|
df_select_event = segment_to_event(df_select_segment, dataset_type)
|
|||
|
df_event_metrics = analysis_results(df_select_event, events_results_save_path / "remove_2", dataset_type,
|
|||
|
is_event=True)
|
|||
|
|
|||
|
|
|||
|
def analysis_results(df_result, base_path, dataset_type, is_event=False):
|
|||
|
if df_result.empty:
|
|||
|
logger.info(base_path, dataset_type, "is_empty")
|
|||
|
return None
|
|||
|
|
|||
|
(base_path / dataset_type).mkdir(exist_ok=True, parents=True)
|
|||
|
|
|||
|
all_sampNo = df_result["sampNo"].unique()
|
|||
|
df_metrics = pd.DataFrame(columns=columns2)
|
|||
|
df_metrics.loc[0] = 0
|
|||
|
df_metrics.loc[0]["sampNo"] = dataset_type
|
|||
|
|
|||
|
for index, sampNo in enumerate(all_sampNo):
|
|||
|
df = df_result[df_result["sampNo"] == sampNo]
|
|||
|
df.to_csv(
|
|||
|
base_path / dataset_type /
|
|||
|
f"{int(sampNo)}_{model_name}_{dataset_type}_{'segment' if not is_event else 'event'}_result.csv",
|
|||
|
index=False)
|
|||
|
|
|||
|
df_metrics.loc[index + 1] = np.NAN
|
|||
|
df_metrics.loc[index + 1]["sampNo"] = str(int(sampNo))
|
|||
|
df_metrics.loc[index + 1]["support"] = df.shape[0]
|
|||
|
df_metrics.loc[index + 1]["severity"] = severity_label[str(int(sampNo))]
|
|||
|
|
|||
|
# if dataset_type == "independence_test" or dataset_type == "train_all_test":
|
|||
|
# continue
|
|||
|
# else:
|
|||
|
df_metrics.loc[index + 1]["origin_P"] = df[df["thresh_label"] == 1].shape[0]
|
|||
|
df_metrics.loc[index + 1]["origin_N"] = df[df["thresh_label"] == 0].shape[0]
|
|||
|
df_metrics.loc[index + 1]["pred_P"] = df[df["thresh_Pred"] == 1].shape[0]
|
|||
|
df_metrics.loc[index + 1]["pred_N"] = df[df["thresh_Pred"] == 0].shape[0]
|
|||
|
df_metrics.loc[index + 1]["T"] = df[df["thresh_Pred"] == df["thresh_label"]].shape[0]
|
|||
|
df_metrics.loc[index + 1]["F"] = df[df["thresh_Pred"] != df["thresh_label"]].shape[0]
|
|||
|
df_metrics.loc[index + 1]["TP"] = \
|
|||
|
df[(df["thresh_Pred"] == df["thresh_label"]) & (df["thresh_Pred"] == 1)].shape[0]
|
|||
|
df_metrics.loc[index + 1]["FP"] = \
|
|||
|
df[(df["thresh_Pred"] != df["thresh_label"]) & (df["thresh_Pred"] == 1)].shape[0]
|
|||
|
df_metrics.loc[index + 1]["TN"] = \
|
|||
|
df[(df["thresh_Pred"] == df["thresh_label"]) & (df["thresh_Pred"] == 0)].shape[0]
|
|||
|
df_metrics.loc[index + 1]["FN"] = \
|
|||
|
df[(df["thresh_Pred"] != df["thresh_label"]) & (df["thresh_Pred"] == 0)].shape[0]
|
|||
|
|
|||
|
df_metrics.loc[0]["origin_P"] += df_metrics.loc[index + 1]["origin_P"]
|
|||
|
df_metrics.loc[0]["origin_N"] += df_metrics.loc[index + 1]["origin_N"]
|
|||
|
df_metrics.loc[0]["pred_P"] += df_metrics.loc[index + 1]["pred_P"]
|
|||
|
df_metrics.loc[0]["pred_N"] += df_metrics.loc[index + 1]["pred_N"]
|
|||
|
df_metrics.loc[0]["T"] += df_metrics.loc[index + 1]["T"]
|
|||
|
df_metrics.loc[0]["F"] += df_metrics.loc[index + 1]["F"]
|
|||
|
df_metrics.loc[0]["TP"] += df_metrics.loc[index + 1]["TP"]
|
|||
|
df_metrics.loc[0]["FP"] += df_metrics.loc[index + 1]["FP"]
|
|||
|
df_metrics.loc[0]["TN"] += df_metrics.loc[index + 1]["TN"]
|
|||
|
df_metrics.loc[0]["FN"] += df_metrics.loc[index + 1]["FN"]
|
|||
|
df_metrics.loc[0]["support"] += df_metrics.loc[index + 1]["support"]
|
|||
|
|
|||
|
for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN"]:
|
|||
|
df_metrics.loc[index + 1][col] = df_metrics.loc[index + 1][col] if df_metrics.loc[index + 1][
|
|||
|
col] != 0 else np.NAN
|
|||
|
|
|||
|
df_metrics.loc[index + 1]["acc"] = df_metrics.iloc[index + 1]["T"] / df_metrics.iloc[index + 1]["support"]
|
|||
|
df_metrics.loc[index + 1]["recall"] = df_metrics.iloc[index + 1]["TP"] / df_metrics.iloc[index + 1]["origin_P"]
|
|||
|
df_metrics.loc[index + 1]["spec"] = df_metrics.iloc[index + 1]["TN"] / df_metrics.iloc[index + 1]["origin_N"]
|
|||
|
df_metrics.loc[index + 1]["pre"] = df_metrics.iloc[index + 1]["TP"] / df_metrics.iloc[index + 1]["pred_P"]
|
|||
|
df_metrics.loc[index + 1]["NPV"] = df_metrics.iloc[index + 1]["TN"] / df_metrics.iloc[index + 1]["pred_N"]
|
|||
|
df_metrics.loc[index + 1]["F1score"] = 2 * df_metrics.iloc[index + 1]["recall"] * df_metrics.iloc[index + 1][
|
|||
|
"pre"] / (df_metrics.iloc[index + 1]["recall"] + df_metrics.iloc[index + 1]["pre"])
|
|||
|
for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN", "acc", "recall",
|
|||
|
"spec", "pre", "NPV", "F1score"]:
|
|||
|
df_metrics.loc[index + 1][col] = 0 if pd.isna(df_metrics.loc[index + 1][col]) else \
|
|||
|
df_metrics.loc[index + 1][col]
|
|||
|
df_metrics.loc[index + 1][col] = round(df_metrics.loc[index + 1][col], 3)
|
|||
|
|
|||
|
# if dataset_type == "independence_test" or dataset_type == "train_all_test":
|
|||
|
# return None
|
|||
|
for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN"]:
|
|||
|
df_metrics.loc[0][col] = df_metrics.loc[0][col] if df_metrics.loc[0][col] != 0 else np.NAN
|
|||
|
|
|||
|
df_metrics.loc[0]["acc"] = df_metrics.iloc[0]["T"] / df_metrics.iloc[0]["support"]
|
|||
|
df_metrics.loc[0]["recall"] = df_metrics.iloc[0]["TP"] / df_metrics.iloc[0]["origin_P"]
|
|||
|
df_metrics.loc[0]["spec"] = df_metrics.iloc[0]["TN"] / df_metrics.iloc[0]["origin_N"]
|
|||
|
df_metrics.loc[0]["pre"] = df_metrics.iloc[0]["TP"] / df_metrics.iloc[0]["pred_P"]
|
|||
|
df_metrics.loc[0]["NPV"] = df_metrics.iloc[0]["TN"] / df_metrics.iloc[0]["pred_N"]
|
|||
|
df_metrics.loc[0]["F1score"] = 2 * df_metrics.iloc[0]["recall"] * df_metrics.iloc[0]["pre"] / (
|
|||
|
df_metrics.iloc[0]["recall"] + df_metrics.iloc[0]["pre"])
|
|||
|
for col in ["TP", "TN", "FP", "FN", "acc", "recall", "spec", "pre", "NPV", "F1score"]:
|
|||
|
df_metrics.loc[0][col] = 0 if pd.isna(df_metrics.loc[0][col]) else df_metrics.loc[0][col]
|
|||
|
df_metrics.loc[0][col] = round(df_metrics.loc[0][col], 3)
|
|||
|
|
|||
|
# 在inner_test中根据 分严重程度绘制
|
|||
|
if dataset_type == "test":
|
|||
|
all_severity = ["正常", "轻度", "中度", "重度"]
|
|||
|
for index, severity in enumerate(all_severity):
|
|||
|
df_event = df_metrics[df_metrics["severity"] == severity]
|
|||
|
df_temp = pd.DataFrame(columns=columns2)
|
|||
|
df_temp.loc[0] = 0
|
|||
|
df_temp.loc[0]["sampNo"] = severity
|
|||
|
df_temp.loc[0]["severity"] = str(index + 1)
|
|||
|
|
|||
|
df_temp.loc[0]["origin_P"] += df_event["origin_P"].sum()
|
|||
|
df_temp.loc[0]["origin_N"] += df_event["origin_N"].sum()
|
|||
|
df_temp.loc[0]["pred_P"] += df_event["pred_P"].sum()
|
|||
|
df_temp.loc[0]["pred_N"] += df_event["pred_N"].sum()
|
|||
|
df_temp.loc[0]["T"] += df_event["T"].sum()
|
|||
|
df_temp.loc[0]["F"] += df_event["F"].sum()
|
|||
|
df_temp.loc[0]["TP"] += df_event["TP"].sum()
|
|||
|
df_temp.loc[0]["FP"] += df_event["FP"].sum()
|
|||
|
df_temp.loc[0]["TN"] += df_event["TN"].sum()
|
|||
|
df_temp.loc[0]["FN"] += df_event["FN"].sum()
|
|||
|
df_temp.loc[0]["support"] += df_event["support"].sum()
|
|||
|
|
|||
|
for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN"]:
|
|||
|
df_temp.loc[0][col] = df_temp.loc[0][col] if df_temp.loc[0][col] != 0 else np.NAN
|
|||
|
|
|||
|
df_temp.loc[0]["acc"] = df_temp.iloc[0]["T"] / df_temp.iloc[0]["support"]
|
|||
|
df_temp.loc[0]["recall"] = df_temp.iloc[0]["TP"] / df_temp.iloc[0]["origin_P"]
|
|||
|
df_temp.loc[0]["spec"] = df_temp.iloc[0]["TN"] / df_temp.iloc[0]["origin_N"]
|
|||
|
df_temp.loc[0]["pre"] = df_temp.iloc[0]["TP"] / df_temp.iloc[0]["pred_P"]
|
|||
|
df_temp.loc[0]["NPV"] = df_temp.iloc[0]["TN"] / df_temp.iloc[0]["pred_N"]
|
|||
|
df_temp.loc[0]["F1score"] = 2 * df_temp.iloc[0]["recall"] * df_temp.iloc[0]["pre"] / (
|
|||
|
df_temp.iloc[0]["recall"] + df_temp.iloc[0]["pre"])
|
|||
|
|
|||
|
for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN", "acc", "recall",
|
|||
|
"spec", "pre", "NPV", "F1score"]:
|
|||
|
df_temp.loc[0][col] = 0 if pd.isna(df_temp.loc[0][col]) else df_temp.loc[0][col]
|
|||
|
df_temp.loc[0][col] = round(df_temp.loc[0][col], 3)
|
|||
|
|
|||
|
df_metrics = df_metrics.append(df_temp, ignore_index=True)
|
|||
|
|
|||
|
df_backup = df_metrics
|
|||
|
df_metrics = df_metrics.astype("str")
|
|||
|
df_metrics = df_metrics.sort_values("severity")
|
|||
|
df_metrics.to_csv(base_path / dataset_type /
|
|||
|
f"{model_name}_{dataset_type}_{'segment' if not is_event else 'event'}_all_metrics.csv",
|
|||
|
index=False, encoding="gbk")
|
|||
|
|
|||
|
return df_backup
|
|||
|
|
|||
|
|
|||
|
def confusionMatrix(df_analysis, base_path, dataset_type):
|
|||
|
if df_analysis is None:
|
|||
|
logger.info(base_path, dataset_type, "is None")
|
|||
|
return
|
|||
|
|
|||
|
if df_analysis.empty:
|
|||
|
logger.info(base_path, dataset_type, "is_empty")
|
|||
|
return
|
|||
|
classes = ["normal", "SA"]
|
|||
|
(base_path / dataset_type / "confusionMatrix").mkdir(exist_ok=True, parents=True)
|
|||
|
for one_samp in df_analysis.index:
|
|||
|
one_samp = df_analysis.loc[one_samp]
|
|||
|
cm = np.array([[one_samp["TN"], one_samp["FP"]], [one_samp["FN"], one_samp["TP"]]])
|
|||
|
draw_confusionMatrix(cm, classes=classes, title=str(one_samp["severity"]) + " " + one_samp["sampNo"],
|
|||
|
save_path=base_path / dataset_type / "confusionMatrix" / f"{one_samp['sampNo']}.jpg")
|
|||
|
|
|||
|
|
|||
|
def segment_to_event(df_segment, dataset_type):
|
|||
|
df_all_event = pd.DataFrame(columns=columns)
|
|||
|
all_sampNo = df_segment["sampNo"].unique()
|
|||
|
|
|||
|
if dataset_type == "test":
|
|||
|
for index, sampNo in enumerate(all_sampNo):
|
|||
|
df_event = pd.DataFrame(columns=columns)
|
|||
|
df = df_segment[df_segment["sampNo"] == sampNo].copy()
|
|||
|
df["thresh_label"] = 1 * (df["label_type"] > event_thresh)
|
|||
|
df["thresh_Pred"] = 1 * (df["pred"] > thresh)
|
|||
|
all_segments_no = df["segmentNo"].unique()
|
|||
|
|
|||
|
for index_se, segment_No in enumerate(all_segments_no):
|
|||
|
df_temp = df[df["segmentNo"] == segment_No].copy()
|
|||
|
SP = df_temp.iloc[0]["EP"]
|
|||
|
EP = df_temp.iloc[-1]["EP"] + 1
|
|||
|
df_event.loc[index_se] = [int(sampNo), segment_No, df_temp.iloc[0]["label_type"],
|
|||
|
df_temp.iloc[0]["new_label"], SP, EP, 0]
|
|||
|
|
|||
|
thresh_Pred = df_temp["thresh_Pred"].values
|
|||
|
thresh_Pred2 = thresh_Pred.copy()
|
|||
|
|
|||
|
# 扩充
|
|||
|
for index_pred, pred in enumerate(thresh_Pred):
|
|||
|
if pred == 0:
|
|||
|
continue
|
|||
|
|
|||
|
for interval in range(1, thresh_event_interval):
|
|||
|
if pred == 1 and index_pred + interval < thresh_Pred.size:
|
|||
|
thresh_Pred2[index_pred + interval] = 1
|
|||
|
else:
|
|||
|
continue
|
|||
|
|
|||
|
# 判断
|
|||
|
same_ar = np.concatenate(([True], thresh_Pred2[:-1] != thresh_Pred2[1:], [True]))
|
|||
|
index_ar = np.where(same_ar)[0]
|
|||
|
count_ar = np.diff(index_ar)
|
|||
|
value_ar = thresh_Pred2[same_ar[:-1]] * count_ar
|
|||
|
for i in value_ar:
|
|||
|
if i > thresh_event_length:
|
|||
|
df_event.iloc[index_se]["pred"] = 1
|
|||
|
|
|||
|
# df_event.to_csv(events_results / dataset_type / f"{int(sampNo)}_event_results.csv", index=False,
|
|||
|
# encoding="gbk")
|
|||
|
df_all_event = df_all_event.append(df_event, ignore_index=True)
|
|||
|
else:
|
|||
|
for index, sampNo in enumerate(all_sampNo):
|
|||
|
df_event = pd.DataFrame(columns=columns)
|
|||
|
df = df_segment[df_segment["sampNo"] == sampNo].copy()
|
|||
|
df["thresh_label"] = 1 * (df["label_type"] > event_thresh)
|
|||
|
df["thresh_Pred"] = 1 * (df["pred"] > thresh)
|
|||
|
thresh_Pred = df["thresh_Pred"].values
|
|||
|
thresh_Pred2 = thresh_Pred.copy()
|
|||
|
# 扩充
|
|||
|
for index_pred, pred in enumerate(thresh_Pred):
|
|||
|
if pred == 0:
|
|||
|
continue
|
|||
|
|
|||
|
for interval in range(1, thresh_event_interval):
|
|||
|
if pred == 1 and index_pred + interval < thresh_Pred.size:
|
|||
|
thresh_Pred2[index_pred + interval] = 1
|
|||
|
else:
|
|||
|
continue
|
|||
|
|
|||
|
# 判断
|
|||
|
same_ar = np.concatenate(([True], thresh_Pred2[:-1] != thresh_Pred2[1:], [True]))
|
|||
|
index_ar = np.where(same_ar)[0]
|
|||
|
count_ar = np.diff(index_ar)
|
|||
|
value_ar = thresh_Pred2[same_ar[:-1]] * count_ar
|
|||
|
|
|||
|
for value_index, value in enumerate(value_ar):
|
|||
|
SP = index_ar[value_index]
|
|||
|
EP = index_ar[value_index] + count_ar[value_index]
|
|||
|
# TP, FP
|
|||
|
if value > thresh_event_length:
|
|||
|
|
|||
|
# label_type = 1 if thresh_Pred2[SP:EP].sum() > 0 else 0
|
|||
|
label_type = df["label_type"][SP:EP].max()
|
|||
|
new_label = df["new_label"][SP:EP].max()
|
|||
|
df_event = df_event.append(pd.DataFrame([[int(sampNo), SP // 30, label_type, new_label,
|
|||
|
SP, EP, thresh_Pred2[SP]]], columns=columns),
|
|||
|
ignore_index=True)
|
|||
|
# if value > 30:
|
|||
|
# logger.info([int(sampNo), SP // 30, label_type, new_label, SP, EP, thresh_Pred2[SP]])
|
|||
|
# 长度不够
|
|||
|
else:
|
|||
|
df["thresh_Pred"][SP:EP] = 0
|
|||
|
|
|||
|
# 对负样本进行统计
|
|||
|
# for segment_no in df["segmentNo"].unique():
|
|||
|
# df_temp = df[df["segmentNo"] == segment_no]
|
|||
|
# if df_temp["thresh_Pred"].sum() > 0:
|
|||
|
# continue
|
|||
|
#
|
|||
|
# df_event = df_event.append(pd.DataFrame(
|
|||
|
# [[int(sampNo), segment_no, df_temp["label_type"].max(), df_temp["new_label"].max(), segment_no * 30,
|
|||
|
# (segment_no + 1) * 30, 0]], columns=columns),
|
|||
|
# ignore_index=True)
|
|||
|
|
|||
|
df_all_event = df_all_event.append(df_event, ignore_index=True)
|
|||
|
|
|||
|
df_temp = df_all_event.loc[:, ["label_type", "pred"]]
|
|||
|
df_all_event["thresh_label"] = 1 * (df_temp["label_type"] > event_thresh)
|
|||
|
df_all_event["thresh_Pred"] = 1 * (df_temp["pred"] > thresh)
|
|||
|
return df_all_event
|
|||
|
|
|||
|
|
|||
|
# 分sampNo保存结果,并不重合地可视化
|
|||
|
# inner_test
|
|||
|
|
|||
|
# 分sampNo将与标签不一致的另行保存,并不重合地可视化
|
|||
|
|
|||
|
# import shap
|
|||
|
# explainer = shap.TreeExplainer()
|
|||
|
# shap_values = explainer.shap_values()
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
all_output_path = list(exam_path.rglob("KFold_*"))
|
|||
|
for exam_index, test_exam_path in enumerate(all_output_path):
|
|||
|
# test_exam_path = exam_path / test_exam_path
|
|||
|
set_environment(exam_index)
|
|||
|
test_and_analysis_and_visual(dataset_type="test")
|
|||
|
test_and_analysis_and_visual(dataset_type="all_test")
|