#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:Marques
@file:calc_metrics.py
@email:admin@marques22.com
@email:2021022362@m.scnu.edu.cn
@time:2022/02/12
"""
import torch
import torchmetrics


class CALC_METRICS:
    metrics = []
    nc = 0

    def __init__(self, nc):
        self.nc = nc
        self.metrics.append(torchmetrics.Accuracy(average="none", num_classes=nc, multiclass=False))
        self.metrics.append(torchmetrics.Recall(average="none", num_classes=nc, multiclass=False))
        self.metrics.append(torchmetrics.Precision(average="none", num_classes=nc, multiclass=False))
        self.metrics.append(torchmetrics.Specificity(average="none", num_classes=nc, multiclass=False))
        self.metrics.append(torchmetrics.F1Score(average="none", num_classes=nc, multiclass=False))
        self.valid_result = self.train_result = None

    def update(self, pred, target):
        for part1 in self.metrics:
            part1.update(pred.cpu(), target.cpu())

    def compute(self):
        result = []
        for part1 in self.metrics:
            result.append(part1.compute())

    def reset(self):
        for part1 in self.metrics:
            part1.reset()

    def get_matrix(self, loss=None, cur_lr=None, epoch=None, epoch_type=None):
        temp_result = []
        for j in self.metrics:
            compute_result = (j.compute().cpu().numpy() * 100).tolist()
            temp_result.append(compute_result)

        if epoch_type == "train":
            self.train_result = [loss] + temp_result
        elif epoch_type == "valid":
            self.valid_result = [loss] + temp_result
        else:
            pass

        a = ""
        a += f"{epoch_type} epoch: {str(epoch)} loss: {str(loss)} lr: {str(cur_lr)} \n"
        a += " " * 8 + "Acc".center(8) + "Rec".center(8) + "Pre".center(8) + "Spe".center(8) + "F1".center(8) + "\n"
        a += "all".center(8) + "".join([str(round(float(i), 2)).center(8) for i in temp_result]) + "\n"
        return a

    def wandb_log(self, wandb=None, cur_lr=None):
        if wandb is None:
            return

        keyword = ["Accuracy", "Recall", "Precision", "Specificity", "F1Score"]
        dict_key = []
        for epoch_type in ["train", "valid"]:
            dict_key.append(epoch_type + "/" + "loss")
            for i in keyword:
                dict_key.append(epoch_type + "/" + i)

        log_dict = dict(zip(dict_key, self.train_result + self.valid_result))
        log_dict["lr"] = cur_lr
        wandb.log(log_dict)


if __name__ == '__main__':
    # pred = [[0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9], [1.0]]
    # true = [[0], [0], [1], [0], [0], [0], [0], [0], [0], [1]]
    pred = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    true = [0,   0,   1,   0,   0,   0,   0,   0,   0,   1]
    pred = torch.tensor(pred).cuda()
    true = torch.tensor(true).cuda()
    calc_metrics = CALC_METRICS(1)
    calc_metrics.update(pred, true)
    print(calc_metrics.get_matrix())