"""
Created on Thu Oct 25 13:41:35 2018
@author: lg
"""
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.contrib.rnn as rnn
import matplotlib.pyplot as plt
TIME_STEPS=28
BATCH_SIZE=128
HIDDEN_UNITS1=30
HIDDEN_UNITS=10
LEARNING_RATE=0.001
EPOCH=50
TRAIN_EXAMPLES=42000
TEST_EXAMPLES=28000
train_frame = pd.read_csv("../Mnist/train.csv")
test_frame = pd.read_csv("../Mnist/test.csv")
train_labels_frame = train_frame.pop("label")
X_train = train_frame.astype(np.float32).values
y_train=pd.get_dummies(data=train_labels_frame).values
X_test = test_frame.astype(np.float32).values
X_train=np.reshape(X_train,newshape=(-1,28,28))
X_test=np.reshape(X_test,newshape=(-1,28,28))
graph=tf.Graph()
with graph.as_default():
X_p=tf.placeholder(dtype=tf.float32,shape=(None,TIME_STEPS,28),name="input_placeholder")
y_p=tf.placeholder(dtype=tf.float32,shape=(None,10),name="pred_placeholder")
lstm_forward_1=rnn.BasicLSTMCell(num_units=HIDDEN_UNITS1)
lstm_forward_2=rnn.BasicLSTMCell(num_units=HIDDEN_UNITS)
lstm_forward=rnn.MultiRNNCell(cells=[lstm_forward_1,lstm_forward_2])
lstm_backward_1 = rnn.BasicLSTMCell(num_units=HIDDEN_UNITS1)
lstm_backward_2 = rnn.BasicLSTMCell(num_units=HIDDEN_UNITS)
lstm_backward=rnn.MultiRNNCell(cells=[lstm_backward_1,lstm_backward_2])
outputs,states=tf.nn.bidirectional_dynamic_rnn( cell_fw=lstm_forward, cell_bw=lstm_backward, inputs=X_p, dtype=tf.float32 )
outputs_fw=outputs[0]
outputs_bw = outputs[1]
h=outputs_fw[:,-1,:]+outputs_bw[:,-1,:]
cross_loss=tf.losses.softmax_cross_entropy(onehot_labels=y_p,logits=h)
correct_prediction = tf.equal(tf.argmax(h, 1), tf.argmax(y_p, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
optimizer=tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss=cross_loss)
init=tf.global_variables_initializer()
with tf.Session(graph=graph) as sess:
sess.run(init)
for epoch in range(1,EPOCH+1):
train_losses=[]
accus=[]
print("epoch:",epoch)
for j in range(TRAIN_EXAMPLES//BATCH_SIZE):
_,train_loss,accu=sess.run( fetches=(optimizer,cross_loss,accuracy),
feed_dict={ X_p:X_train[j*BATCH_SIZE:(j+1)*BATCH_SIZE],
y_p:y_train[j*BATCH_SIZE:(j+1)*BATCH_SIZE] } )
train_losses.append(train_loss)
accus.append(accu)
print("average training loss:", sum(train_losses) / len(train_losses))
print("accuracy:",sum(accus)/len(accus))
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· .NET10 - 预览版1新功能体验(一)