Tensorflow2.0 自定义训练时的二分类各种评价指标

在自己写训练步骤的时候,需要挨个写一遍,费时费事,特此记录一下。

直接上代码

# 引包,肯定是冗余的
import pandas as pd
import tensorflow as tf
from glob import glob
import numpy as np
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
from sklearn import metrics
import matplotlib.pyplot as plt

# 做数据
train_x 
train_y 
val_x
val_y
train_ds = tf.data.Dataset.from_tensor_slices((train_x, train_y))
val_ds = tf.data.Dataset.from_tensor_slices((val_x, val_y))


# 自定义搭模型 随便写写
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.dense = tf.keras.layers.Dense(16, activation='relu')
        self.dropout = tf.keras.layers.Dropout(0.5)
        # 这里要是激活了,出来都是0~1,要是不激活,负数正数都有,在损失函数设置一下就成了
        self.d2 = tf.keras.layers.Dense(1, activation='sigmoid')
    
    def call(self, x):
        x = self.dropout(self.dense(x))
        return self.d2(x)

# 损失函数,以及优化器
loss_object = tf.keras.losses.BinaryCrossentropy()  # 二分类专属
optimizer = tf.keras.optimizers.Nadam()  # 随便写写,听说这个比Adam还好点

# 指标定义
# 我喜欢把损失函数单放出来 
train_loss = tf.keras.metrics.Mean(name='train_loss')
# 下面的各种各样的二分类指标
METRICS = [
    tf.keras.metrics.TruePositives(name='tp'),
    tf.keras.metrics.FalsePositives(name='fp'),
    tf.keras.metrics.TrueNegatives(name='tn'),
    tf.keras.metrics.FalseNegatives(name='fn'), 
    tf.keras.metrics.BinaryAccuracy(name='accuracy'),
    tf.keras.metrics.Precision(name='precision'),
    tf.keras.metrics.Recall(name='recall'),
    tf.keras.metrics.AUC(name='auc'),
    tf.keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
]

# 验证的,当然可以写一个方法重复用
val_loss = tf.keras.metrics.Mean(name='val_loss')

METRICS_ = [
    tf.keras.metrics.TruePositives(name='tp'),
    tf.keras.metrics.FalsePositives(name='fp'),
    tf.keras.metrics.TrueNegatives(name='tn'),
    tf.keras.metrics.FalseNegatives(name='fn'), 
    tf.keras.metrics.BinaryAccuracy(name='accuracy'),
    tf.keras.metrics.Precision(name='precision'),
    tf.keras.metrics.Recall(name='recall'),
    tf.keras.metrics.AUC(name='auc'),
    tf.keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
]

# 定义训练和验证的方法
@tf.function
def train_step(x1, labels):
    with tf.GradientTape() as tape:
        # training=True is only needed if there are layers with different
        # behavior during training versus inference (e.g. Dropout).
        predictions = model(x1, training=True)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    for i in METRICS:
        i(labels, predictions)

@tf.function
def val_step(x1, labels):
    predictions = model(x1, training=False)
    loss = loss_object(labels, predictions)
    val_loss(loss)
    for i in METRICS_:
        i(labels, predictions)

#下面就是跑起来了,指标还算了下f1,配合tensorboard,一目了然
EPOCHS = 10
for epoch in range(EPOCHS):
    # Reset the metrics at the start of the next epoch
    # 训练
    train_loss.reset_states()
    for i in METRICS:
        i.reset_states()
    train_par = tqdm(train_ds)
    for x1, labels in train_par:
        train_step(x1, labels)
        str_ = ''
        f1 = {}
        for i in METRICS:
            if i.name in ['tp', 'fp', 'tn', 'fn']:
                str_ += f'{i.name}:{i.result()} \n'
            else:
                if i.name in ['precision', 'recall']:
                    f1[i.name] = i.result()
                str_ += f'{i.name}:{i.result():0.5f} \n' 
        str_ += f"f1: {(2*f1['precision']*f1['recall'])/(f1['precision']+f1['recall']):0.5f} \n"
        train_par.set_description(
            f'{epoch+1} 🌏,'
            f'Loss: {train_loss.result():0.5f}, '
            f'{str_}'
        )
        
	# 验证
    val_loss.reset_states()
    for i in METRICS_:
        i.reset_states()
    val_par = tqdm(val_ds)
    for x1, labels in val_par:
        val_step(x1, labels)
        str_ = ''
        f1 = {}
        for i in METRICS_:
            if i.name in ['tp', 'fp', 'tn', 'fn']:
                str_ += f'{i.name}:{i.result()} \n'
            else:
                if i.name in ['precision', 'recall']:
                    f1[i.name] = i.result()
                str_ += f'{i.name}:{i.result():0.5f} \n' 
        m = (2*f1['precision']*f1['recall'])/(f1['precision']+f1['recall'])
        str_ += f"f1: {(2*f1['precision']*f1['recall'])/(f1['precision']+f1['recall']):0.5f} \n"
        val_par.set_description(
            f'{epoch+1} 🌙,'
            f'Loss: {val_loss.result()}, '
            f'{str_}'
        )
posted @   赫凯  阅读(27)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 25岁的心里话
· 闲置电脑爆改个人服务器(超详细) #公网映射 #Vmware虚拟网络编辑器
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 零经验选手,Compose 一天开发一款小游戏!
· 一起来玩mcp_server_sqlite,让AI帮你做增删改查!!
点击右上角即可分享
微信分享提示