RNN预测字母

#字母预测:输入a预测出b,输入b预测出c,输入c预测出d,输入d预测出e,输入e预测出a
#10000  a
#01000  b
#00100  c
#00010  d
#00001  e

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense,SimpleRNN
import matplotlib.pyplot as plt
import os

input_word='abcde'
w_to_id={'a':0,'b':1,'c':2,'d':3,'e':4,}
id_to_onehot={0:[1.,0.,0.,0.,0.],1:[0.,1.,0.,0.,0.],2:[0.,0.,1.,0.,0.],3:[0.,0.,0.,1.,0.],4:[0.,0.,0.,0.,1.]}
x_train=[id_to_onehot[w_to_id['a']],id_to_onehot[w_to_id['b']],id_to_onehot[w_to_id['c']],id_to_onehot[w_to_id['d']],id_to_onehot[w_to_id['e']]]
y_train=[w_to_id['b'],w_to_id['c'],w_to_id['d'],w_to_id['e'],w_to_id['a']]

np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)

#使x_train符合SimpleRNN输入要求:[送入样本数,循环核时间展开步数,每个时间步输入特征个数]
#此处整个数据集送入,所以送入研样本数为len(x_train);输入1个字母出结果,循环核时间展开步数为1;表示独热码有5个输入特征,每个时间步输入特征个数为5
x_train=np.reshape(x_train,(len(x_train),1,5))
y_train=np.array(y_train)

model=tf.keras.Sequential([SimpleRNN(3),Dense(5,activation='softmax')])

model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path='./checkpoint/rnn_onehot_lprel.ckpt'

if os.path.exists(checkpoint_save_path+'.index'):
    print('-------------------load the model--------------')
    model.load_weights(checkpoint_save_path)

cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,
                                               save_best_only=True,
                                               monitor='loss')  #由于fit没有给出测试集,不计算测试集准确率,根据loss,保存最优模型

history=model.fit(x_train,y_train,batch_size=32,epochs=50,callbacks=[cp_callback])

model.summary()
# print(model.trainable_variables)
file = open('./rnn_weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

###############################################    show   ###############################################

# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
loss = history.history['loss']


plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.title('Training Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.title('Training Loss')
plt.legend()
plt.show()

preNum=int(input('input the number of test alphabet'))

for i in range(preNum):
    alphabet1=input('input test alphabet:')
    alphabet=[id_to_onehot[w_to_id[alphabet1]]]
    #使alphabet符合SimpleRNN输入要求[送入样本数,循环核时间展开步数,每个时间步输入特征个数]
    #使此处验证效果送入了1个样本,送入样本数为1;输入1个字母出结果,所以循环核时间展开步数为1;独热码有5个输入特征,每个时间步输入特征个数为5
    alphabet=np.reshape(alphabet,(1,1,5))
    reseult=model.predict([alphabet])
    pred=tf.argmax(reseult,axis=1)
    pred=int(pred)
    tf.print(alphabet1 + '->' + input_word[pred])

 

posted @ 2020-09-03 20:28  爬到牢底坐穿  阅读(281)  评论(0编辑  收藏  举报