RNN预测字母
#字母预测:输入a预测出b,输入b预测出c,输入c预测出d,输入d预测出e,输入e预测出a #10000 a #01000 b #00100 c #00010 d #00001 e import numpy as np import tensorflow as tf from tensorflow.keras.layers import Dense,SimpleRNN import matplotlib.pyplot as plt import os input_word='abcde' w_to_id={'a':0,'b':1,'c':2,'d':3,'e':4,} id_to_onehot={0:[1.,0.,0.,0.,0.],1:[0.,1.,0.,0.,0.],2:[0.,0.,1.,0.,0.],3:[0.,0.,0.,1.,0.],4:[0.,0.,0.,0.,1.]} x_train=[id_to_onehot[w_to_id['a']],id_to_onehot[w_to_id['b']],id_to_onehot[w_to_id['c']],id_to_onehot[w_to_id['d']],id_to_onehot[w_to_id['e']]] y_train=[w_to_id['b'],w_to_id['c'],w_to_id['d'],w_to_id['e'],w_to_id['a']] np.random.seed(7) np.random.shuffle(x_train) np.random.seed(7) np.random.shuffle(y_train) tf.random.set_seed(7) #使x_train符合SimpleRNN输入要求:[送入样本数,循环核时间展开步数,每个时间步输入特征个数] #此处整个数据集送入,所以送入研样本数为len(x_train);输入1个字母出结果,循环核时间展开步数为1;表示独热码有5个输入特征,每个时间步输入特征个数为5 x_train=np.reshape(x_train,(len(x_train),1,5)) y_train=np.array(y_train) model=tf.keras.Sequential([SimpleRNN(3),Dense(5,activation='softmax')]) model.compile(optimizer=tf.keras.optimizers.Adam(0.01), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy']) checkpoint_save_path='./checkpoint/rnn_onehot_lprel.ckpt' if os.path.exists(checkpoint_save_path+'.index'): print('-------------------load the model--------------') model.load_weights(checkpoint_save_path) cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True, save_best_only=True, monitor='loss') #由于fit没有给出测试集,不计算测试集准确率,根据loss,保存最优模型 history=model.fit(x_train,y_train,batch_size=32,epochs=50,callbacks=[cp_callback]) model.summary() # print(model.trainable_variables) file = open('./rnn_weights.txt', 'w') for v in model.trainable_variables: file.write(str(v.name) + '\n') file.write(str(v.shape) + '\n') file.write(str(v.numpy()) + '\n') file.close() ############################################### show ############################################### # 显示训练集和验证集的acc和loss曲线 acc = history.history['sparse_categorical_accuracy'] loss = history.history['loss'] plt.subplot(1, 2, 1) plt.plot(acc, label='Training Accuracy') plt.title('Training Accuracy') plt.legend() plt.subplot(1, 2, 2) plt.plot(loss, label='Training Loss') plt.title('Training Loss') plt.legend() plt.show() preNum=int(input('input the number of test alphabet')) for i in range(preNum): alphabet1=input('input test alphabet:') alphabet=[id_to_onehot[w_to_id[alphabet1]]] #使alphabet符合SimpleRNN输入要求[送入样本数,循环核时间展开步数,每个时间步输入特征个数] #使此处验证效果送入了1个样本,送入样本数为1;输入1个字母出结果,所以循环核时间展开步数为1;独热码有5个输入特征,每个时间步输入特征个数为5 alphabet=np.reshape(alphabet,(1,1,5)) reseult=model.predict([alphabet]) pred=tf.argmax(reseult,axis=1) pred=int(pred) tf.print(alphabet1 + '->' + input_word[pred])