Keras 识别手写数字

Keras 识别手写数字

from keras.utils import np_utils
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Activation
import numpy as np
#生成训练和测试数据集
(x_data, x_label),(y_data, y_label) = mnist.load_data()
#生成验证数据
x_val = x_data[50000:]
x_val_label = x_label[50000:]
x_data = x_data[:50000]
x_label = x_label[:50000]
#数据维度
x_data.shape, x_label.shape, y_data.shape, y_label.shape, x_val.shape, x_val_label.shape

"""
 ((50000, 28, 28),(50000,),(10000, 28, 28),(10000,),(10000, 28, 28),(10000,))
"""

#预处理,将三维转成二维
x_data = x_data.reshape(50000784).astype('float32') / 255.0
y_data = y_data.reshape(10000784).astype('float32') / 255.0
x_val = x_val.reshape(10000784).astype('float32') / 255.0
#抽取700个标本
train_rand = np.random.choice(50000700)
val_rand = np.random.choice(10000300)
#重新生成训练和验证数据集数据
x_data = x_data[train_rand]
x_label = x_label[train_rand]


x_val = x_val[val_rand]
x_val_label = x_val_label[val_rand]

x_data.shape, x_val_label.shape, x_val.shape, x_val_label.shape
"""
((700, 784), (300,), (300, 784), (300,))
"""

#one-hot编码
x_label = np_utils.to_categorical(x_label)
x_val_label = np_utils.to_categorical(x_val_label)
y_label = np_utils.to_categorical(y_label)
#搭建模型
model = Sequential()
model.add(Dense(2, input_dim=28*28, activation='relu'))
model.add(Dense(10, activation='softmax'))
model.summary()
"""
    Model: "sequential_1"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    dense_1 (Dense)              (None, 2)                 1570      
    _________________________________________________________________
    dense_2 (Dense)              (None, 10)                30        
    =================================================================
    Total params: 1,600
    Trainable params: 1,600
    Non-trainable params: 0
    _________________________________________________________________
"""

#编译
model.compile(optimizer='Adam',
              loss='categorical_crossentropy',
              metrics=['accuracy']
             )  
#训练
his = model.fit(x_data,
                x_label,
                epochs=1000
                batch_size=10
                validation_data=(y_data, y_label))
"""
    Train on 700 samples, validate on 10000 samples
    Epoch 1/1000


    2022-06-15 12:38:00.042906: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcublas.so.10


    700/700 [==============================] - 2s 3ms/step - loss: 2.2821 - accuracy: 0.1714 - val_loss: 2.2173 - val_accuracy: 0.2314
    Epoch 2/1000
    700/700 [==============================] - 2s 3ms/step - loss: 2.1893 - accuracy: 0.2114 - val_loss: 2.1363 - val_accuracy: 0.2008
    ..........................................................................................
    Epoch 3/1000
    Epoch 997/1000
    700/700 [==============================] - 2s 3ms/step - loss: 0.3267 - accuracy: 0.9114 - val_loss: 9.6746 - val_accuracy: 0.4252
    Epoch 998/1000
    700/700 [==============================] - 2s 3ms/step - loss: 0.3299 - accuracy: 0.9114 - val_loss: 9.9539 - val_accuracy: 0.4229
    Epoch 999/1000
    700/700 [==============================] - 2s 3ms/step - loss: 0.3312 - accuracy: 0.9071 - val_loss: 9.7498 - val_accuracy: 0.4248
    Epoch 1000/1000
    700/700 [==============================] - 2s 3ms/step - loss: 0.3347 - accuracy: 0.9014 - val_loss: 9.4837 - val_accuracy: 0.4229
"""

%matplotlib inline
import matplotlib.pyplot as plt
fig, loss_ax = plt.subplots()
acc_ax = loss_ax.twinx()


loss_ax.plot(his.history['loss'], 'y', label='train loss')
loss_ax.plot(his.history['val_loss'], 'r', label='val loss')

loss_ax.plot(his.history['accuracy'], 'b', label='train acc')
loss_ax.plot(his.history['val_accuracy'], 'g', label='val acc')



loss_ax.set_xlabel('epoch')
loss_ax.set_ylabel('loss')
acc_ax.set_xlabel('accuracy')

loss_ax.legend(loc='upper left')
acc_ax.legend(loc='lower left')
plt.show()
posted @ 2022-06-16 16:45  MKY-门可意  阅读(93)  评论(0编辑  收藏  举报