8-minist数据测试参数精度

一、数据读取

mnist.py有一个load_mnist()函数,调用这个函数按下述方式可以轻松读入MNIST数据。

(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)

这个函数原型为:

load_mnist(normalize=True, flatten=True, one_hot_label=False)

含有3 个 参 数。

  • 第 1 个参数 normalize设置是否将输入图像正规化为0.0~1.0的值,默认值为True,如果将该参数设置为False,则输入图像的像素会保持原来的0~255。
  • 第2个参数flatten设置是否展开输入图像(变成一维数组),默认值为True,如果将该参数设置为False,则输入图像为1×28×28的三维数组;若设置为True,则输入图像会保存为由784个 元素构成的一维数组。
  • 第3个参数one_hot_label设置是否将标签保存为onehot表示(one-hot representation)。 one-hot表示是仅正确解标签为1,其余 皆为0的数组,就像[0,0,1,0,0,0,0,0,0,0] 这样。当one_hot_label为False时, 只是像7、2这样简单保存正确解标签;当one_hot_label为True时,标签则保存为one-hot表示。

二、神经网络推理处理

下面,我们对这个MNIST数据集实现神经网络的推理处理。

神经网络 的输入层有784个神经元,输出层有10个神经元。输入层的784这个数字来 源于图像大小的 28 × 28 = 784 28×28 = 784 28×28=784,输出层的10这个数字来源于10类别分类(数字0到9,共10类别)。

此外,这个神经网络有2个隐藏层,第1个隐藏层有 50个神经元,第2个隐藏层有100个神经元,这个50和100可以设置为任何值。 下面我们先定义get_data()init_network()predict()这3个函数。

def sigmoid(x):
    return 1/(1 + np.exp(-x))

def softmax(x):
    C = np.max(x)
    return np.exp(x-C)/np.sum(np.exp(x-C))

def get_data():   
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test

def init_network():  
    with open(r"D:\dataset\MNIST_data\sample_weight.pkl", 'rb') as f:     #加载权重参数
        network = pickle.load(f)
    return network

def predict(network, x):   
    W1, W2, W3 = network['W1'], network['W2'], network['W3']  
    b1, b2, b3 = network['b1'], network['b2'], network['b3']
    
    a1 = np.dot(x, W1) + b1    
    z1 = sigmoid(a1)  
    
    a2 = np.dot(z1, W2) + b2    
    z2 = sigmoid(a2)    
    
    a3 = np.dot(z2, W3) + b3  
    y = softmax(a3)
    return y

init_network()会读入保存在pkl文件sample_weight.pkl中的学习到的权重参数。这个文件中以字典的形式保存了权重和偏置参数。

现在,我们用这3 个函数来实现神经网络的推理处理。然后,评价它的识别精度(accuracy), 即能在多大程度上正确分类。

x, t = get_data() 
network = init_network()
accuracy_cnt = 0 
for i in range(len(x)):  
    y = predict(network, x[i])  
    p = np.argmax(y) # 获取概率最高的元素的索引  
    if p == t[i]:      
        accuracy_cnt += 1
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

首先获得MNIST数据集,生成网络。接着,用for语句逐一取出保存 在x中的图像数据,用predict()函数进行分类。

predict()函数以NumPy数组的形式输出各个标签对应的概率。比如输出[0.1, 0.3, 0.2, …, 0.04]的 数组,该数组表示“0”的概率为0.1,“ 1”的概率为0.3,等等。然后,我们 取出这个概率列表中的最大值的索引(第几个元素的概率最高),作为预测结 果。可以用np.argmax(x)函数取出数组中的最大值的索引,np.argmax(x)将 获取被赋给参数x的数组中的最大值元素的索引。

最后,比较神经网络所预测的答案和正确解标签,将回答正确的概率作为识别精度。

执行上面的代码后,会显示“Accuracy:0.9311”。这表示有93.11%的数据被正确分类了。

在这个例子中,我们把load_mnist函数的参数normalize设置成了 True。将normalize设置成True后,函数内部会进行转换,将图像的各个像 素值除以255,使得数据的值在0.0~1.0的范围内。像这样把数据限定到某 个范围内的处理称为正规化(normalization)。

此外,对神经网络的输入数据进行某种既定的转换称为预处理(pre-processing)。这里,作为对输入图像的 一种预处理,我们进行了正规化。

三、批处理

从 上面的这句代码 for i in range(len(x)) 可以看出,每次只是传入了一张图片的数据,即每次传入的数据 x 是一维数组的

在这里插入图片描述

现在我们来考虑打包输入多张图像的情形。比如,我们想用predict() 函数一次性打包处理100张图像。为此,可以把 x 的形状改为100×784,将 100张图像打包作为输入数据。用图表示的话,如下图所示:

在这里插入图片描述

在这个处理中,只是输出变为了二维数组 100×10,这表示输入的100张图像的结果被一次性输出了,比如,x[0]和y[0]中保存了第0张图像及其推理结果,x[1]和y[1]中保存了第1张图像及其推理结果,等等。

这种打包式的输入数据称为批(batch)。批有“捆”的意思,图像就如同 纸币一样扎成一捆。

下面我们进行基于批处理的代码实现:

x, t = get_data() 
network = init_network()
batch_size = 100 # 批数量 
accuracy_cnt = 0

for i in range(0, len(x), batch_size):  
    x_batch = x[i:i+batch_size]  #每次取100个数据
    y_batch = predict(network, x_batch)    
    p = np.argmax(y_batch, axis=1)   
    accuracy_cnt += np.sum(p == t[i:i+batch_size])
    
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

这里通过argmax()获取值最大的元素的索引,比之前多了参数 axis=1,这指定了在100×10的数组中,沿着第一维方向(这里第一维指行)找到值最大的元素(即每行的最大元素)的索引 。这里也来看一个例子:

>>> x = np.array([[0.1, 0.8, 0.1], [0.3, 0.1, 0.6],
...    		[0.2, 0.5, 0.3], [0.8, 0.1, 0.1]]) 
>>> x
array([[0.1, 0.8, 0.1], 
       [0.3, 0.1, 0.6],
       [0.2, 0.5, 0.3], 
       [0.8, 0.1, 0.1]])

>>> y = np.argmax(x, axis=1) 
>>> print(y) 
[1 2 1 0]

注意:矩阵的第0维是列方向,第1维是行方向。

最后,我们比较一下以批为单位进行分类的结果和实际的答案。为此, 需要在NumPy数组之间使用比较运算符==生成由True/False构成的布尔型数组,并计算True的个数。我们通过下面的例子进行确认。

>>> y = np.array([1, 2, 1, 0]) 
>>> t = np.array([1, 2, 0, 0]) 
>>> print(y==t) 
[True True False True] 

>>> np.sum(y==t)
3

至此,基于批处理的代码实现就介绍完了。使用批处理,可以实现高速且高效的运算。

posted @   aJream  阅读(183)  评论(0编辑  收藏  举报
编辑推荐:
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
点击右上角即可分享
微信分享提示