第六讲 循环神经网路——SImpleRNN_onehot_4pred1

 1 import numpy as np
 2 import tensorflow as tf
 3 from tensorflow.keras.layers import Dense, SimpleRNN
 4 import matplotlib.pyplot as plt
 5 import os
 6 
 7 
 8 input_word = 'abcde'
 9 w_to_id = {'a':0, 'b':1, 'c':2, 'd':3, 'e':4}
10 id_to_onehot = {0:[1., 0., 0., 0., 0.], 1:[0., 1., 0., 0., 0.], 2:[0., 0., 1., 0., 0., ], 3:[0., 0., 0., 1., 0.],
11                 4:[0., 0., 0., 0., 1.]}
12 
13 
14 x_train = [[id_to_onehot[w_to_id['a']], id_to_onehot[w_to_id['b']], id_to_onehot[w_to_id['c']], id_to_onehot[w_to_id['d']]],
15            [id_to_onehot[w_to_id['b']], id_to
16            [id_to_onehot[w_to_id['d']], id__onehot[w_to_id['c']], id_to_onehot[w_to_id['d']], id_to_onehot[w_to_id['e']]],
17            [id_to_onehot[w_to_id['c']], id_to_onehot[w_to_id['d']], id_to_onehot[w_to_id['e']], id_to_onehot[w_to_id['a']]],to_onehot[w_to_id['e']], id_to_onehot[w_to_id['a']], id_to_onehot[w_to_id['b']]],
18            [id_to_onehot[w_to_id['e']], id_to_onehot[w_to_id['a']], id_to_onehot[w_to_id['b']], id_to_onehot[w_to_id['c']]]]
19 y_train = [w_to_id['e'], w_to_id['a'], w_to_id['b'], w_to_id['c'], w_to_id['d']]
20 
21 
22 print(x_train)
23 print(y_train)
24 
25 
26 np.random.seed(7)
27 np.random.shuffle(x_train)
28 np.random.seed(7)
29 np.random.shuffle(y_train)
30 tf.random.set_seed(7)
31 
32 
33 # 使x_train符合SimpleRNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]。
34 # 此处整个数据集送入,送入样本数为len(x_train);输入4个字母出结果,循环核时间展开步数为4; 表示为独热码有5个输入特征,每个时间步输入特征个数为5
35 x_train = np.reshape(x_train, (len(x_train), 4, 5))
36 y_train = np.array(y_train)
37 
38 
39 model = tf.keras.models.Sequential([
40     SimpleRNN(3),
41     Dense(5, activation='softmax')
42 ])
43 
44 model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
45               loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
46               metrics=['sparse_categorical_accuracy'])
47 
48 checkpoint_save_path = './checkpoint/rnn_onehot_4pre1.ckpt'
49 
50 if os.path.exists(checkpoint_save_path + '.index'):
51   print('-----------load the model-------------------')
52   model.load_weigts(checkpoint_save_path)
53 
54 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, 
55                                                   save_weights_only=True,
56                                                   save_best_only=True,
57                                                   monitor='loss')
58   
59 history = model.fit(x_train, y_train, batch_size=32, epochs=100, callbacks=[cp_callback])
60 
61 model.summary()
62 
63 
64 with open('./weights.txt', 'w') as f:
65   for v in model.trainable_variables:
66     f.write(str(v.name) +'\n')
67     f.write(str(v.shape) + '\n')
68     f.write(str(v.numpy()) + '\n')
69 
70 
71 
72 acc = history.history['sparse_categorical_accuracy']
73 loss = history.history['loss']
74 
75 plt.subplot(1, 2, 1)
76 plt.plot(acc, label='Training Accuracy')
77 plt.title('Training Accuracy')
78 plt.legend()
79 
80 plt.subplot(1, 2, 2)
81 plt.plot(loss, label='Training Loss')
82 plt.title('Training Loss')
83 plt.legend()
84 plt.show()
85 
86 
87 
88 preNum = int(input("input the number of test alphabet:"))
89 for i in range(preNum):
90   alphabet1 = input("input test alphabet:")
91   alphabet = [id_to_onehot[w_to_id[a]] for a in alphabet1]
92   #使alphabet符合SimpleRNN输入要求:[送入样本数, 循环核时间展开步数,
93   #每个时间步输入特征个数]。此处验证效果送入了1个样本,送入样本数为1;输入4个字母出结果,
94   #所以循环核时间展开步数为4; 表示为独热码有5个输入特征,每个时间步输入特征个数为5
95   alphabet = np.reshape(alphabet, (1, 4, 5))
96   result = model.predict([alphabet])
97   pred = tf.argmax(result, axis=1)
98   pred = int(pred)
99   tf.print(alphabet1 + '->' + input_word[pred])

 

posted @ 2020-05-12 21:40  WWBlog  阅读(386)  评论(1编辑  收藏  举报