478 lines
21 KiB
Python
478 lines
21 KiB
Python
#!/usr/bin/python
|
||
# -*- coding: UTF-8 -*-
|
||
"""
|
||
@author:Marques
|
||
@file:test_analysis.py
|
||
@email:admin@marques22.com
|
||
@email:2021022362@m.scnu.edu.cn
|
||
@time:2022/02/21
|
||
"""
|
||
import logging
|
||
import os
|
||
import sys
|
||
|
||
import pandas as pd
|
||
import torch.cuda
|
||
import numpy as np
|
||
import yaml
|
||
from matplotlib import pyplot as plt
|
||
from tqdm import tqdm
|
||
from pathlib import Path
|
||
from torch.nn import functional as F
|
||
from torch.utils.data import DataLoader
|
||
from load_dataset import TestApneaDataset2, read_dataset
|
||
from utils.Draw_ConfusionMatrix import draw_confusionMatrix
|
||
from torch import nn
|
||
from utils.calc_metrics import CALC_METRICS
|
||
from my_augment import my_augment, my_segment_augment
|
||
from model.Hybrid_Net018 import HYBRIDNET018
|
||
|
||
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||
exam_path = Path("./output/")
|
||
|
||
# 置信率阈值
|
||
thresh = 0.5
|
||
# 间隔最小距离
|
||
thresh_event_interval = 0
|
||
# 最小事件长度
|
||
thresh_event_length = 2
|
||
|
||
#
|
||
event_thresh = 1
|
||
|
||
severity_path = Path(r"/home/marques/code/marques/apnea/dataset/loc_first_csa.xlsx")
|
||
severity_label = {"all": "none"}
|
||
severity_df = pd.read_excel(severity_path)
|
||
for one_data in severity_df.index:
|
||
one_data = severity_df.loc[one_data]
|
||
severity_label[str(one_data["数据编号"])] = one_data["程度"]
|
||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||
gpu = torch.cuda.is_available()
|
||
|
||
num_classes = 1
|
||
calc_metrics = CALC_METRICS(num_classes)
|
||
|
||
with open("./settings.yaml") as f:
|
||
hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
|
||
data_path = hyp["Path"]["dataset"]
|
||
read_dataset(data_path, augment=my_augment)
|
||
del hyp
|
||
|
||
# 默认取最新的文件夹
|
||
all_output_path, output_path, segments_results_save_path, events_results_save_path, = [None, ] * 4
|
||
my_augment, model_path, label_path, data_path, model, model_name = [None, ] * 6
|
||
train_set, test_set = None, None
|
||
loss_func = nn.CrossEntropyLoss()
|
||
|
||
columns = ["sampNo", "segmentNo", "label_type", "new_label", "SP", "EP", "pred"]
|
||
columns2 = ["sampNo", "severity", "origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN",
|
||
"acc", "recall", "spec", "pre", "NPV", "F1score", "support"]
|
||
|
||
logging.getLogger('matplotlib.font_manager').disabled = True
|
||
logging.getLogger('matplotlib.ticker').disabled = True
|
||
logger = logging.getLogger()
|
||
logger.setLevel(logging.INFO)
|
||
ch = logging.StreamHandler()
|
||
ch.setLevel(logging.INFO)
|
||
ch.setFormatter(logging.Formatter("%(asctime)s: %(message)s"))
|
||
logger.addHandler(ch)
|
||
|
||
if (exam_path / "test.log").exists():
|
||
(exam_path / "test.log").unlink()
|
||
fh = logging.FileHandler(exam_path / "test.log", mode='a')
|
||
fh.setLevel(logging.INFO)
|
||
fh.setFormatter(logging.Formatter("%(message)s"))
|
||
logger.addHandler(fh)
|
||
logger.info("------------------------------------")
|
||
|
||
|
||
def set_environment(i):
|
||
global output_path, segments_results_save_path, events_results_save_path, model_path, label_path, data_path, \
|
||
model, model_name, train_set, test_set
|
||
|
||
output_path = all_output_path[i]
|
||
logger.info(output_path)
|
||
segments_results_save_path = (output_path / "segments_results")
|
||
segments_results_save_path.mkdir(exist_ok=True)
|
||
events_results_save_path = (output_path / "events_results")
|
||
events_results_save_path.mkdir(exist_ok=True)
|
||
|
||
# 加载配置
|
||
with open(output_path / "settings.yaml") as f:
|
||
hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
|
||
data_path = hyp["Path"]["dataset"]
|
||
label_path = hyp["Path"]["label"]
|
||
train_set = hyp["train_set"]
|
||
test_set = hyp["test_set"]
|
||
|
||
model_path = output_path / "weights" / "best.pt"
|
||
model = eval(hyp["model_name"])()
|
||
model_name = hyp["model_name"]
|
||
model.load_state_dict(torch.load(model_path))
|
||
model.cuda()
|
||
model.eval()
|
||
|
||
|
||
def test_and_analysis_and_visual(dataset_type):
|
||
if dataset_type == "test":
|
||
sampNo = train_set
|
||
elif dataset_type == "all_test":
|
||
sampNo = test_set
|
||
else:
|
||
sampNo = None
|
||
logger.info("出错了")
|
||
|
||
exam_name = Path("./").absolute().name
|
||
|
||
test_dataset = TestApneaDataset2(data_path, label_path, select_sampno=sampNo, dataset_type=dataset_type,
|
||
segment_augment=my_segment_augment)
|
||
test_loader = DataLoader(test_dataset, batch_size=128, pin_memory=True, num_workers=0)
|
||
|
||
test_loss = 0.0
|
||
|
||
df_segment = pd.DataFrame(columns=columns)
|
||
|
||
for one in tqdm(test_loader, total=len(test_loader)):
|
||
resp, labels = one[:2]
|
||
other_info = one[2:]
|
||
resp = resp.float().cuda() if gpu else resp.float()
|
||
labels = labels.cuda() if gpu else labels
|
||
with torch.no_grad():
|
||
out = model(resp)
|
||
|
||
loss = loss_func(out, labels)
|
||
|
||
test_loss += loss.item()
|
||
|
||
labels = torch.unsqueeze(labels, dim=1)
|
||
out = F.softmax(out, dim=1)
|
||
out = torch.unsqueeze(out[:, 1], dim=1)
|
||
|
||
calc_metrics.update(out.cpu(), labels.cpu())
|
||
# one[0] = list(one[0].cpu().numpy())
|
||
# one[1] = list(one[1].cpu().numpy())
|
||
# one = one[1:]
|
||
# out = out.view(1, -1).cpu().numpy().tolist()
|
||
# one += out
|
||
# result_record += [i for i in list(np.array(one, dtype=object).transpose(1, 0))]
|
||
|
||
one2 = np.array([i.cpu().numpy() for i in (other_info + [out.squeeze()])])
|
||
one2 = one2.transpose((1, 0))
|
||
df = pd.DataFrame(data=one2, columns=columns)
|
||
df_segment = df_segment.append(df, ignore_index=True)
|
||
|
||
test_loss /= len(test_loader)
|
||
calc_metrics.compute()
|
||
logger.info(f"EXAM_NAME: {exam_name} SampNO: {sampNo}")
|
||
logger.info(calc_metrics.get_matrix(loss=test_loss, epoch=0, epoch_type="test"))
|
||
calc_metrics.reset()
|
||
|
||
df_segment["thresh_label"] = 1 * (df_segment["label_type"] > event_thresh).copy()
|
||
df_segment["thresh_Pred"] = 1 * (df_segment["pred"] > thresh).copy()
|
||
df_segment["pred"] = df_segment["pred"].copy().apply(lambda x: round(x, 3))
|
||
|
||
# 片段级分析
|
||
df_segment_metrics = analysis_results(df_segment, segments_results_save_path, dataset_type)
|
||
|
||
# 绘制混淆矩阵
|
||
# 每个样本都绘制一份
|
||
confusionMatrix(df_segment_metrics, segments_results_save_path, dataset_type)
|
||
# 绘制柱状图
|
||
|
||
# 事件级分析
|
||
# 对于inner_test 每个编号就是一个事件
|
||
# 而对于整晚的independence_test,需要另行计算
|
||
df_all_event = segment_to_event(df_segment, dataset_type)
|
||
df_event_metrics = analysis_results(df_all_event, events_results_save_path, dataset_type, is_event=True)
|
||
confusionMatrix(df_event_metrics, events_results_save_path, dataset_type)
|
||
|
||
# 剔除质量不好的样本
|
||
df_bad_segment = df_segment[
|
||
(df_segment["label_type"].isin([2, 3])) & (df_segment["new_label"] == 2)]
|
||
df_select_segment = df_segment.drop(df_bad_segment.index)
|
||
df_select_segment_metrics = analysis_results(df_select_segment, segments_results_save_path / "remove_2",
|
||
dataset_type)
|
||
df_select_event = segment_to_event(df_select_segment, dataset_type)
|
||
df_event_metrics = analysis_results(df_select_event, events_results_save_path / "remove_2", dataset_type,
|
||
is_event=True)
|
||
|
||
|
||
def analysis_results(df_result, base_path, dataset_type, is_event=False):
|
||
if df_result.empty:
|
||
logger.info(base_path, dataset_type, "is_empty")
|
||
return None
|
||
|
||
(base_path / dataset_type).mkdir(exist_ok=True, parents=True)
|
||
|
||
all_sampNo = df_result["sampNo"].unique()
|
||
df_metrics = pd.DataFrame(columns=columns2)
|
||
df_metrics.loc[0] = 0
|
||
df_metrics.loc[0]["sampNo"] = dataset_type
|
||
|
||
for index, sampNo in enumerate(all_sampNo):
|
||
df = df_result[df_result["sampNo"] == sampNo]
|
||
df.to_csv(
|
||
base_path / dataset_type /
|
||
f"{int(sampNo)}_{model_name}_{dataset_type}_{'segment' if not is_event else 'event'}_result.csv",
|
||
index=False)
|
||
|
||
df_metrics.loc[index + 1] = np.NAN
|
||
df_metrics.loc[index + 1]["sampNo"] = str(int(sampNo))
|
||
df_metrics.loc[index + 1]["support"] = df.shape[0]
|
||
df_metrics.loc[index + 1]["severity"] = severity_label[str(int(sampNo))]
|
||
|
||
# if dataset_type == "independence_test" or dataset_type == "train_all_test":
|
||
# continue
|
||
# else:
|
||
df_metrics.loc[index + 1]["origin_P"] = df[df["thresh_label"] == 1].shape[0]
|
||
df_metrics.loc[index + 1]["origin_N"] = df[df["thresh_label"] == 0].shape[0]
|
||
df_metrics.loc[index + 1]["pred_P"] = df[df["thresh_Pred"] == 1].shape[0]
|
||
df_metrics.loc[index + 1]["pred_N"] = df[df["thresh_Pred"] == 0].shape[0]
|
||
df_metrics.loc[index + 1]["T"] = df[df["thresh_Pred"] == df["thresh_label"]].shape[0]
|
||
df_metrics.loc[index + 1]["F"] = df[df["thresh_Pred"] != df["thresh_label"]].shape[0]
|
||
df_metrics.loc[index + 1]["TP"] = \
|
||
df[(df["thresh_Pred"] == df["thresh_label"]) & (df["thresh_Pred"] == 1)].shape[0]
|
||
df_metrics.loc[index + 1]["FP"] = \
|
||
df[(df["thresh_Pred"] != df["thresh_label"]) & (df["thresh_Pred"] == 1)].shape[0]
|
||
df_metrics.loc[index + 1]["TN"] = \
|
||
df[(df["thresh_Pred"] == df["thresh_label"]) & (df["thresh_Pred"] == 0)].shape[0]
|
||
df_metrics.loc[index + 1]["FN"] = \
|
||
df[(df["thresh_Pred"] != df["thresh_label"]) & (df["thresh_Pred"] == 0)].shape[0]
|
||
|
||
df_metrics.loc[0]["origin_P"] += df_metrics.loc[index + 1]["origin_P"]
|
||
df_metrics.loc[0]["origin_N"] += df_metrics.loc[index + 1]["origin_N"]
|
||
df_metrics.loc[0]["pred_P"] += df_metrics.loc[index + 1]["pred_P"]
|
||
df_metrics.loc[0]["pred_N"] += df_metrics.loc[index + 1]["pred_N"]
|
||
df_metrics.loc[0]["T"] += df_metrics.loc[index + 1]["T"]
|
||
df_metrics.loc[0]["F"] += df_metrics.loc[index + 1]["F"]
|
||
df_metrics.loc[0]["TP"] += df_metrics.loc[index + 1]["TP"]
|
||
df_metrics.loc[0]["FP"] += df_metrics.loc[index + 1]["FP"]
|
||
df_metrics.loc[0]["TN"] += df_metrics.loc[index + 1]["TN"]
|
||
df_metrics.loc[0]["FN"] += df_metrics.loc[index + 1]["FN"]
|
||
df_metrics.loc[0]["support"] += df_metrics.loc[index + 1]["support"]
|
||
|
||
for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN"]:
|
||
df_metrics.loc[index + 1][col] = df_metrics.loc[index + 1][col] if df_metrics.loc[index + 1][
|
||
col] != 0 else np.NAN
|
||
|
||
df_metrics.loc[index + 1]["acc"] = df_metrics.iloc[index + 1]["T"] / df_metrics.iloc[index + 1]["support"]
|
||
df_metrics.loc[index + 1]["recall"] = df_metrics.iloc[index + 1]["TP"] / df_metrics.iloc[index + 1]["origin_P"]
|
||
df_metrics.loc[index + 1]["spec"] = df_metrics.iloc[index + 1]["TN"] / df_metrics.iloc[index + 1]["origin_N"]
|
||
df_metrics.loc[index + 1]["pre"] = df_metrics.iloc[index + 1]["TP"] / df_metrics.iloc[index + 1]["pred_P"]
|
||
df_metrics.loc[index + 1]["NPV"] = df_metrics.iloc[index + 1]["TN"] / df_metrics.iloc[index + 1]["pred_N"]
|
||
df_metrics.loc[index + 1]["F1score"] = 2 * df_metrics.iloc[index + 1]["recall"] * df_metrics.iloc[index + 1][
|
||
"pre"] / (df_metrics.iloc[index + 1]["recall"] + df_metrics.iloc[index + 1]["pre"])
|
||
for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN", "acc", "recall",
|
||
"spec", "pre", "NPV", "F1score"]:
|
||
df_metrics.loc[index + 1][col] = 0 if pd.isna(df_metrics.loc[index + 1][col]) else \
|
||
df_metrics.loc[index + 1][col]
|
||
df_metrics.loc[index + 1][col] = round(df_metrics.loc[index + 1][col], 3)
|
||
|
||
# if dataset_type == "independence_test" or dataset_type == "train_all_test":
|
||
# return None
|
||
for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN"]:
|
||
df_metrics.loc[0][col] = df_metrics.loc[0][col] if df_metrics.loc[0][col] != 0 else np.NAN
|
||
|
||
df_metrics.loc[0]["acc"] = df_metrics.iloc[0]["T"] / df_metrics.iloc[0]["support"]
|
||
df_metrics.loc[0]["recall"] = df_metrics.iloc[0]["TP"] / df_metrics.iloc[0]["origin_P"]
|
||
df_metrics.loc[0]["spec"] = df_metrics.iloc[0]["TN"] / df_metrics.iloc[0]["origin_N"]
|
||
df_metrics.loc[0]["pre"] = df_metrics.iloc[0]["TP"] / df_metrics.iloc[0]["pred_P"]
|
||
df_metrics.loc[0]["NPV"] = df_metrics.iloc[0]["TN"] / df_metrics.iloc[0]["pred_N"]
|
||
df_metrics.loc[0]["F1score"] = 2 * df_metrics.iloc[0]["recall"] * df_metrics.iloc[0]["pre"] / (
|
||
df_metrics.iloc[0]["recall"] + df_metrics.iloc[0]["pre"])
|
||
for col in ["TP", "TN", "FP", "FN", "acc", "recall", "spec", "pre", "NPV", "F1score"]:
|
||
df_metrics.loc[0][col] = 0 if pd.isna(df_metrics.loc[0][col]) else df_metrics.loc[0][col]
|
||
df_metrics.loc[0][col] = round(df_metrics.loc[0][col], 3)
|
||
|
||
# 在inner_test中根据 分严重程度绘制
|
||
if dataset_type == "test":
|
||
all_severity = ["正常", "轻度", "中度", "重度"]
|
||
for index, severity in enumerate(all_severity):
|
||
df_event = df_metrics[df_metrics["severity"] == severity]
|
||
df_temp = pd.DataFrame(columns=columns2)
|
||
df_temp.loc[0] = 0
|
||
df_temp.loc[0]["sampNo"] = severity
|
||
df_temp.loc[0]["severity"] = str(index + 1)
|
||
|
||
df_temp.loc[0]["origin_P"] += df_event["origin_P"].sum()
|
||
df_temp.loc[0]["origin_N"] += df_event["origin_N"].sum()
|
||
df_temp.loc[0]["pred_P"] += df_event["pred_P"].sum()
|
||
df_temp.loc[0]["pred_N"] += df_event["pred_N"].sum()
|
||
df_temp.loc[0]["T"] += df_event["T"].sum()
|
||
df_temp.loc[0]["F"] += df_event["F"].sum()
|
||
df_temp.loc[0]["TP"] += df_event["TP"].sum()
|
||
df_temp.loc[0]["FP"] += df_event["FP"].sum()
|
||
df_temp.loc[0]["TN"] += df_event["TN"].sum()
|
||
df_temp.loc[0]["FN"] += df_event["FN"].sum()
|
||
df_temp.loc[0]["support"] += df_event["support"].sum()
|
||
|
||
for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN"]:
|
||
df_temp.loc[0][col] = df_temp.loc[0][col] if df_temp.loc[0][col] != 0 else np.NAN
|
||
|
||
df_temp.loc[0]["acc"] = df_temp.iloc[0]["T"] / df_temp.iloc[0]["support"]
|
||
df_temp.loc[0]["recall"] = df_temp.iloc[0]["TP"] / df_temp.iloc[0]["origin_P"]
|
||
df_temp.loc[0]["spec"] = df_temp.iloc[0]["TN"] / df_temp.iloc[0]["origin_N"]
|
||
df_temp.loc[0]["pre"] = df_temp.iloc[0]["TP"] / df_temp.iloc[0]["pred_P"]
|
||
df_temp.loc[0]["NPV"] = df_temp.iloc[0]["TN"] / df_temp.iloc[0]["pred_N"]
|
||
df_temp.loc[0]["F1score"] = 2 * df_temp.iloc[0]["recall"] * df_temp.iloc[0]["pre"] / (
|
||
df_temp.iloc[0]["recall"] + df_temp.iloc[0]["pre"])
|
||
|
||
for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN", "acc", "recall",
|
||
"spec", "pre", "NPV", "F1score"]:
|
||
df_temp.loc[0][col] = 0 if pd.isna(df_temp.loc[0][col]) else df_temp.loc[0][col]
|
||
df_temp.loc[0][col] = round(df_temp.loc[0][col], 3)
|
||
|
||
df_metrics = df_metrics.append(df_temp, ignore_index=True)
|
||
|
||
df_backup = df_metrics
|
||
df_metrics = df_metrics.astype("str")
|
||
df_metrics = df_metrics.sort_values("severity")
|
||
df_metrics.to_csv(base_path / dataset_type /
|
||
f"{model_name}_{dataset_type}_{'segment' if not is_event else 'event'}_all_metrics.csv",
|
||
index=False, encoding="gbk")
|
||
|
||
return df_backup
|
||
|
||
|
||
def confusionMatrix(df_analysis, base_path, dataset_type):
|
||
if df_analysis is None:
|
||
logger.info(base_path, dataset_type, "is None")
|
||
return
|
||
|
||
if df_analysis.empty:
|
||
logger.info(base_path, dataset_type, "is_empty")
|
||
return
|
||
classes = ["normal", "SA"]
|
||
(base_path / dataset_type / "confusionMatrix").mkdir(exist_ok=True, parents=True)
|
||
for one_samp in df_analysis.index:
|
||
one_samp = df_analysis.loc[one_samp]
|
||
cm = np.array([[one_samp["TN"], one_samp["FP"]], [one_samp["FN"], one_samp["TP"]]])
|
||
draw_confusionMatrix(cm, classes=classes, title=str(one_samp["severity"]) + " " + one_samp["sampNo"],
|
||
save_path=base_path / dataset_type / "confusionMatrix" / f"{one_samp['sampNo']}.jpg")
|
||
|
||
|
||
def segment_to_event(df_segment, dataset_type):
|
||
df_all_event = pd.DataFrame(columns=columns)
|
||
all_sampNo = df_segment["sampNo"].unique()
|
||
|
||
if dataset_type == "test":
|
||
for index, sampNo in enumerate(all_sampNo):
|
||
df_event = pd.DataFrame(columns=columns)
|
||
df = df_segment[df_segment["sampNo"] == sampNo].copy()
|
||
df["thresh_label"] = 1 * (df["label_type"] > event_thresh)
|
||
df["thresh_Pred"] = 1 * (df["pred"] > thresh)
|
||
all_segments_no = df["segmentNo"].unique()
|
||
|
||
for index_se, segment_No in enumerate(all_segments_no):
|
||
df_temp = df[df["segmentNo"] == segment_No].copy()
|
||
SP = df_temp.iloc[0]["EP"]
|
||
EP = df_temp.iloc[-1]["EP"] + 1
|
||
df_event.loc[index_se] = [int(sampNo), segment_No, df_temp.iloc[0]["label_type"],
|
||
df_temp.iloc[0]["new_label"], SP, EP, 0]
|
||
|
||
thresh_Pred = df_temp["thresh_Pred"].values
|
||
thresh_Pred2 = thresh_Pred.copy()
|
||
|
||
# 扩充
|
||
for index_pred, pred in enumerate(thresh_Pred):
|
||
if pred == 0:
|
||
continue
|
||
|
||
for interval in range(1, thresh_event_interval):
|
||
if pred == 1 and index_pred + interval < thresh_Pred.size:
|
||
thresh_Pred2[index_pred + interval] = 1
|
||
else:
|
||
continue
|
||
|
||
# 判断
|
||
same_ar = np.concatenate(([True], thresh_Pred2[:-1] != thresh_Pred2[1:], [True]))
|
||
index_ar = np.where(same_ar)[0]
|
||
count_ar = np.diff(index_ar)
|
||
value_ar = thresh_Pred2[same_ar[:-1]] * count_ar
|
||
for i in value_ar:
|
||
if i > thresh_event_length:
|
||
df_event.iloc[index_se]["pred"] = 1
|
||
|
||
# df_event.to_csv(events_results / dataset_type / f"{int(sampNo)}_event_results.csv", index=False,
|
||
# encoding="gbk")
|
||
df_all_event = df_all_event.append(df_event, ignore_index=True)
|
||
else:
|
||
for index, sampNo in enumerate(all_sampNo):
|
||
df_event = pd.DataFrame(columns=columns)
|
||
df = df_segment[df_segment["sampNo"] == sampNo].copy()
|
||
df["thresh_label"] = 1 * (df["label_type"] > event_thresh)
|
||
df["thresh_Pred"] = 1 * (df["pred"] > thresh)
|
||
thresh_Pred = df["thresh_Pred"].values
|
||
thresh_Pred2 = thresh_Pred.copy()
|
||
# 扩充
|
||
for index_pred, pred in enumerate(thresh_Pred):
|
||
if pred == 0:
|
||
continue
|
||
|
||
for interval in range(1, thresh_event_interval):
|
||
if pred == 1 and index_pred + interval < thresh_Pred.size:
|
||
thresh_Pred2[index_pred + interval] = 1
|
||
else:
|
||
continue
|
||
|
||
# 判断
|
||
same_ar = np.concatenate(([True], thresh_Pred2[:-1] != thresh_Pred2[1:], [True]))
|
||
index_ar = np.where(same_ar)[0]
|
||
count_ar = np.diff(index_ar)
|
||
value_ar = thresh_Pred2[same_ar[:-1]] * count_ar
|
||
|
||
for value_index, value in enumerate(value_ar):
|
||
SP = index_ar[value_index]
|
||
EP = index_ar[value_index] + count_ar[value_index]
|
||
# TP, FP
|
||
if value > thresh_event_length:
|
||
|
||
# label_type = 1 if thresh_Pred2[SP:EP].sum() > 0 else 0
|
||
label_type = df["label_type"][SP:EP].max()
|
||
new_label = df["new_label"][SP:EP].max()
|
||
df_event = df_event.append(pd.DataFrame([[int(sampNo), SP // 30, label_type, new_label,
|
||
SP, EP, thresh_Pred2[SP]]], columns=columns),
|
||
ignore_index=True)
|
||
# if value > 30:
|
||
# logger.info([int(sampNo), SP // 30, label_type, new_label, SP, EP, thresh_Pred2[SP]])
|
||
# 长度不够
|
||
else:
|
||
df["thresh_Pred"][SP:EP] = 0
|
||
|
||
# 对负样本进行统计
|
||
# for segment_no in df["segmentNo"].unique():
|
||
# df_temp = df[df["segmentNo"] == segment_no]
|
||
# if df_temp["thresh_Pred"].sum() > 0:
|
||
# continue
|
||
#
|
||
# df_event = df_event.append(pd.DataFrame(
|
||
# [[int(sampNo), segment_no, df_temp["label_type"].max(), df_temp["new_label"].max(), segment_no * 30,
|
||
# (segment_no + 1) * 30, 0]], columns=columns),
|
||
# ignore_index=True)
|
||
|
||
df_all_event = df_all_event.append(df_event, ignore_index=True)
|
||
|
||
df_temp = df_all_event.loc[:, ["label_type", "pred"]]
|
||
df_all_event["thresh_label"] = 1 * (df_temp["label_type"] > event_thresh)
|
||
df_all_event["thresh_Pred"] = 1 * (df_temp["pred"] > thresh)
|
||
return df_all_event
|
||
|
||
|
||
# 分sampNo保存结果,并不重合地可视化
|
||
# inner_test
|
||
|
||
# 分sampNo将与标签不一致的另行保存,并不重合地可视化
|
||
|
||
# import shap
|
||
# explainer = shap.TreeExplainer()
|
||
# shap_values = explainer.shap_values()
|
||
|
||
if __name__ == '__main__':
|
||
all_output_path = list(exam_path.rglob("KFold_*"))
|
||
for exam_index, test_exam_path in enumerate(all_output_path):
|
||
# test_exam_path = exam_path / test_exam_path
|
||
set_environment(exam_index)
|
||
test_and_analysis_and_visual(dataset_type="test")
|
||
test_and_analysis_and_visual(dataset_type="all_test")
|