python读取,显示,保存mnist图片
python处理二进制
python的struct模块可以将整型(或者其它类型)转化为byte数组.看下面的代码.
# coding: utf-8
from struct import *
# 包装成大端的byte数组
print(pack('>hhl', 1, 2, 3)) # b'\x00\x01\x00\x02\x00\x00\x00\x03'
pack('>hhl', 1, 2, 3)作用是以大端的方式把1(h表示2字节整型),2,3(l表示4字节整型),转化为对于的byte数组.大端小端的区别看参数资料2,>hhl的含义见参考资料1.输出为长度为8的byte数组,2个h的长度为4,1个l的长度为4,加起来一共是8.
再体会下面代码的作用.
# coding: utf-8
from struct import *
# 包装成大端的byte数组
print(pack('>hhl', 1, 2, 3)) # b'\x00\x01\x00\x02\x00\x00\x00\x03'
# 以大端的方式还原成整型
print(unpack('>hhl', b'\x00\x01\x00\x02\x00\x00\x00\x03')) # (1, 2, 3)
# 包装成小端的byte数组
print(pack('<hhl', 1, 2, 3)) # b'\x01\x00\x02\x00\x03\x00\x00\x00'
# 以小端的方式还原成整型
print(unpack('<hhl', b'\x00\x01\x00\x02\x00\x00\x00\x03')) # (256, 512, 50331648)
mnist显示
以t10k-images.idx3-ubyte为例,t10k-images.idx3-ubyte是二进制文件.其格式为
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 10000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
前4个整型代表文件头的一些信息.之后的无符号byte数组才是图片的内容.所以要先越过前4个整型,然后再开始读取,代码如下
import numpy as np
import struct
import matplotlib.pyplot as plt
filename = r'D:\source\technology_source\data\t10k-images.idx3-ubyte'
binfile = open(filename, 'rb')
buf = binfile.read()
index = 0
magic, numImages, numRows, numColumns = struct.unpack_from('>IIII', buf, index) # 读取前4个字节的内容
index += struct.calcsize('>IIII')
im = struct.unpack_from('>784B', buf, index) # 以大端方式读取一张图上28*28=784
index += struct.calcsize('>784B')
binfile.close()
im = np.array(im)
im = im.reshape(28, 28)
fig = plt.figure()
plotwindow = fig.add_subplot(111)
plt.axis('off')
plt.imshow(im, cmap='gray')
plt.show()
# plt.savefig("test.png") # 保存成文件
plt.close()
可以看到结果:
下面是读取多个图片并存盘的代码.
import numpy as np
import struct
import matplotlib.pyplot as plt
filename = r'D:\source\technology_source\data\t10k-images.idx3-ubyte'
binfile = open(filename, 'rb')
buf = binfile.read()
index = 0
magic, numImages, numRows, numColumns = struct.unpack_from('>IIII', buf, index)
index += struct.calcsize('>IIII')
for i in range(30): # 读取前30张图片
im = struct.unpack_from('>784B', buf, index)
index += struct.calcsize('>784B')
im = np.array(im)
im = im.reshape(28, 28)
fig = plt.figure()
plotwindow = fig.add_subplot(111)
plt.axis('off')
plt.imshow(im, cmap='gray')
plt.savefig("test" + str(i) + ".png")
plt.close()
binfile.close()
另外一种方法
参考tensorflow中mnist模块的方法读取,代码如下
import gzip
import numpy
import matplotlib.pyplot as plt
filepath = r"D:\train-images-idx3-ubyte.gz"
def _read32(bytestream):
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
def imagine_arr(filepath, index):
with open(filepath, 'rb') as f:
with gzip.GzipFile(fileobj=f) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
raise ValueError('Invalid magic number %d in MNIST image file: %s' % (magic, f.name))
num = _read32(bytestream) # 几张图片
rows = _read32(bytestream)
cols = _read32(bytestream)
if index >= num:
index = 0
bytestream.read(rows * cols * index)
buf = bytestream.read(rows * cols)
data = numpy.frombuffer(buf, dtype=numpy.ubyte)
return data.reshape(rows, cols)
im = imagine_arr(filepath, 0) # 显示第0张
fig = plt.figure()
plotwindow = fig.add_subplot(111)
plt.axis('off')
plt.imshow(im, cmap='gray')
plt.show()
plt.close()
用的是numpy里面的方法.函数_read32作用是读取4个字节,以大端的方式转化成无符号整型.其余代码逻辑和之前的类似.
一次显示多张
import gzip
import numpy
import matplotlib.pyplot as plt
filepath = r"D:\PrjGit\AI\py35ts100\tstutorial\asset\data\mnist\train-images-idx3-ubyte.gz"
def _read32(bytestream):
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
def imagine_arr(filepath):
with open(filepath, 'rb') as f:
with gzip.GzipFile(fileobj=f) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
raise ValueError('Invalid magic number %d in MNIST image file: %s' % (magic, f.name))
_read32(bytestream) # 几张图片
rows = _read32(bytestream)
cols = _read32(bytestream)
img_num = 64
buf = bytestream.read(rows * cols * img_num)
data = numpy.frombuffer(buf, dtype=numpy.ubyte)
return data.reshape(img_num, rows, cols, 1)
im_data = imagine_arr(filepath)
fig, axes = plt.subplots(8, 8)
for l, ax in enumerate(axes.flat):
ax.imshow(im_data[l].reshape(28, 28), cmap='gray')
ax.set_xticks([])
ax.set_yticks([])
plt.show()
plt.close()
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· 单线程的Redis速度为什么快?