导航

Go读取MNIST数据

Posted on 2022-04-23 08:06  蝈蝈俊  阅读(178)  评论(0编辑  收藏  举报

MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据库,包含60,000个示例的训练集以及10,000个示例的测试集.

MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:

  • train-images-idx3-ubyte.gz: training set images (9912422 bytes),训练图像数据
  • train-labels-idx1-ubyte.gz: training set labels (28881 bytes),训练图像标签
  • t10k-images-idx3-ubyte.gz: test set images (1648877 bytes),测试图像数据
  • t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes),测试图像标签

以train-images-idx3-ubyte.gz为例,这是个压缩文件,需要解压为train-images.idx3-ubyte,用winhex打开此文件,如下图所示,00000803是固定,0000EA60表示有6万张图片,0000001C表示图片的宽度为28,接着0000001C表示图片的高度为28。

读取的代码可以看
https://gist.github.com/higuma/dbcd006546eb844c01e5102b4d0bcc93

使用上面代码读取第一个数字的例子


package main

import (
	"ghj1976/aigo/nn/mnist"
	"image"
	"image/color"
	"image/png"
	"log"
	"os"
)

func main() {

	dataSet, err := mnist.ReadTrainSet("../mnist")
	if err != nil {
		log.Fatal(err)
	}

	imCols := 28
	imRows := 28

	rect := image.Rect(0, 0, imCols, imRows)

	rgba := image.NewNRGBA(rect)

	log.Println(dataSet.Data[0].Digit)
	for dy := 0; dy < imCols; dy++ {
		for dx := 0; dx < imRows; dx++ {
			rgba.Set(dy, dx, color.Gray{dataSet.Data[0].Image[dx][dy]})
		}
	}

	fIm, err := os.Create("a0.png")

	if nil != err {
		log.Fatal(err)
	}

	err = png.Encode(fIm, rgba)

	if nil != err {
		log.Fatal(err)
	}

}

读取出来的是数字5