神经网络学习笔记2

 

预处理:将各个像素值除以255,进行了简单的正则化。

批处理:可以减轻数据总线的负荷,相对于数据读入,可以将更多的时间用在计算上。批处理一次性计算大型数组比分开逐步计算各个小型数组要快得多。

考虑打包输入多张图像情形,使用predict()一次性打包处理100张图像
x形状:100*784

for i in range(0,len(x),batch_size):
    x_batch=x[i:i+batch_size] # 通过x[i:i+batch_size]从输入数据中抽出批数据
    y_batch=predict(network,x_batch)
    p=np.argmax(y_batch,axis=1) #  通过argmax()获取值最大的元素索引
    accuracy_cnt+=np.sum(p==t[i:i+batch_size])

range():指定为range(start,end),生成一个由start到end-1之间的整数构成的列表

使用for语句逐一取出保存在x中的图像数据通过predict()函数进行分类。predict函数以numPy数组的形式输出各个标签对应的概率,取出概率列表中最大的值的索引作为预测结果。可以使用np.argmax(x)函数取出数组中最大值的索引。np.argmax(x)将获取被赋给参数x的数组中最大值元素的索引。最后,比较神经网络所预测的答案和正确解的标签,将回答正确的概率作为识别精度。

通过x[i:i+batch_size]从输入数据中抽出批数据:会取出从i到第i+batch_n个之间的数据,取出样例如x[0:100],x[100:200]...从头开始以100为单位将数据提取为批数据
然后通过argmax()函数获取值最大的元素的索引
axis=1:指定在100*10数组中,沿着第一位方向找到值最大的元素索引

使用批处理可以实现高速高效的运算

完整实现代码如下:

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmax

def get_data():
    (x_train,t_train),(x_test,t_test) = load_mnist(normalize=True,flatten=True,one_hot_label=False)
    return x_test,t_test

def init_network():
    with open("sample_weight.pkl",'rb') as f:
        network =pickle.load(f)
    return network

def predict(network,x):
    W1,W2,W3=network["W1"],network["W2"],network["W3"]
    b1,b2,b3=network["b1"],network["b2"],network["b3"]
    a1=np.dot(x,W1)+b1
    z1=sigmoid(a1)
    a2=np.dot(z1,W2)+b2
    z2=sigmoid(a2)
    a3=np.dot(z2,W3)+b3
    y=softmax(a3)
    return y

x,t=get_data()
network=init_network()
batch_size=100
accuracy_cnt=0

for i in range(0,len(x),batch_size):
    x_batch=x[i:i+batch_size] # 通过x[i:i+batch_size]从输入数据中抽出批数据
    y_batch=predict(network,x_batch)
    p=np.argmax(y_batch,axis=1) #  通过argmax()获取值最大的元素索引
    accuracy_cnt+=np.sum(p==t[i:i+batch_size])

print("Accuracy:"+str(float(accuracy_cnt)/len(x)))

mini-batch 使用mini-batch进行学习


梯度:梯度指示的方向是各个点函数值减少最多的方向

posted @ 2020-08-10 01:53  -DP-  阅读(224)  评论(0编辑  收藏  举报