8-minist数据测试参数精度
一、数据读取
mnist.py有一个load_mnist()
函数,调用这个函数按下述方式可以轻松读入MNIST数据。
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)
这个函数原型为:
load_mnist(normalize=True, flatten=True, one_hot_label=False)
含有3 个 参 数。
- 第 1 个参数
normalize
设置是否将输入图像正规化为0.0~1.0的值,默认值为True,如果将该参数设置为False,则输入图像的像素会保持原来的0~255。 - 第2个参数
flatten
设置是否展开输入图像(变成一维数组),默认值为True,如果将该参数设置为False,则输入图像为1×28×28的三维数组;若设置为True,则输入图像会保存为由784个 元素构成的一维数组。 - 第3个参数
one_hot_label
设置是否将标签保存为onehot表示(one-hot representation)。 one-hot表示是仅正确解标签为1,其余 皆为0的数组,就像[0,0,1,0,0,0,0,0,0,0] 这样。当one_hot_label为False时, 只是像7、2这样简单保存正确解标签;当one_hot_label为True时,标签则保存为one-hot表示。
二、神经网络推理处理
下面,我们对这个MNIST数据集实现神经网络的推理处理。
神经网络 的输入层有784个神经元,输出层有10个神经元。输入层的784这个数字来 源于图像大小的 28 × 28 = 784 28×28 = 784 28×28=784,输出层的10这个数字来源于10类别分类(数字0到9,共10类别)。
此外,这个神经网络有2个隐藏层,第1个隐藏层有 50个神经元,第2个隐藏层有100个神经元,这个50和100可以设置为任何值。 下面我们先定义get_data()
、init_network()
、predict()
这3个函数。
def sigmoid(x):
return 1/(1 + np.exp(-x))
def softmax(x):
C = np.max(x)
return np.exp(x-C)/np.sum(np.exp(x-C))
def get_data():
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
return x_test, t_test
def init_network():
with open(r"D:\dataset\MNIST_data\sample_weight.pkl", 'rb') as f: #加载权重参数
network = pickle.load(f)
return network
def predict(network, x):
W1, W2, W3 = network['W1'], network['W2'], network['W3']
b1, b2, b3 = network['b1'], network['b2'], network['b3']
a1 = np.dot(x, W1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1, W2) + b2
z2 = sigmoid(a2)
a3 = np.dot(z2, W3) + b3
y = softmax(a3)
return y
init_network()
会读入保存在pkl文件sample_weight.pkl
中的学习到的权重参数。这个文件中以字典的形式保存了权重和偏置参数。
现在,我们用这3 个函数来实现神经网络的推理处理。然后,评价它的识别精度(accuracy), 即能在多大程度上正确分类。
x, t = get_data()
network = init_network()
accuracy_cnt = 0
for i in range(len(x)):
y = predict(network, x[i])
p = np.argmax(y) # 获取概率最高的元素的索引
if p == t[i]:
accuracy_cnt += 1
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
首先获得MNIST数据集,生成网络。接着,用for语句逐一取出保存 在x中的图像数据,用predict()函数进行分类。
predict()函数以NumPy数组的形式输出各个标签对应的概率。比如输出[0.1, 0.3, 0.2, …, 0.04]的 数组,该数组表示“0”的概率为0.1,“ 1”的概率为0.3,等等。然后,我们 取出这个概率列表中的最大值的索引(第几个元素的概率最高),作为预测结 果。可以用np.argmax(x)
函数取出数组中的最大值的索引,np.argmax(x)将 获取被赋给参数x的数组中的最大值元素的索引。
最后,比较神经网络所预测的答案和正确解标签,将回答正确的概率作为识别精度。
执行上面的代码后,会显示“Accuracy:0.9311”。这表示有93.11%的数据被正确分类了。
在这个例子中,我们把load_mnist
函数的参数normalize
设置成了 True。将normalize设置成True后,函数内部会进行转换,将图像的各个像 素值除以255,使得数据的值在0.0~1.0的范围内。像这样把数据限定到某 个范围内的处理称为正规化(normalization)。
此外,对神经网络的输入数据进行某种既定的转换称为预处理(pre-processing)。这里,作为对输入图像的 一种预处理,我们进行了正规化。
三、批处理
从 上面的这句代码 for i in range(len(x))
可以看出,每次只是传入了一张图片的数据,即每次传入的数据 x 是一维数组的
现在我们来考虑打包输入多张图像的情形。比如,我们想用predict() 函数一次性打包处理100张图像。为此,可以把 x 的形状改为100×784,将 100张图像打包作为输入数据。用图表示的话,如下图所示:
在这个处理中,只是输出变为了二维数组 100×10,这表示输入的100张图像的结果被一次性输出了,比如,x[0]和y[0]中保存了第0张图像及其推理结果,x[1]和y[1]中保存了第1张图像及其推理结果,等等。
这种打包式的输入数据称为批(batch)。批有“捆”的意思,图像就如同 纸币一样扎成一捆。
下面我们进行基于批处理的代码实现:
x, t = get_data()
network = init_network()
batch_size = 100 # 批数量
accuracy_cnt = 0
for i in range(0, len(x), batch_size):
x_batch = x[i:i+batch_size] #每次取100个数据
y_batch = predict(network, x_batch)
p = np.argmax(y_batch, axis=1)
accuracy_cnt += np.sum(p == t[i:i+batch_size])
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
这里通过argmax()获取值最大的元素的索引,比之前多了参数 axis=1
,这指定了在100×10的数组中,沿着第一维方向(这里第一维指行)找到值最大的元素(即每行的最大元素)的索引 。这里也来看一个例子:
>>> x = np.array([[0.1, 0.8, 0.1], [0.3, 0.1, 0.6],
... [0.2, 0.5, 0.3], [0.8, 0.1, 0.1]])
>>> x
array([[0.1, 0.8, 0.1],
[0.3, 0.1, 0.6],
[0.2, 0.5, 0.3],
[0.8, 0.1, 0.1]])
>>> y = np.argmax(x, axis=1)
>>> print(y)
[1 2 1 0]
注意:矩阵的第0维是列方向,第1维是行方向。
最后,我们比较一下以批为单位进行分类的结果和实际的答案。为此, 需要在NumPy数组之间使用比较运算符==
生成由True/False构成的布尔型数组,并计算True的个数。我们通过下面的例子进行确认。
>>> y = np.array([1, 2, 1, 0])
>>> t = np.array([1, 2, 0, 0])
>>> print(y==t)
[True True False True]
>>> np.sum(y==t)
3
至此,基于批处理的代码实现就介绍完了。使用批处理,可以实现高速且高效的运算。
本文来自博客园,作者:aJream,转载请记得标明出处:https://www.cnblogs.com/ajream/p/15383597.html
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人