tensorflow中softmax多分类以及优化器的几个参数实例笔记

import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
(train_image,train_label),(test_image,test_label) = tf.keras.datasets.fashion_mnist.load_data()
# 加载keras里的fashion_mnist数据
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 1us/step
40960/29515 [=========================================] - 0s 1us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 5s 0us/step
26435584/26421880 [==============================] - 5s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
16384/5148 [===============================================================================================] - 0s 0s/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 1s 0us/step
4431872/4422102 [==============================] - 1s 0us/step
train_image.shape
# 60000张28*28的图片
(60000, 28, 28)
train_label.shape
(60000,)
test_image.shape,test_label.shape
((10000, 28, 28), (10000,))
plt.imshow(train_image[0])
# imshow()接收一张图像,只是画出该图,并不会立刻显示出来。
# imshow后还可以进行其他draw操作,比如scatter散点等。
# 所有画完后使用plt.show()才能进行结果的显示。如果前面加了%matplotlib inline就不需要plt.show()可以直接显示
<matplotlib.image.AxesImage at 0x23fc68281f0>

image

train_image[0]  # rgb: 0-255
array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1,
          0,   0,  13,  73,   0,   0,   1,   4,   0,   0,   0,   0,   1,
          1,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   3,
          0,  36, 136, 127,  62,  54,   0,   0,   0,   1,   3,   4,   0,
          0,   3],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   6,
          0, 102, 204, 176, 134, 144, 123,  23,   0,   0,   0,   0,  12,
         10,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0, 155, 236, 207, 178, 107, 156, 161, 109,  64,  23,  77, 130,
         72,  15],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   0,
         69, 207, 223, 218, 216, 216, 163, 127, 121, 122, 146, 141,  88,
        172,  66],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   0,
        200, 232, 232, 233, 229, 223, 223, 215, 213, 164, 127, 123, 196,
        229,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
        183, 225, 216, 223, 228, 235, 227, 224, 222, 224, 221, 223, 245,
        173,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
        193, 228, 218, 213, 198, 180, 212, 210, 211, 213, 223, 220, 243,
        202,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   1,   3,   0,  12,
        219, 220, 212, 218, 192, 169, 227, 208, 218, 224, 212, 226, 197,
        209,  52],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   6,   0,  99,
        244, 222, 220, 218, 203, 198, 221, 215, 213, 222, 220, 245, 119,
        167,  56],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   4,   0,   0,  55,
        236, 228, 230, 228, 240, 232, 213, 218, 223, 234, 217, 217, 209,
         92,   0],
       [  0,   0,   1,   4,   6,   7,   2,   0,   0,   0,   0,   0, 237,
        226, 217, 223, 222, 219, 222, 221, 216, 223, 229, 215, 218, 255,
         77,   0],
       [  0,   3,   0,   0,   0,   0,   0,   0,   0,  62, 145, 204, 228,
        207, 213, 221, 218, 208, 211, 218, 224, 223, 219, 215, 224, 244,
        159,   0],
       [  0,   0,   0,   0,  18,  44,  82, 107, 189, 228, 220, 222, 217,
        226, 200, 205, 211, 230, 224, 234, 176, 188, 250, 248, 233, 238,
        215,   0],
       [  0,  57, 187, 208, 224, 221, 224, 208, 204, 214, 208, 209, 200,
        159, 245, 193, 206, 223, 255, 255, 221, 234, 221, 211, 220, 232,
        246,   0],
       [  3, 202, 228, 224, 221, 211, 211, 214, 205, 205, 205, 220, 240,
         80, 150, 255, 229, 221, 188, 154, 191, 210, 204, 209, 222, 228,
        225,   0],
       [ 98, 233, 198, 210, 222, 229, 229, 234, 249, 220, 194, 215, 217,
        241,  65,  73, 106, 117, 168, 219, 221, 215, 217, 223, 223, 224,
        229,  29],
       [ 75, 204, 212, 204, 193, 205, 211, 225, 216, 185, 197, 206, 198,
        213, 240, 195, 227, 245, 239, 223, 218, 212, 209, 222, 220, 221,
        230,  67],
       [ 48, 203, 183, 194, 213, 197, 185, 190, 194, 192, 202, 214, 219,
        221, 220, 236, 225, 216, 199, 206, 186, 181, 177, 172, 181, 205,
        206, 115],
       [  0, 122, 219, 193, 179, 171, 183, 196, 204, 210, 213, 207, 211,
        210, 200, 196, 194, 191, 195, 191, 198, 192, 176, 156, 167, 177,
        210,  92],
       [  0,   0,  74, 189, 212, 191, 175, 172, 175, 181, 185, 188, 189,
        188, 193, 198, 204, 209, 210, 210, 211, 188, 188, 194, 192, 216,
        170,   0],
       [  2,   0,   0,   0,  66, 200, 222, 237, 239, 242, 246, 243, 244,
        221, 220, 193, 191, 179, 182, 182, 181, 176, 166, 168,  99,  58,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  40,  61,  44,  72,  41,  35,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0]], dtype=uint8)
