tf2 callback
我们经常使用到的三个回调函数为:
TensorBoard
ModelCheckpoint
EarlyStopping
可以这样使用:
logdir = "./callback" if not os.path.exists(logdir): os.mkdir(logdir) out_put_model_file = os.path.join(logdir,"fashion_mnist_model.h5") callbacks=[ k.callbacks.TensorBoard(logdir), k.callbacks.ModelCheckpoint(out_put_model_file,save_best_only=True), k.callbacks.EarlyStopping(patience=5,min_delta=1e-3), ] history=model.fit(x_train,y_train,epochs=10, validation_data=(x_valid,y_valid), callbacks=callbacks)
完整代码:
#!/usr/bin/env python # coding: utf-8 # In[2]: import tensorflow as tf import tensorflow.keras as k import numpy as np import matplotlib.pyplot as plt import os # In[3]: fashion_mnist = k.datasets.fashion_mnist (x_train,y_train),(x_test,y_test)=fashion_mnist.load_data() x_train,x_valid = x_train[:5000],x_train[5000:] y_train,y_valid= y_train[:5000],y_train[5000:] # In[4]: from sklearn.preprocessing import StandardScaler scaler = StandardScaler() x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28,28) x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28,28) x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28) # In[5]: #build the model model =k.Sequential() model.add(k.layers.Flatten(input_shape=[28,28])) model.add(k.layers.Dense(300,activation="relu")) model.add(k.layers.Dense(100,activation="relu")) model.add(k.layers.Dense(10,activation="softmax")) model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) # In[7]: logdir = "./callback" if not os.path.exists(logdir): os.mkdir(logdir) out_put_model_file = os.path.join(logdir,"fashion_mnist_model.h5") callbacks=[ k.callbacks.TensorBoard(logdir), k.callbacks.ModelCheckpoint(out_put_model_file,save_best_only=True), k.callbacks.EarlyStopping(patience=5,min_delta=1e-3), ] history=model.fit(x_train,y_train,epochs=10, validation_data=(x_valid,y_valid), callbacks=callbacks) # In[ ]: import pandas as pd def plot_curve(history): pd.DataFrame(history.history).plot(figsize=(8,5)) plt.grid(True) plt.gca().set_ylim(0,1) plt.show() plot_curve(history) # In[ ]: