sleep_apnea_hybrid/exam/021/utils/calc_metrics.py
2022-10-14 22:33:34 +08:00

85 lines
2.9 KiB
Python

#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:Marques
@file:calc_metrics.py
@email:admin@marques22.com
@email:2021022362@m.scnu.edu.cn
@time:2022/02/12
"""
import torch
import torchmetrics
class CALC_METRICS:
metrics = []
nc = 0
def __init__(self, nc):
self.nc = nc
self.metrics.append(torchmetrics.Accuracy(average="none", num_classes=nc, multiclass=False))
self.metrics.append(torchmetrics.Recall(average="none", num_classes=nc, multiclass=False))
self.metrics.append(torchmetrics.Precision(average="none", num_classes=nc, multiclass=False))
self.metrics.append(torchmetrics.Specificity(average="none", num_classes=nc, multiclass=False))
self.metrics.append(torchmetrics.F1Score(average="none", num_classes=nc, multiclass=False))
self.valid_result = self.train_result = None
def update(self, pred, target):
for part1 in self.metrics:
part1.update(pred.cpu(), target.cpu())
def compute(self):
result = []
for part1 in self.metrics:
result.append(part1.compute())
def reset(self):
for part1 in self.metrics:
part1.reset()
def get_matrix(self, loss=None, cur_lr=None, epoch=None, epoch_type=None):
temp_result = []
for j in self.metrics:
compute_result = (j.compute().cpu().numpy() * 100).tolist()
temp_result.append(compute_result)
if epoch_type == "train":
self.train_result = [loss] + temp_result
elif epoch_type == "valid":
self.valid_result = [loss] + temp_result
else:
pass
a = ""
a += f"{epoch_type} epoch: {str(epoch)} loss: {str(loss)} lr: {str(cur_lr)} \n"
a += " " * 8 + "Acc".center(8) + "Rec".center(8) + "Pre".center(8) + "Spe".center(8) + "F1".center(8) + "\n"
a += "all".center(8) + "".join([str(round(float(i), 2)).center(8) for i in temp_result]) + "\n"
return a
def wandb_log(self, wandb=None, cur_lr=None):
if wandb is None:
return
keyword = ["Accuracy", "Recall", "Precision", "Specificity", "F1Score"]
dict_key = []
for epoch_type in ["train", "valid"]:
dict_key.append(epoch_type + "/" + "loss")
for i in keyword:
dict_key.append(epoch_type + "/" + i)
log_dict = dict(zip(dict_key, self.train_result + self.valid_result))
log_dict["lr"] = cur_lr
wandb.log(log_dict)
if __name__ == '__main__':
# pred = [[0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9], [1.0]]
# true = [[0], [0], [1], [0], [0], [0], [0], [0], [0], [1]]
pred = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
true = [0, 0, 1, 0, 0, 0, 0, 0, 0, 1]
pred = torch.tensor(pred).cuda()
true = torch.tensor(true).cuda()
calc_metrics = CALC_METRICS(1)
calc_metrics.update(pred, true)
print(calc_metrics.get_matrix())