#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:Marques
@file:test_analysis.py
@email:admin@marques22.com
@email:2021022362@m.scnu.edu.cn
@time:2022/02/21
"""

import os
import sys

import pandas as pd
import torch.cuda
import numpy as np
import yaml
from matplotlib import pyplot as plt
from tqdm import tqdm
from pathlib import Path
from torch.nn import functional as F
from torch.utils.data import DataLoader
from load_dataset import TestApneaDataset2, read_dataset
from utils.Draw_ConfusionMatrix import draw_confusionMatrix
from torch import nn
from utils.calc_metrics import CALC_METRICS
from my_augment import my_augment, my_segment_augment
from model.Hybrid_Net001 import HYBRIDNET001

plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
exam_path = Path("./output/")

# 置信率阈值
thresh = 0.5
# 间隔最小距离
thresh_event_interval = 2
# 最小事件长度
thresh_event_length = 8

#
event_thresh = 1

severity_path = Path(r"/home/marques/code/marques/apnea/dataset/loc_first_csa.xlsx")
severity_label = {"all": "none"}
severity_df = pd.read_excel(severity_path)
for one_data in severity_df.index:
    one_data = severity_df.loc[one_data]
    severity_label[str(one_data["数据编号"])] = one_data["程度"]

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
gpu = torch.cuda.is_available()

num_classes = 1
calc_metrics = CALC_METRICS(num_classes)

with open("./settings.yaml") as f:
    hyp = yaml.load(f, Loader=yaml.SafeLoader)  # load hyps
data_path = hyp["Path"]["dataset"]
read_dataset(data_path, augment=my_augment)
del hyp

# 默认取最新的文件夹
all_output_path, output_path, segments_results_save_path, events_results_save_path, = [None, ] * 4
my_augment, model_path, label_path, data_path, model, model_name = [None, ] * 6
train_set, test_set = None, None
loss_func = nn.CrossEntropyLoss()

columns = ["sampNo", "segmentNo", "label_type", "new_label", "SP", "EP", "pred"]
columns2 = ["sampNo", "severity", "origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN",
            "acc", "recall", "spec", "pre", "NPV", "F1score", "support"]


def set_environment(i):
    global output_path, segments_results_save_path, events_results_save_path, model_path, label_path, data_path, \
        model, model_name, train_set, test_set

    output_path = all_output_path[i]
    print(output_path)
    segments_results_save_path = (output_path / "segments_results")
    segments_results_save_path.mkdir(exist_ok=True)
    events_results_save_path = (output_path / "events_results")
    events_results_save_path.mkdir(exist_ok=True)

    # 加载配置
    with open(output_path / "settings.yaml") as f:
        hyp = yaml.load(f, Loader=yaml.SafeLoader)  # load hyps
    data_path = hyp["Path"]["dataset"]
    label_path = hyp["Path"]["label"]
    train_set = hyp["train_set"]
    test_set = hyp["test_set"]

    model_path = output_path / "weights" / "best.pt"
    model = eval(hyp["model_name"])()
    model_name = hyp["model_name"]
    model.load_state_dict(torch.load(model_path))
    model.cuda()
    model.eval()


def test_and_analysis_and_visual(dataset_type):
    if dataset_type == "test":
        sampNo = train_set
    elif dataset_type == "all_test":
        sampNo = test_set
    test_dataset = TestApneaDataset2(data_path, label_path, select_sampno=sampNo, dataset_type=dataset_type,
                                     segment_augment=my_segment_augment)
    test_loader = DataLoader(test_dataset, batch_size=128, pin_memory=True, num_workers=0)

    test_loss = 0.0

    df_segment = pd.DataFrame(columns=columns)

    for one in tqdm(test_loader, total=len(test_loader)):
        resp, stft, labels = one[:3]
        other_info = one[3:]
        resp = resp.float().cuda() if gpu else resp.float()
        stft = stft.float().cuda() if gpu else stft.float()
        labels = labels.cuda() if gpu else labels
        with torch.no_grad():
            out = model(resp, stft)

            loss = loss_func(out, labels)

            test_loss += loss.item()

            labels = torch.unsqueeze(labels, dim=1)
            out = F.softmax(out, dim=1)
            out = torch.unsqueeze(out[:, 1], dim=1)

            calc_metrics.update(out.cpu(), labels.cpu())
        # one[0] = list(one[0].cpu().numpy())
        # one[1] = list(one[1].cpu().numpy())
        # one = one[1:]
        # out = out.view(1, -1).cpu().numpy().tolist()
        # one += out
        # result_record += [i for i in list(np.array(one, dtype=object).transpose(1, 0))]

        one2 = np.array([i.cpu().numpy() for i in (other_info + [out.squeeze()])])
        one2 = one2.transpose((1, 0))
        df = pd.DataFrame(data=one2, columns=columns)
        df_segment = df_segment.append(df, ignore_index=True)

    test_loss /= len(test_loader)
    calc_metrics.compute()
    print(calc_metrics.get_matrix(loss=test_loss, epoch=0, epoch_type="test"))
    calc_metrics.reset()

    df_segment["thresh_label"] = 1 * (df_segment["label_type"] > event_thresh).copy()
    df_segment["thresh_Pred"] = 1 * (df_segment["pred"] > thresh).copy()
    df_segment["pred"] = df_segment["pred"].copy().apply(lambda x: round(x, 3))

    # 片段级分析
    df_segment_metrics = analysis_results(df_segment, segments_results_save_path, dataset_type)

    # 绘制混淆矩阵
    # 每个样本都绘制一份
    confusionMatrix(df_segment_metrics, segments_results_save_path, dataset_type)
    # 绘制柱状图

    # 事件级分析
    # 对于inner_test 每个编号就是一个事件
    # 而对于整晚的independence_test,需要另行计算
    df_all_event = segment_to_event(df_segment, dataset_type)
    df_event_metrics = analysis_results(df_all_event, events_results_save_path, dataset_type, is_event=True)
    confusionMatrix(df_event_metrics, events_results_save_path, dataset_type)

    # 剔除质量不好的样本
    df_bad_segment = df_segment[
        (df_segment["label_type"].isin([2, 3])) & (df_segment["new_label"] == 2)]
    df_select_segment = df_segment.drop(df_bad_segment.index)
    df_select_segment_metrics = analysis_results(df_select_segment, segments_results_save_path / "remove_2",
                                                 dataset_type)
    df_select_event = segment_to_event(df_select_segment, dataset_type)
    df_event_metrics = analysis_results(df_select_event, events_results_save_path / "remove_2", dataset_type,
                                        is_event=True)


def analysis_results(df_result, base_path, dataset_type, is_event=False):
    if df_result.empty:
        print(base_path, dataset_type, "is_empty")
        return None

    (base_path / dataset_type).mkdir(exist_ok=True, parents=True)

    all_sampNo = df_result["sampNo"].unique()
    df_metrics = pd.DataFrame(columns=columns2)
    df_metrics.loc[0] = 0
    df_metrics.loc[0]["sampNo"] = dataset_type

    for index, sampNo in enumerate(all_sampNo):
        df = df_result[df_result["sampNo"] == sampNo]
        df.to_csv(
            base_path / dataset_type /
            f"{int(sampNo)}_{model_name}_{dataset_type}_{'segment' if not is_event else 'event'}_result.csv",
            index=False)

        df_metrics.loc[index + 1] = np.NAN
        df_metrics.loc[index + 1]["sampNo"] = str(int(sampNo))
        df_metrics.loc[index + 1]["support"] = df.shape[0]
        df_metrics.loc[index + 1]["severity"] = severity_label[str(int(sampNo))]

        # if dataset_type == "independence_test" or dataset_type == "train_all_test":
        #     continue
        # else:
        df_metrics.loc[index + 1]["origin_P"] = df[df["thresh_label"] == 1].shape[0]
        df_metrics.loc[index + 1]["origin_N"] = df[df["thresh_label"] == 0].shape[0]
        df_metrics.loc[index + 1]["pred_P"] = df[df["thresh_Pred"] == 1].shape[0]
        df_metrics.loc[index + 1]["pred_N"] = df[df["thresh_Pred"] == 0].shape[0]
        df_metrics.loc[index + 1]["T"] = df[df["thresh_Pred"] == df["thresh_label"]].shape[0]
        df_metrics.loc[index + 1]["F"] = df[df["thresh_Pred"] != df["thresh_label"]].shape[0]
        df_metrics.loc[index + 1]["TP"] = \
            df[(df["thresh_Pred"] == df["thresh_label"]) & (df["thresh_Pred"] == 1)].shape[0]
        df_metrics.loc[index + 1]["FP"] = \
            df[(df["thresh_Pred"] != df["thresh_label"]) & (df["thresh_Pred"] == 1)].shape[0]
        df_metrics.loc[index + 1]["TN"] = \
            df[(df["thresh_Pred"] == df["thresh_label"]) & (df["thresh_Pred"] == 0)].shape[0]
        df_metrics.loc[index + 1]["FN"] = \
            df[(df["thresh_Pred"] != df["thresh_label"]) & (df["thresh_Pred"] == 0)].shape[0]

        df_metrics.loc[0]["origin_P"] += df_metrics.loc[index + 1]["origin_P"]
        df_metrics.loc[0]["origin_N"] += df_metrics.loc[index + 1]["origin_N"]
        df_metrics.loc[0]["pred_P"] += df_metrics.loc[index + 1]["pred_P"]
        df_metrics.loc[0]["pred_N"] += df_metrics.loc[index + 1]["pred_N"]
        df_metrics.loc[0]["T"] += df_metrics.loc[index + 1]["T"]
        df_metrics.loc[0]["F"] += df_metrics.loc[index + 1]["F"]
        df_metrics.loc[0]["TP"] += df_metrics.loc[index + 1]["TP"]
        df_metrics.loc[0]["FP"] += df_metrics.loc[index + 1]["FP"]
        df_metrics.loc[0]["TN"] += df_metrics.loc[index + 1]["TN"]
        df_metrics.loc[0]["FN"] += df_metrics.loc[index + 1]["FN"]
        df_metrics.loc[0]["support"] += df_metrics.loc[index + 1]["support"]

        for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN"]:
            df_metrics.loc[index + 1][col] = df_metrics.loc[index + 1][col] if df_metrics.loc[index + 1][
                                                                                   col] != 0 else np.NAN

        df_metrics.loc[index + 1]["acc"] = df_metrics.iloc[index + 1]["T"] / df_metrics.iloc[index + 1]["support"]
        df_metrics.loc[index + 1]["recall"] = df_metrics.iloc[index + 1]["TP"] / df_metrics.iloc[index + 1]["origin_P"]
        df_metrics.loc[index + 1]["spec"] = df_metrics.iloc[index + 1]["TN"] / df_metrics.iloc[index + 1]["origin_N"]
        df_metrics.loc[index + 1]["pre"] = df_metrics.iloc[index + 1]["TP"] / df_metrics.iloc[index + 1]["pred_P"]
        df_metrics.loc[index + 1]["NPV"] = df_metrics.iloc[index + 1]["TN"] / df_metrics.iloc[index + 1]["pred_N"]
        df_metrics.loc[index + 1]["F1score"] = 2 * df_metrics.iloc[index + 1]["recall"] * df_metrics.iloc[index + 1][
            "pre"] / (df_metrics.iloc[index + 1]["recall"] + df_metrics.iloc[index + 1]["pre"])
        for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN", "acc", "recall",
                    "spec", "pre", "NPV", "F1score"]:
            df_metrics.loc[index + 1][col] = 0 if pd.isna(df_metrics.loc[index + 1][col]) else \
                df_metrics.loc[index + 1][col]
            df_metrics.loc[index + 1][col] = round(df_metrics.loc[index + 1][col], 3)

    # if dataset_type == "independence_test" or dataset_type == "train_all_test":
    #     return None
    for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN"]:
        df_metrics.loc[0][col] = df_metrics.loc[0][col] if df_metrics.loc[0][col] != 0 else np.NAN

    df_metrics.loc[0]["acc"] = df_metrics.iloc[0]["T"] / df_metrics.iloc[0]["support"]
    df_metrics.loc[0]["recall"] = df_metrics.iloc[0]["TP"] / df_metrics.iloc[0]["origin_P"]
    df_metrics.loc[0]["spec"] = df_metrics.iloc[0]["TN"] / df_metrics.iloc[0]["origin_N"]
    df_metrics.loc[0]["pre"] = df_metrics.iloc[0]["TP"] / df_metrics.iloc[0]["pred_P"]
    df_metrics.loc[0]["NPV"] = df_metrics.iloc[0]["TN"] / df_metrics.iloc[0]["pred_N"]
    df_metrics.loc[0]["F1score"] = 2 * df_metrics.iloc[0]["recall"] * df_metrics.iloc[0]["pre"] / (
            df_metrics.iloc[0]["recall"] + df_metrics.iloc[0]["pre"])
    for col in ["TP", "TN", "FP", "FN", "acc", "recall", "spec", "pre", "NPV", "F1score"]:
        df_metrics.loc[0][col] = 0 if pd.isna(df_metrics.loc[0][col]) else df_metrics.loc[0][col]
        df_metrics.loc[0][col] = round(df_metrics.loc[0][col], 3)

    # 在inner_test中根据 分严重程度绘制
    if dataset_type == "test":
        all_severity = ["正常", "轻度", "中度", "重度"]
        for index, severity in enumerate(all_severity):
            df_event = df_metrics[df_metrics["severity"] == severity]
            df_temp = pd.DataFrame(columns=columns2)
            df_temp.loc[0] = 0
            df_temp.loc[0]["sampNo"] = severity
            df_temp.loc[0]["severity"] = str(index + 1)

            df_temp.loc[0]["origin_P"] += df_event["origin_P"].sum()
            df_temp.loc[0]["origin_N"] += df_event["origin_N"].sum()
            df_temp.loc[0]["pred_P"] += df_event["pred_P"].sum()
            df_temp.loc[0]["pred_N"] += df_event["pred_N"].sum()
            df_temp.loc[0]["T"] += df_event["T"].sum()
            df_temp.loc[0]["F"] += df_event["F"].sum()
            df_temp.loc[0]["TP"] += df_event["TP"].sum()
            df_temp.loc[0]["FP"] += df_event["FP"].sum()
            df_temp.loc[0]["TN"] += df_event["TN"].sum()
            df_temp.loc[0]["FN"] += df_event["FN"].sum()
            df_temp.loc[0]["support"] += df_event["support"].sum()

            for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN"]:
                df_temp.loc[0][col] = df_temp.loc[0][col] if df_temp.loc[0][col] != 0 else np.NAN

            df_temp.loc[0]["acc"] = df_temp.iloc[0]["T"] / df_temp.iloc[0]["support"]
            df_temp.loc[0]["recall"] = df_temp.iloc[0]["TP"] / df_temp.iloc[0]["origin_P"]
            df_temp.loc[0]["spec"] = df_temp.iloc[0]["TN"] / df_temp.iloc[0]["origin_N"]
            df_temp.loc[0]["pre"] = df_temp.iloc[0]["TP"] / df_temp.iloc[0]["pred_P"]
            df_temp.loc[0]["NPV"] = df_temp.iloc[0]["TN"] / df_temp.iloc[0]["pred_N"]
            df_temp.loc[0]["F1score"] = 2 * df_temp.iloc[0]["recall"] * df_temp.iloc[0]["pre"] / (
                    df_temp.iloc[0]["recall"] + df_temp.iloc[0]["pre"])

            for col in ["origin_P", "origin_N", "pred_P", "pred_N", "T", "F", "TP", "TN", "FP", "FN", "acc", "recall",
                        "spec", "pre", "NPV", "F1score"]:
                df_temp.loc[0][col] = 0 if pd.isna(df_temp.loc[0][col]) else df_temp.loc[0][col]
                df_temp.loc[0][col] = round(df_temp.loc[0][col], 3)

            df_metrics = df_metrics.append(df_temp, ignore_index=True)

    df_backup = df_metrics
    df_metrics = df_metrics.astype("str")
    df_metrics = df_metrics.sort_values("severity")
    df_metrics.to_csv(base_path / dataset_type /
                      f"{model_name}_{dataset_type}_{'segment' if not is_event else 'event'}_all_metrics.csv",
                      index=False, encoding="gbk")

    return df_backup


def confusionMatrix(df_analysis, base_path, dataset_type):
    if df_analysis is None:
        print(base_path, dataset_type, "is None")
        return

    if df_analysis.empty:
        print(base_path, dataset_type, "is_empty")
        return
    classes = ["normal", "SA"]
    (base_path / dataset_type / "confusionMatrix").mkdir(exist_ok=True, parents=True)
    for one_samp in df_analysis.index:
        one_samp = df_analysis.loc[one_samp]
        cm = np.array([[one_samp["TN"], one_samp["FP"]], [one_samp["FN"], one_samp["TP"]]])
        draw_confusionMatrix(cm, classes=classes, title=str(one_samp["severity"]) + " " + one_samp["sampNo"],
                             save_path=base_path / dataset_type / "confusionMatrix" / f"{one_samp['sampNo']}.jpg")


def segment_to_event(df_segment, dataset_type):
    df_all_event = pd.DataFrame(columns=columns)
    all_sampNo = df_segment["sampNo"].unique()

    if dataset_type == "test":
        for index, sampNo in enumerate(all_sampNo):
            df_event = pd.DataFrame(columns=columns)
            df = df_segment[df_segment["sampNo"] == sampNo].copy()
            df["thresh_label"] = 1 * (df["label_type"] > event_thresh)
            df["thresh_Pred"] = 1 * (df["pred"] > thresh)
            all_segments_no = df["segmentNo"].unique()

            for index_se, segment_No in enumerate(all_segments_no):
                df_temp = df[df["segmentNo"] == segment_No].copy()
                SP = df_temp.iloc[0]["EP"]
                EP = df_temp.iloc[-1]["EP"] + 1
                df_event.loc[index_se] = [int(sampNo), segment_No, df_temp.iloc[0]["label_type"],
                                          df_temp.iloc[0]["new_label"], SP, EP, 0]

                thresh_Pred = df_temp["thresh_Pred"].values
                thresh_Pred2 = thresh_Pred.copy()

                # 扩充
                for index_pred, pred in enumerate(thresh_Pred):
                    if pred == 0:
                        continue

                    for interval in range(1, thresh_event_interval):
                        if pred == 1 and index_pred + interval < thresh_Pred.size:
                            thresh_Pred2[index_pred + interval] = 1
                        else:
                            continue

                # 判断
                same_ar = np.concatenate(([True], thresh_Pred2[:-1] != thresh_Pred2[1:], [True]))
                index_ar = np.where(same_ar)[0]
                count_ar = np.diff(index_ar)
                value_ar = thresh_Pred2[same_ar[:-1]] * count_ar
                for i in value_ar:
                    if i > thresh_event_length:
                        df_event.iloc[index_se]["pred"] = 1

            # df_event.to_csv(events_results / dataset_type / f"{int(sampNo)}_event_results.csv", index=False,
            #                 encoding="gbk")
            df_all_event = df_all_event.append(df_event, ignore_index=True)
    else:
        for index, sampNo in enumerate(all_sampNo):
            df_event = pd.DataFrame(columns=columns)
            df = df_segment[df_segment["sampNo"] == sampNo].copy()
            df["thresh_label"] = 1 * (df["label_type"] > event_thresh)
            df["thresh_Pred"] = 1 * (df["pred"] > thresh)
            thresh_Pred = df["thresh_Pred"].values
            thresh_Pred2 = thresh_Pred.copy()
            # 扩充
            for index_pred, pred in enumerate(thresh_Pred):
                if pred == 0:
                    continue

                for interval in range(1, thresh_event_interval):
                    if pred == 1 and index_pred + interval < thresh_Pred.size:
                        thresh_Pred2[index_pred + interval] = 1
                    else:
                        continue

            # 判断
            same_ar = np.concatenate(([True], thresh_Pred2[:-1] != thresh_Pred2[1:], [True]))
            index_ar = np.where(same_ar)[0]
            count_ar = np.diff(index_ar)
            value_ar = thresh_Pred2[same_ar[:-1]] * count_ar

            for value_index, value in enumerate(value_ar):
                SP = index_ar[value_index]
                EP = index_ar[value_index] + count_ar[value_index]
                # TP, FP
                if value > thresh_event_length:

                    # label_type = 1 if thresh_Pred2[SP:EP].sum() > 0 else 0
                    label_type = df["label_type"][SP:EP].max()
                    new_label = df["new_label"][SP:EP].max()
                    df_event = df_event.append(pd.DataFrame([[int(sampNo), SP // 30, label_type, new_label,
                                                              SP, EP, thresh_Pred2[SP]]], columns=columns),
                                               ignore_index=True)
                    if value > 30:
                        print([int(sampNo), SP // 30, label_type, new_label, SP, EP, thresh_Pred2[SP]])
                # 长度不够
                else:
                    df["thresh_Pred"][SP:EP] = 0

            # 对负样本进行统计
            for segment_no in df["segmentNo"].unique():
                df_temp = df[df["segmentNo"] == segment_no]
                if df_temp["thresh_Pred"].sum() > 0:
                    continue

                df_event = df_event.append(pd.DataFrame(
                    [[int(sampNo), segment_no, df_temp["label_type"].max(), df_temp["new_label"].max(), segment_no * 30,
                      (segment_no + 1) * 30, 0]], columns=columns),
                                           ignore_index=True)

            df_all_event = df_all_event.append(df_event, ignore_index=True)

    df_temp = df_all_event.loc[:, ["label_type", "pred"]]
    df_all_event["thresh_label"] = 1 * (df_temp["label_type"] > event_thresh)
    df_all_event["thresh_Pred"] = 1 * (df_temp["pred"] > thresh)
    return df_all_event


# 分sampNo保存结果,并不重合地可视化
# inner_test

# 分sampNo将与标签不一致的另行保存,并不重合地可视化

# import shap
# explainer = shap.TreeExplainer()
# shap_values = explainer.shap_values()

if __name__ == '__main__':
    all_output_path = list(exam_path.rglob("KFold_0"))
    for exam_index, test_exam_path in enumerate(all_output_path):
        # test_exam_path = exam_path / test_exam_path
        set_environment(exam_index)
        # test_and_analysis_and_visual(dataset_type="test")
        test_and_analysis_and_visual(dataset_type="all_test")