代码改变世界

神经网络入门 第6章 识别手写字体

2017-04-22 18:28  JavaDaddy  阅读(457)  评论(0编辑  收藏  举报

 前言

    神经网络是一种很特别的解决问题的方法。本书将用最简单易懂的方式与读者一起从最简单开始,一步一步深入了解神经网络的基础算法。本书将尽量避开让人望而生畏的名词和数学概念,通过构造可以运行的Java程序来实践相关算法。

 

    关注微信号“逻辑编程"来获取本书的更多信息。

 

 

 

    这一章节我们将会解决一个真正的问题:手写字体识别。我们将识别像下面图中这样的手写数字。

 

 

 

    在开始之前,我们先要准备好相应的测试数据。我们不能像前边那样简单的产生手写字体,毕竟我们自己还不知道如何写出一个产生手写字体的算法。训练要达到一定的精度需要较多的训练数据。还好,前人栽树后人乘凉,先驱们已经收集了宝贵的训练材料。MNIST就是一个广泛使用的数据集。不但可以拿来用,我们还可以从网站上看到别人的识别准确率。这样我们就有了很好的参照。MNIST包含一套训练数据和一套测试数据,分别来自不同的人群的手写。

 

    MNIST网站: http://yann.lecun.com/exdb/mnist/

 

这个数据集是写在特定的二进制文件中的,并非普通图片格式。每个图片数据由28*28个像素组成。每个像素1个字节表示颜色灰度级。MNIST网站上有具体的介绍。

 

    我们写一个类来完成数据集的读取工作,并提供接口返回指定的训练或者测试数据。具体代码不做分析,仅将代码附在下面,供读者使用。代码执行前要先下载数据文件并保留GZIP格式。代码执行后将随机抽取20个生成PNG图片供读者自己查看和验证数据内容。

 

    下面我们写个测试类来识别手写字体。我们使用MNIST库的60000训练数据来反复训练我们的神经网络。每轮训练后使用MNIST库的10000个测试数据来测试识别率。

 

    下面是代码:

 


package com.luoxq.ann;

import java.util.Arrays;
import java.util.Random;

public class MnistTest {


public static void main(String... args) {
int[] shape = {28 * 28, 10};
NeuralNetwork nn = new NeuralNetwork(shape);
Mnist mnist = new Mnist();
mnist.load();
mnist.shuffle();
System.out.println("Shape: " + Arrays.toString(shape));
System.out.println("Initial correct rate: " + test(nn, mnist));
int epochs = 1000;
double rate = 0.5;
System.out.println("Learning rate: " + rate);
System.out.println("Epoch,Time,Correctness\n----------------------");
long time = System.currentTimeMillis();
Mnist.Data[] data = mnist.getTrainingSlice(0, 60000);
for (int epoch = 1; epoch <= epochs; epoch++) {
for (int sample = 0; sample < data.length; sample++) {
nn.train(data[sample].input, data[sample].output, rate);
}
long seconds = (System.currentTimeMillis() - time) / 1000;
System.out.println(epoch + ", " + seconds + ", " +
test(nn, mnist));
}
}

private static int test(NeuralNetwork nn, Mnist mnist) {
int correct = 0;
Mnist.Data[] data = mnist.getTestSlice(0, 10000);
for (int sample = 0; sample < data.length; sample++) {
if (max(nn.f(data[sample].input)) == data[sample].label) {
correct++;
}
}
return correct;
}

private static int max(double[] d) {
double max = d[0];
int idx = 0;
for (int i = 1; i < d.length; i++) {
if (max < d[i]) {
max = d[i];
idx = i;
}
}
return idx;
}
}

 

    我们先用一个10个神经元的单层神经网络试试看。结果出乎意外的好。我们很快就获得了超过90%的正确率。单层网络几乎就是对每个数字的像素分布做简单统计。能获得如此高的识别率,还是很神奇的。 在达到90%之后再训练已经效果不大,达到饱和了。我们必须换一种方法来做了。 

 

Shape: [784, 10]

Initial correct rate: 1373

Learning rate: 0.5

Epoch,Time,Correctness

----------------------

1, 4, 6429

2, 8, 7663

3, 13, 8963

4, 17, 9029

5, 22, 9016

6, 27, 9062

7, 31, 9063

8, 36, 9066

9, 41, 9072

10, 45, 9057

11, 50, 9084

12, 55, 9072

13, 61, 9062

14, 66, 9050

15, 70, 9077

16, 75, 9052

17, 79, 9068

18, 84, 9055

19, 88, 9060

20, 93, 9064

 

 

    那么我们来使用三层神经网络试一试。在试了几个不同的中间层大小和学习率参数之后,我找到了下面这个较好的参数组合:

 

Shape: [784, 50, 10]

Initial correct rate: 944

Learning rate: 1.0

Epoch,Time,Correctness

----------------------

1, 24, 7459

2, 59, 9232

3, 99, 9313

4, 131, 9379

5, 153, 9412

6, 176, 9443

7, 200, 9412

8, 226, 9447

9, 248, 9462

10, 269, 9461

11, 290, 9465

12, 314, 9493

13, 343, 9477

14, 368, 9499

15, 392, 9502

16, 420, 9509

17, 447, 9482

18, 472, 9508

19, 496, 9491

20, 518, 9536

21, 545, 9523

22, 569, 9549

23, 593, 9527

24, 618, 9527

25, 643, 9520

26, 667, 9513

27, 689, 9507

28, 712, 9527

29, 734, 9501

30, 758, 9521

31, 781, 9508

32, 804, 9534

33, 827, 9534

34, 850, 9550

35, 875, 9569

 

 

 

    我们很快达到了95%以上的正确率。可见多层网络相对单层神经网络还是有优势的。虽然这个正确率还达不到产品水平,但是作为初次尝试结果还是很不错的。

 

    下面是MNIST文件读取源代码:

 

package com.luoxq.ann;

import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.util.Random;
import java.util.zip.GZIPInputStream;

/**
* Created by luoxq on 17/4/15.
*/
public class Mnist {


static class Data {
public byte[] data;
public int label;
public double[] input;
public double[] output;
}

public static void main(String... args) throws Exception {
Mnist mnist = new Mnist();
mnist.load();
System.out.println("Data loaded.");
Random rand = new Random(System.nanoTime());
for (int i = 0; i < 20; i++) {
int idx = rand.nextInt(60000);
Data d = mnist.getTrainingData(idx);
BufferedImage img = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);
for (int x = 0; x < 28; x++) {
for (int y = 0; y < 28; y++) {
img.setRGB(x, y, toRgb(d.data[y * 28 + x]));
}
}
File output = new File(i + "_" + d.label + ".png");
if (!output.exists()) {
output.createNewFile();
}
ImageIO.write(img, "png", output);
}
}

static int toRgb(byte bb) {
int b = (255 - (0xff & bb));
return (b << 16 | b << 8 | b) & 0xffffff;
}


Data[] trainingSet;
Data[] testSet;

public void shuffle() {
Random rand = new Random();
for (int i = 0; i < trainingSet.length; i++) {
int x = rand.nextInt(trainingSet.length);
Data d = trainingSet[i];
trainingSet[i] = trainingSet[x];
trainingSet[x] = trainingSet[i];
}
}

public Data getTrainingData(int idx) {
return trainingSet[idx];
}

public Data[] getTrainingSlice(int start, int count) {
Data[] ret = new Data[count];
System.arraycopy(trainingSet, start, ret, 0, count);
return ret;
}

public Data getTestData(int idx) {
return testSet[idx];
}

public Data[] getTestSlice(int start, int count) {
Data[] ret = new Data[count];
System.arraycopy(testSet, start, ret, 0, count);
return ret;
}


public void load() {
trainingSet = load("train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz");
testSet = load("t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz");
if (trainingSet.length != 60000 || testSet.length != 10000) {
throw new RuntimeException("Unexpected training/test data size: " + trainingSet.length + "/" + testSet.length);
}
}

private Data[] load(String imgFile, String labelFile) {
byte[][] images = loadImages(imgFile);
byte[] labels = loadLabels(labelFile);
if (images.length != labels.length) {
throw new RuntimeException("Images and label doesn't match: " + imgFile + " " + labelFile);
}
int len = images.length;
Data[] data = new Data[len];
for (int i = 0; i < len; i++) {
data[i] = new Data();
data[i].data = images[i];
data[i].label = 0xff & labels[i];
data[i].input = dataToInput(images[i]);
data[i].output = labelToOutput(labels[i]);
}
return data;
}

private double[] labelToOutput(byte label) {
double[] o = new double[10];
o[label] = 1;
return o;
}

private double[] dataToInput(byte[] b) {
double[] d = new double[b.length];
for (int i = 0; i < b.length; i++) {
d[i] = (b[i] & 0xff) / 255.0;
}
return d;
}

private byte[][] loadImages(String imgFile) {
try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(imgFile)));) {
int magic = in.readInt();
if (magic != 0x00000803) {
throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));
}
int count = in.readInt();
int rows = in.readInt();
int cols = in.readInt();
if (rows != 28 || cols != 28) {
throw new RuntimeException("Unexpected row and col count: " + rows + "x" + cols);
}
byte[][] data = new byte[count][rows * cols];
for (int i = 0; i < count; i++) {
in.readFully(data[i]);
}
return data;
} catch (Exception ex) {
throw new RuntimeException("Failed to read file: " + imgFile, ex);
}
}

private byte[] loadLabels(String labelFile) {
try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(labelFile)));) {
int magic = in.readInt();
if (magic != 0x00000801) {
throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));
}
int count = in.readInt();
byte[] data = new byte[count];
in.readFully(data);
return data;
} catch (Exception ex) {
throw new RuntimeException("Failed to read file: " + labelFile, ex);
}
}

}

 

 

欢迎关注订阅号逻辑编程阅读更多内容。

Logical Programming