代码改变世界

机器学习 对数据集进行批处理

2022-04-05 14:03  jym蒟蒻  阅读(136)  评论(0编辑  收藏  举报

 

只输入一张图像数据过程和一次性处理100张图像数据过程中,数组形状变换如下图所示:

在这里插入图片描述

这些数组形状可以在代码中输出出来:

def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test


def init_network():
    with open("sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network


x, t = get_data()
network = init_network()
print(x.shape)
print(x[0].shape)
W1, W2, W3 = network['W1'], network['W2'], network['W3']
print(W1.shape)
print(W2.shape)
print(W3.shape)

输出结果:

(10000, 784)
(784,)
(784, 50)
(50, 100)
(100, 10)

基于批处理的代码实现:

batch_size=100

for i in range(0, len(x), batch_size):这句话的意义,使i从0开始每次增加100 。

x_batch = x[i:i+batch_size]可以取出第i个到第i+100个之间的数据。

这样数据就变成了x[0:100]、x[100:200]、…这样的批数据。

p = np.argmax(y_batch, axis=1),这句话获取y_batch取最大值时的y_batch数组的下标。 axis=1表示从行方向找最大值。也就是说,输入一个图片,输出一个y,找0-9下标里面y最大的那个下标,就是神经网络根据这个图片猜出来的数字。

import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmax


def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test


def init_network():
    with open("sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network


def predict(network, x):
    w1, w2, w3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, w1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, w2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, w3) + b3
    y = softmax(a3)

    return y


x, t = get_data()
network = init_network()
print(x.shape)
print(x[0].shape)
W1, W2, W3 = network['W1'], network['W2'], network['W3']
print(W1.shape)
print(W2.shape)
print(W3.shape)

batch_size = 100 # 批数量
accuracy_cnt = 0

for i in range(0, len(x), batch_size):
    x_batch = x[i:i+batch_size]
    y_batch = predict(network, x_batch)
    p = np.argmax(y_batch, axis=1)
    accuracy_cnt += np.sum(p == t[i:i+batch_size])

print("Accuracy:" + str(float(accuracy_cnt) / len(x)))