第六讲 循环神经网络 --SimpleRNN_onehot_1pre1
1 import numpy as np 2 import tensorflow as tf 3 from tensorflow.keras.layers import Dense, SimpleRNN 4 import matplotlib.pyplot as plt 5 import os 6 7 input_word = "abcde" 8 w_to_id = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4} # 单词映射到数值id的词典 9 id_to_onehot = {0: [1., 0., 0., 0., 0.], 1: [0., 1., 0., 0., 0.], 2: [0., 0., 1., 0., 0.], 3: [0., 0., 0., 1., 0.], 10 4: [0., 0., 0., 0., 1.]} # id编码为one-hot 11 12 x_train = [id_to_onehot[w_to_id['a']], id_to_onehot[w_to_id['b']], id_to_onehot[w_to_id['c']], 13 id_to_onehot[w_to_id['d']], id_to_onehot[w_to_id['e']]] 14 y_train = [w_to_id['b'], w_to_id['c'], w_to_id['d'], w_to_id['e'], w_to_id['a']] 15 16 np.random.seed(7) 17 np.random.shuffle(x_train) 18 np.random.seed(7) 19 np.random.shuffle(y_train) 20 tf.random.set_seed(7) 21 22 # 使x_train符合SimpleRNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]。 23 # 此处整个数据集送入,送入样本数为len(x_train);输入1个字母出结果,循环核时间展开步数为1; 表示为独热码有5个输入特征,每个时间步输入特征个数为5 24 x_train = np.reshape(x_train, (len(x_train), 1, 5)) 25 y_train = np.array(y_train) 26 27 model = tf.keras.models.Sequential([ 28 SimpleRNN(3), 29 Dense(5, activation='softmax') 30 ]) 31 32 model.compile(optimizer=tf.keras.optimizers.Adam(0.01), 33 loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), 34 metrics=['sparse_categorical_accuracy']) 35 36 checkpoint_save_path = "./checkpoint/rnn_onehot_1pre1.ckpt" 37 38 if os.path.exists(checkpoint_save_path + '.index'): 39 print('-------------load the model-----------------') 40 model.load_weights(checkpoint_save_path) 41 42 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, 43 save_weights_only=True, 44 save_best_only=True, 45 monitor='loss') # 由于fit没有给出测试集,不计算测试集准确率,根据loss,保存最优模型 46 47 history = model.fit(x_train, y_train, batch_size=32, epochs=100, callbacks=[cp_callback]) 48 49 model.summary() 50 51 # print(model.trainable_variables) 52 file = open('./weights.txt', 'w') # 参数提取 53 for v in model.trainable_variables: 54 file.write(str(v.name) + '\n') 55 file.write(str(v.shape) + '\n') 56 file.write(str(v.numpy()) + '\n') 57 file.close() 58 59 ############################################### show ############################################### 60 61 # 显示训练集和验证集的acc和loss曲线 62 acc = history.history['sparse_categorical_accuracy'] 63 loss = history.history['loss'] 64 65 plt.subplot(1, 2, 1) 66 plt.plot(acc, label='Training Accuracy') 67 plt.title('Training Accuracy') 68 plt.legend() 69 70 plt.subplot(1, 2, 2) 71 plt.plot(loss, label='Training Loss') 72 plt.title('Training Loss') 73 plt.legend() 74 plt.show() 75 76 ############### predict ############# 77 78 preNum = int(input("input the number of test alphabet:")) 79 for i in range(preNum): 80 alphabet1 = input("input test alphabet:") 81 alphabet = [id_to_onehot[w_to_id[alphabet1]]] 82 # 使alphabet符合SimpleRNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]。此处验证效果送入了1个样本,送入样本数为1;输入1个字母出结果,所以循环核时间展开步数为1; 表示为独热码有5个输入特征,每个时间步输入特征个数为5 83 alphabet = np.reshape(alphabet, (1, 1, 5)) 84 result = model.predict([alphabet]) 85 pred = tf.argmax(result, axis=1) 86 pred = int(pred) 87 tf.print(alphabet1 + '->' + input_word[pred])