Tensorflow2.0 自定义训练时的二分类各种评价指标
在自己写训练步骤的时候,需要挨个写一遍,费时费事,特此记录一下。
直接上代码
# 引包,肯定是冗余的
import pandas as pd
import tensorflow as tf
from glob import glob
import numpy as np
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
from sklearn import metrics
import matplotlib.pyplot as plt
# 做数据
train_x
train_y
val_x
val_y
train_ds = tf.data.Dataset.from_tensor_slices((train_x, train_y))
val_ds = tf.data.Dataset.from_tensor_slices((val_x, val_y))
# 自定义搭模型 随便写写
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.dense = tf.keras.layers.Dense(16, activation='relu')
self.dropout = tf.keras.layers.Dropout(0.5)
# 这里要是激活了,出来都是0~1,要是不激活,负数正数都有,在损失函数设置一下就成了
self.d2 = tf.keras.layers.Dense(1, activation='sigmoid')
def call(self, x):
x = self.dropout(self.dense(x))
return self.d2(x)
# 损失函数,以及优化器
loss_object = tf.keras.losses.BinaryCrossentropy() # 二分类专属
optimizer = tf.keras.optimizers.Nadam() # 随便写写,听说这个比Adam还好点
# 指标定义
# 我喜欢把损失函数单放出来
train_loss = tf.keras.metrics.Mean(name='train_loss')
# 下面的各种各样的二分类指标
METRICS = [
tf.keras.metrics.TruePositives(name='tp'),
tf.keras.metrics.FalsePositives(name='fp'),
tf.keras.metrics.TrueNegatives(name='tn'),
tf.keras.metrics.FalseNegatives(name='fn'),
tf.keras.metrics.BinaryAccuracy(name='accuracy'),
tf.keras.metrics.Precision(name='precision'),
tf.keras.metrics.Recall(name='recall'),
tf.keras.metrics.AUC(name='auc'),
tf.keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
]
# 验证的,当然可以写一个方法重复用
val_loss = tf.keras.metrics.Mean(name='val_loss')
METRICS_ = [
tf.keras.metrics.TruePositives(name='tp'),
tf.keras.metrics.FalsePositives(name='fp'),
tf.keras.metrics.TrueNegatives(name='tn'),
tf.keras.metrics.FalseNegatives(name='fn'),
tf.keras.metrics.BinaryAccuracy(name='accuracy'),
tf.keras.metrics.Precision(name='precision'),
tf.keras.metrics.Recall(name='recall'),
tf.keras.metrics.AUC(name='auc'),
tf.keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
]
# 定义训练和验证的方法
@tf.function
def train_step(x1, labels):
with tf.GradientTape() as tape:
# training=True is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
predictions = model(x1, training=True)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
for i in METRICS:
i(labels, predictions)
@tf.function
def val_step(x1, labels):
predictions = model(x1, training=False)
loss = loss_object(labels, predictions)
val_loss(loss)
for i in METRICS_:
i(labels, predictions)
#下面就是跑起来了,指标还算了下f1,配合tensorboard,一目了然
EPOCHS = 10
for epoch in range(EPOCHS):
# Reset the metrics at the start of the next epoch
# 训练
train_loss.reset_states()
for i in METRICS:
i.reset_states()
train_par = tqdm(train_ds)
for x1, labels in train_par:
train_step(x1, labels)
str_ = ''
f1 = {}
for i in METRICS:
if i.name in ['tp', 'fp', 'tn', 'fn']:
str_ += f'{i.name}:{i.result()} \n'
else:
if i.name in ['precision', 'recall']:
f1[i.name] = i.result()
str_ += f'{i.name}:{i.result():0.5f} \n'
str_ += f"f1: {(2*f1['precision']*f1['recall'])/(f1['precision']+f1['recall']):0.5f} \n"
train_par.set_description(
f'{epoch+1} 🌏,'
f'Loss: {train_loss.result():0.5f}, '
f'{str_}'
)
# 验证
val_loss.reset_states()
for i in METRICS_:
i.reset_states()
val_par = tqdm(val_ds)
for x1, labels in val_par:
val_step(x1, labels)
str_ = ''
f1 = {}
for i in METRICS_:
if i.name in ['tp', 'fp', 'tn', 'fn']:
str_ += f'{i.name}:{i.result()} \n'
else:
if i.name in ['precision', 'recall']:
f1[i.name] = i.result()
str_ += f'{i.name}:{i.result():0.5f} \n'
m = (2*f1['precision']*f1['recall'])/(f1['precision']+f1['recall'])
str_ += f"f1: {(2*f1['precision']*f1['recall'])/(f1['precision']+f1['recall']):0.5f} \n"
val_par.set_description(
f'{epoch+1} 🌙,'
f'Loss: {val_loss.result()}, '
f'{str_}'
)
本文来自博客园,作者:赫凯,转载请注明原文链接:https://www.cnblogs.com/heKaiii/p/17137378.html
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 25岁的心里话
· 闲置电脑爆改个人服务器(超详细) #公网映射 #Vmware虚拟网络编辑器
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 零经验选手,Compose 一天开发一款小游戏!
· 一起来玩mcp_server_sqlite,让AI帮你做增删改查!!