参数提取,写至文本
1.提取可训练参数
model.trainable_variables模型中可训练的参数
2.配置print的输出格式
np.set_printoptions(precision=小数点后按四舍五入保留几位,threshold=数组元素数量少于或等于门槛值,打印全部元素;否则打印门槛值+1 个元素,中间用省略号补充)
注:threshold=np.inf 可以打印全部数组元素
模型参数打印结果:
weights_mnist.txt
完整代码
import tensorflow as tf import os import numpy as np np.set_printoptions(threshold=np.inf) mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, y_train / 255.0 model = tf.keras.models.Sequential( tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ) model.compile( optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy'] ) checkpoint_save_path = './checkpoint/mnist.ckpt' if os.path.exists(checkpoint_save_path + '.index'): print('--------------- load the model ---------------') model.load_weights(checkpoint_save_path) cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, monitor='val_loss', save_best_only=True) history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=cp_callback) model.summary() print(model.trainable_variables) file = open('./weights_mnist.txt') for v in model.trainable_variables: file.write(str(v.name) + '\n') file.write(str(v.shape) + '\n') file.write(str(v.numpy) + '\n') file.close()