train_label
# 第一个图像的类别是9,第二个是0,...倒数第一个是5
array([9, 0, 0, ..., 3, 0, 5], dtype=uint8)
train_image = train_image/255  # 对训练集进行归一化
test_image = test_image/255
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28)))  # 参考https://blog.csdn.net/qq_46244851/article/details/109584831
model.add(tf.keras.layers.Dense(128,activation='relu'))
model.add(tf.keras.layers.Dense(10,activation='softmax'))  # 输出十个类,用softmax转化成概率值激活
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',  # 当label使用数字编码时就使用这个损失函数
    metrics=['acc']
)
model.fit(train_image,train_label,epochs=5)
Epoch 1/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.5045 - acc: 0.8214
Epoch 2/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.3787 - acc: 0.8632
Epoch 3/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.3376 - acc: 0.8781
Epoch 4/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.3125 - acc: 0.8847
Epoch 5/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.2980 - acc: 0.8899





<keras.callbacks.History at 0x23fc62a45e0>
model.evaluate(test_image,test_label)
313/313 [==============================] - 0s 779us/step - loss: 0.3516 - acc: 0.8758





[0.35156339406967163, 0.8758000135421753]
train_label
array([9, 0, 0, ..., 3, 0, 5], dtype=uint8)

独热编码,比如:
上海用[1,0,0]表示
北京用[0,1,0]表示
深圳用[0,0,1]表示
独热编码即 One-Hot 编码,又称一位有效编码,其方法是使用N位状态寄存器来对N个状态进行编码,每个状态都有它独立的寄存器位,并且在任意时候,其中只有一位有效。

train_label_onehot = tf.keras.utils.to_categorical(train_label)# 将train_label转换成独热编码
train_label_onehot 
array([[0., 0., 0., ..., 0., 0., 1.],
       [1., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)
test_label_onehot = tf.keras.utils.to_categorical(test_label)
test_label_onehot
array([[0., 0., 0., ..., 0., 0., 1.],
       [0., 0., 1., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28)))  # 参考https://blog.csdn.net/qq_46244851/article/details/109584831
model.add(tf.keras.layers.Dense(128,activation='relu'))
model.add(tf.keras.layers.Dense(10,activation='softmax'))  # 输出十个类,用softmax转化成概率值激活
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',  # 独热编码时,就用categorical_crossentropy
    metrics=['acc']
)
model.fit(train_image,train_label_onehot,epochs=5)
Epoch 1/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.4965 - acc: 0.8247
Epoch 2/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3729 - acc: 0.8636
Epoch 3/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3368 - acc: 0.8774
Epoch 4/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3124 - acc: 0.8853
Epoch 5/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.2946 - acc: 0.8922





<keras.callbacks.History at 0x23fca64b850>
predict=model.predict(test_image)
predict.shape
(10000, 10)
predict[0]
array([1.4467735e-05, 6.7211374e-08, 5.7228771e-08, 1.8838726e-06,
       1.9777593e-07, 1.3431358e-02, 5.4569400e-06, 6.4479321e-02,
       6.1685896e-05, 9.2200553e-01], dtype=float32)
np.argmax(predict[0])  # 第一个预测的参数
9
test_label[0]
9
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28)))  # 参考https://blog.csdn.net/qq_46244851/article/details/109584831
model.add(tf.keras.layers.Dense(128,activation='relu'))
model.add(tf.keras.layers.Dense(10,activation='softmax'))  # 输出十个类,用softmax转化成概率值激活
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),  # 定义学习率为0.01,摁下shift+tab可以查看函数的用法,摁下tab+tab可以提示
    loss='categorical_crossentropy',  # 独热编码时,就用categorical_crossentropy
    metrics=['acc']
)

model.fit(train_image,train_label_onehot,epochs=5)
Epoch 1/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3720 - acc: 0.8651
Epoch 2/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3728 - acc: 0.8668
Epoch 3/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3642 - acc: 0.8694
Epoch 4/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3641 - acc: 0.8682
Epoch 5/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3713 - acc: 0.8676





<keras.callbacks.History at 0x23fcba05fa0>
posted @ 2021-10-31 19:15  闲伯  阅读(154)  评论(0编辑  收藏  举报