超好用的tensorboardX!

# tensorboard address
# tensorboard --logdir /home/zy/pycharm/project/MetaSAug-main/cifar/logger/Accuracy
from tensorboardX import SummaryWriter
from sklearn.metrics import confusion_matrix
logdir = 'cifar/logger/Accuracyage/test'
writerTensor = SummaryWriter(logdir)
title = f'Validate/Accuracy/test'
logdirclass = 'checkpoint/writerTensor/Cifar/logger/ClassAccuracy/Baseline'+ time.strftime("%H%M%S")
writerTensorclass = SummaryWriter(logdirclass)
titleclass = f'Validate/ClassAccuracy/BKD'
if i % args.print_freq == 0:
print("---------------------------Begin Test--------------------------")
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1))
print(' * Prec@1 {top1.avg:.3f}\t Error:{Error:.3f}'.format(top1=top1, Error=(100 - top1.val)))
print("---------------------------End All Test--------------------------")
writerTensor.add_scalar(title, top1.avg, epoch)
#--------------------------class Acc--------------------------
from sklearn.metrics import confusion_matrix
self.num_classes = config._config['arch']['args'].get('num_classes', 100)
print(self.num_classes)
num_classes = self.num_classes
# Initialize the confusion matrix
all_preds = []
all_targets = []
# Get predicted labels
preds = torch.argmax(output, dim=1)
# Collect all predictions and targets
all_preds.append(preds.cpu().numpy())
all_targets.append(target.cpu().numpy())
# Flatten the predictions and targets
all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)
# Compute confusion matrix
conf_matrix = confusion_matrix(all_targets, all_preds, labels=np.arange(num_classes))
# Calculate per-class accuracy
per_class_accuracy = conf_matrix.diagonal() / conf_matrix.sum(axis=1)
# Print or log the per-class accuracy
out_cls_acc = '%s Class Accuracy: %s' % (
'Validation', (np.array2string(per_class_accuracy, separator=',', formatter={'float_kind': lambda x: "%.3f" % x})))
print(out_cls_acc)
if epoch == 199:
for i, acc in enumerate(per_class_accuracy):
writerTensorclass.add_scalar(titleclass, acc, i) # Recording for each class

本文作者:太好了还有脑子可以用

本文链接:https://www.cnblogs.com/ZarkY/p/18122287

版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。

posted @   太好了还有脑子可以用  阅读(23)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
评论
收藏
关注
推荐
深色
回顶
收起