python读取,显示,保存mnist图片

python处理二进制

python的struct模块可以将整型(或者其它类型)转化为byte数组.看下面的代码.

# coding: utf-8
from struct import *
# 包装成大端的byte数组
print(pack('>hhl', 1, 2, 3))  # b'\x00\x01\x00\x02\x00\x00\x00\x03'

pack('>hhl', 1, 2, 3)作用是以大端的方式把1(h表示2字节整型),2,3(l表示4字节整型),转化为对于的byte数组.大端小端的区别看参数资料2,>hhl的含义见参考资料1.输出为长度为8的byte数组,2个h的长度为4,1个l的长度为4,加起来一共是8.
再体会下面代码的作用.

# coding: utf-8
from struct import *

# 包装成大端的byte数组
print(pack('>hhl', 1, 2, 3))  # b'\x00\x01\x00\x02\x00\x00\x00\x03'
# 以大端的方式还原成整型
print(unpack('>hhl', b'\x00\x01\x00\x02\x00\x00\x00\x03'))  # (1, 2, 3)


# 包装成小端的byte数组
print(pack('<hhl', 1, 2, 3))  # b'\x01\x00\x02\x00\x03\x00\x00\x00'
# 以小端的方式还原成整型
print(unpack('<hhl', b'\x00\x01\x00\x02\x00\x00\x00\x03'))  # (256, 512, 50331648)

mnist显示

以t10k-images.idx3-ubyte为例,t10k-images.idx3-ubyte是二进制文件.其格式为

[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000803(2051) magic number
0004     32 bit integer  10000            number of images
0008     32 bit integer  28               number of rows
0012     32 bit integer  28               number of columns
0016     unsigned byte   ??               pixel
0017     unsigned byte   ??               pixel
........
xxxx     unsigned byte   ??               pixel

前4个整型代表文件头的一些信息.之后的无符号byte数组才是图片的内容.所以要先越过前4个整型,然后再开始读取,代码如下

import numpy as np
import struct
import matplotlib.pyplot as plt

filename = r'D:\source\technology_source\data\t10k-images.idx3-ubyte'
binfile = open(filename, 'rb')
buf = binfile.read()

index = 0
magic, numImages, numRows, numColumns = struct.unpack_from('>IIII', buf, index)  # 读取前4个字节的内容
index += struct.calcsize('>IIII')
im = struct.unpack_from('>784B', buf, index)  # 以大端方式读取一张图上28*28=784
index += struct.calcsize('>784B')
binfile.close()

im = np.array(im)
im = im.reshape(28, 28)
fig = plt.figure()
plotwindow = fig.add_subplot(111)
plt.axis('off')
plt.imshow(im, cmap='gray')
plt.show()
# plt.savefig("test.png")  # 保存成文件
plt.close()

可以看到结果:

下面是读取多个图片并存盘的代码.

import numpy as np
import struct
import matplotlib.pyplot as plt

filename = r'D:\source\technology_source\data\t10k-images.idx3-ubyte'
binfile = open(filename, 'rb')
buf = binfile.read()

index = 0
magic, numImages, numRows, numColumns = struct.unpack_from('>IIII', buf, index)
index += struct.calcsize('>IIII')

for i in range(30):  # 读取前30张图片
    im = struct.unpack_from('>784B', buf, index)
    index += struct.calcsize('>784B')
    im = np.array(im)
    im = im.reshape(28, 28)
    fig = plt.figure()
    plotwindow = fig.add_subplot(111)
    plt.axis('off')
    plt.imshow(im, cmap='gray')
    plt.savefig("test" + str(i) + ".png")
    plt.close()
binfile.close()

另外一种方法

参考tensorflow中mnist模块的方法读取,代码如下

import gzip
import numpy
import matplotlib.pyplot as plt
filepath = r"D:\train-images-idx3-ubyte.gz"
def _read32(bytestream):
    dt = numpy.dtype(numpy.uint32).newbyteorder('>')
    return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
def imagine_arr(filepath, index):
    with open(filepath, 'rb') as f:
        with gzip.GzipFile(fileobj=f) as bytestream:
            magic = _read32(bytestream)
            if magic != 2051:
                raise ValueError('Invalid magic number %d in MNIST image file: %s' % (magic, f.name))
            num = _read32(bytestream)  # 几张图片
            rows = _read32(bytestream)
            cols = _read32(bytestream)
            if index >= num:
                index = 0
            bytestream.read(rows * cols * index)
            buf = bytestream.read(rows * cols)
            data = numpy.frombuffer(buf, dtype=numpy.ubyte)
            return data.reshape(rows, cols)
im = imagine_arr(filepath, 0)  # 显示第0张
fig = plt.figure()
plotwindow = fig.add_subplot(111)
plt.axis('off')
plt.imshow(im, cmap='gray')
plt.show()
plt.close()

用的是numpy里面的方法.函数_read32作用是读取4个字节,以大端的方式转化成无符号整型.其余代码逻辑和之前的类似.
一次显示多张

import gzip
import numpy
import matplotlib.pyplot as plt
filepath = r"D:\PrjGit\AI\py35ts100\tstutorial\asset\data\mnist\train-images-idx3-ubyte.gz"
def _read32(bytestream):
    dt = numpy.dtype(numpy.uint32).newbyteorder('>')
    return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
def imagine_arr(filepath):
    with open(filepath, 'rb') as f:
        with gzip.GzipFile(fileobj=f) as bytestream:
            magic = _read32(bytestream)
            if magic != 2051:
                raise ValueError('Invalid magic number %d in MNIST image file: %s' % (magic, f.name))
            _read32(bytestream)  # 几张图片
            rows = _read32(bytestream)
            cols = _read32(bytestream)
            img_num = 64
            buf = bytestream.read(rows * cols * img_num)
            data = numpy.frombuffer(buf, dtype=numpy.ubyte)
            return data.reshape(img_num, rows, cols, 1)
im_data = imagine_arr(filepath)
fig, axes = plt.subplots(8, 8)
for l, ax in enumerate(axes.flat):
    ax.imshow(im_data[l].reshape(28, 28), cmap='gray')
    ax.set_xticks([])
    ax.set_yticks([])
plt.show()
plt.close()

参考资料

  1. python struct官方文档
  2. Big and Little Endian
  3. python读取mnist 2012
  4. mnist数据集官网
  5. Not another MNIST tutorial with TensorFlow 2016

posted on 2017-02-24 08:10  荷楠仁  阅读(7111)  评论(0编辑  收藏  举报

导航