#!/usr/bin/python # -*- coding: UTF-8 -*- """ @author:Marques @file:calc_metrics.py @email:admin@marques22.com @email:2021022362@m.scnu.edu.cn @time:2022/02/12 """ import torch import torchmetrics class CALC_METRICS: metrics = [] nc = 0 def __init__(self, nc): self.nc = nc self.metrics.append(torchmetrics.Accuracy(average="none", num_classes=nc, multiclass=False)) self.metrics.append(torchmetrics.Recall(average="none", num_classes=nc, multiclass=False)) self.metrics.append(torchmetrics.Precision(average="none", num_classes=nc, multiclass=False)) self.metrics.append(torchmetrics.Specificity(average="none", num_classes=nc, multiclass=False)) self.metrics.append(torchmetrics.F1Score(average="none", num_classes=nc, multiclass=False)) self.valid_result = self.train_result = None def update(self, pred, target): for part1 in self.metrics: part1.update(pred.cpu(), target.cpu()) def compute(self): result = [] for part1 in self.metrics: result.append(part1.compute()) def reset(self): for part1 in self.metrics: part1.reset() def get_matrix(self, loss=None, cur_lr=None, epoch=None, epoch_type=None): temp_result = [] for j in self.metrics: compute_result = (j.compute().cpu().numpy() * 100).tolist() temp_result.append(compute_result) if epoch_type == "train": self.train_result = [loss] + temp_result elif epoch_type == "valid": self.valid_result = [loss] + temp_result else: pass a = "" a += f"{epoch_type} epoch: {str(epoch)} loss: {str(loss)} lr: {str(cur_lr)} \n" a += " " * 8 + "Acc".center(8) + "Rec".center(8) + "Pre".center(8) + "Spe".center(8) + "F1".center(8) + "\n" a += "all".center(8) + "".join([str(round(float(i), 2)).center(8) for i in temp_result]) + "\n" return a def wandb_log(self, wandb=None, cur_lr=None): if wandb is None: return keyword = ["Accuracy", "Recall", "Precision", "Specificity", "F1Score"] dict_key = [] for epoch_type in ["train", "valid"]: dict_key.append(epoch_type + "/" + "loss") for i in keyword: dict_key.append(epoch_type + "/" + i) log_dict = dict(zip(dict_key, self.train_result + self.valid_result)) log_dict["lr"] = cur_lr wandb.log(log_dict) if __name__ == '__main__': # pred = [[0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9], [1.0]] # true = [[0], [0], [1], [0], [0], [0], [0], [0], [0], [1]] pred = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] true = [0, 0, 1, 0, 0, 0, 0, 0, 0, 1] pred = torch.tensor(pred).cuda() true = torch.tensor(true).cuda() calc_metrics = CALC_METRICS(1) calc_metrics.update(pred, true) print(calc_metrics.get_matrix())