tf2 callback

摘自b站tf2视频教程

我们经常使用到的三个回调函数为:

  TensorBoard

  ModelCheckpoint

  EarlyStopping

可以这样使用:

logdir = "./callback"
if not os.path.exists(logdir):
    os.mkdir(logdir)
out_put_model_file = os.path.join(logdir,"fashion_mnist_model.h5")
callbacks=[
    k.callbacks.TensorBoard(logdir),
    k.callbacks.ModelCheckpoint(out_put_model_file,save_best_only=True),
    k.callbacks.EarlyStopping(patience=5,min_delta=1e-3),
]
history=model.fit(x_train,y_train,epochs=10,
         validation_data=(x_valid,y_valid),
         callbacks=callbacks)

 

完整代码:

#!/usr/bin/env python
# coding: utf-8

# In[2]:


import tensorflow as tf
import tensorflow.keras as k
import numpy as np
import matplotlib.pyplot as plt
import os


# In[3]:


fashion_mnist = k.datasets.fashion_mnist
(x_train,y_train),(x_test,y_test)=fashion_mnist.load_data()
x_train,x_valid = x_train[:5000],x_train[5000:]
y_train,y_valid= y_train[:5000],y_train[5000:]


# In[4]:


from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)


# In[5]:


#build the model
model =k.Sequential()
model.add(k.layers.Flatten(input_shape=[28,28]))
model.add(k.layers.Dense(300,activation="relu"))
model.add(k.layers.Dense(100,activation="relu"))
model.add(k.layers.Dense(10,activation="softmax"))

model.compile(loss="sparse_categorical_crossentropy",
             optimizer="adam",
             metrics=["accuracy"])


# In[7]:



logdir = "./callback"
if not os.path.exists(logdir):
    os.mkdir(logdir)
out_put_model_file = os.path.join(logdir,"fashion_mnist_model.h5")
callbacks=[
    k.callbacks.TensorBoard(logdir),
    k.callbacks.ModelCheckpoint(out_put_model_file,save_best_only=True),
    k.callbacks.EarlyStopping(patience=5,min_delta=1e-3),
]
history=model.fit(x_train,y_train,epochs=10,
         validation_data=(x_valid,y_valid),
         callbacks=callbacks)


# In[ ]:


import pandas as pd
def plot_curve(history):
    pd.DataFrame(history.history).plot(figsize=(8,5))
    plt.grid(True)
    plt.gca().set_ylim(0,1)
    plt.show()
plot_curve(history)


# In[ ]:
完整代码

 

posted @ 2020-02-25 20:12  超级学渣渣  阅读(200)  评论(0编辑  收藏  举报