import glob
import os
import numpy as np
import cv2
classification=[
'airplane',
'automobile',
'bird',
'cat',
'deer',
'dog',
'frog',
'horse',
'ship',
'truck']
def unpick(file):#这是cifar10官网提供的解压函数
import pickle
with open(file,'rb') as fo:
dict=pickle.load(fo,encoding='bytes')
return dict
folders='/home/ubuntu/WorkPlace/data_manager/data/cifar-10-batches-py'#cifar10源数据集
trfiles=glob.glob(folders+'/data_batch*')#获取训练样本的地址
data=[]
labels=[]
for file in trfiles:#各小包解压后数据存在data中,label存在labels中
dt=unpick(file)
data+=list(dt[b'data'])
labels+=list(dt[b'labels'])
print(labels)
#讲数据转换为4维度的数据(也就是直观的图片),cifar中图片32*32
imgs=np.reshape(data,[-1,3,32,32])#-1代表自动获取data的数量
for i in range(imgs.shape[0]):#shape[0]代表图片总量
im_data=imgs[i,...]
im_data=np.transpose(im_data,[1,2,0])#维度转换应为opencv非通道优先顺序存储
im_data=cv2.cvtColor(im_data,cv2.COLOR_RGB2BGR)#cv非RGB格式
f='{}/{}'.format('data/image/train',classification[labels[i]])#即将储存图片,这里定义每个图片的存放地址
if not os.path.exists(f):#判断路径是否存在
os.mkdir(f)
cv2.imwrite('{}/{}.jpg'.format(f,str(i)),im_data)