85 lines
2.9 KiB
Python
85 lines
2.9 KiB
Python
#!/usr/bin/python
|
|
# -*- coding: UTF-8 -*-
|
|
"""
|
|
@author:Marques
|
|
@file:calc_metrics.py
|
|
@email:admin@marques22.com
|
|
@email:2021022362@m.scnu.edu.cn
|
|
@time:2022/02/12
|
|
"""
|
|
import torch
|
|
import torchmetrics
|
|
|
|
|
|
class CALC_METRICS:
|
|
metrics = []
|
|
nc = 0
|
|
|
|
def __init__(self, nc):
|
|
self.nc = nc
|
|
self.metrics.append(torchmetrics.Accuracy(average="none", num_classes=nc, multiclass=False))
|
|
self.metrics.append(torchmetrics.Recall(average="none", num_classes=nc, multiclass=False))
|
|
self.metrics.append(torchmetrics.Precision(average="none", num_classes=nc, multiclass=False))
|
|
self.metrics.append(torchmetrics.Specificity(average="none", num_classes=nc, multiclass=False))
|
|
self.metrics.append(torchmetrics.F1Score(average="none", num_classes=nc, multiclass=False))
|
|
self.valid_result = self.train_result = None
|
|
|
|
def update(self, pred, target):
|
|
for part1 in self.metrics:
|
|
part1.update(pred.cpu(), target.cpu())
|
|
|
|
def compute(self):
|
|
result = []
|
|
for part1 in self.metrics:
|
|
result.append(part1.compute())
|
|
|
|
def reset(self):
|
|
for part1 in self.metrics:
|
|
part1.reset()
|
|
|
|
def get_matrix(self, loss=None, cur_lr=None, epoch=None, epoch_type=None):
|
|
temp_result = []
|
|
for j in self.metrics:
|
|
compute_result = (j.compute().cpu().numpy() * 100).tolist()
|
|
temp_result.append(compute_result)
|
|
|
|
if epoch_type == "train":
|
|
self.train_result = [loss] + temp_result
|
|
elif epoch_type == "valid":
|
|
self.valid_result = [loss] + temp_result
|
|
else:
|
|
pass
|
|
|
|
a = ""
|
|
a += f"{epoch_type} epoch: {str(epoch)} loss: {str(loss)} lr: {str(cur_lr)} \n"
|
|
a += " " * 8 + "Acc".center(8) + "Rec".center(8) + "Pre".center(8) + "Spe".center(8) + "F1".center(8) + "\n"
|
|
a += "all".center(8) + "".join([str(round(float(i), 2)).center(8) for i in temp_result]) + "\n"
|
|
return a
|
|
|
|
def wandb_log(self, wandb=None, cur_lr=None):
|
|
if wandb is None:
|
|
return
|
|
|
|
keyword = ["Accuracy", "Recall", "Precision", "Specificity", "F1Score"]
|
|
dict_key = []
|
|
for epoch_type in ["train", "valid"]:
|
|
dict_key.append(epoch_type + "/" + "loss")
|
|
for i in keyword:
|
|
dict_key.append(epoch_type + "/" + i)
|
|
|
|
log_dict = dict(zip(dict_key, self.train_result + self.valid_result))
|
|
log_dict["lr"] = cur_lr
|
|
wandb.log(log_dict)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# pred = [[0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9], [1.0]]
|
|
# true = [[0], [0], [1], [0], [0], [0], [0], [0], [0], [1]]
|
|
pred = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
|
|
true = [0, 0, 1, 0, 0, 0, 0, 0, 0, 1]
|
|
pred = torch.tensor(pred).cuda()
|
|
true = torch.tensor(true).cuda()
|
|
calc_metrics = CALC_METRICS(1)
|
|
calc_metrics.update(pred, true)
|
|
print(calc_metrics.get_matrix())
|