Python 实现深度学习(2)
第一篇:基础知识简介
第一篇是基础知识简介,对于过于简单的知识点,不会详细叙述,分为两部分:
1. python基础知识:将后期需要的了解的知识点列出,并给出相关资料。
2. 神经网络基础知识:感知机是神经网络的前身,对感知机简单的介绍。
本篇的目的和内容主要为: 介绍感知机和python;
一、Python基础知识
本章会列出实现的神经网络所需要的基础知识,并给出参考资料。
TODO:
介绍numpy库和matplotlib库、读写二进制的方法、pkl等。这些知识会在后面用到,在本篇的最后会以mnist数据集为例,创建处理手写体图片的函数,供后使用。
1. class 和function
2. numpy
3. Matplotlib
4. 序列化
1.1 class 和function
1.2 Numpy
Nump(Numerical Python)是Python的运算库,支持大规模的数组和矩阵运算。在深度学习的实现中会使用矩阵进行计算,numpy中实现了很多数据组的运算方法,在后期会用到的有:
- Nmupy的数据结构ndarray
- Numpy的切片和索引
- 广播功能函数
- 算术函数
1.2.1 Ndarray
Numpy中主要的数据结构是Ndarray,用于存放同类型元素的多维数组。
Ndarray的内部如图1所示。
fig1. ndarray 的数据结构
数据类型:dtype,描述数据类型,可以计算每个元素大小;
数组形状:shape,描述数组的大小和形状;
跨度元组,stride:表示从前一个维度到下一个维度需要跨越的字节数;
data: 指向数组的地址;
ps: 后期会用到dtype, shape等成员变量
1.2.2 广播
Numpy对于不同形状的乘法采用了广播机制。
广播可以对不同形状的数组做点乘:将较小的形状按照一定的规则填充,填充的方向依次为由内向外;广播机制在cudnn、tensorflow等深度学习框架中同样会使用。
广播是一种ufunc的机制是 不同形状的数组之间执行算数运算的方式,需要遵循4个原则:
- 1.让所有输入数组都向其中shape最长的数组看齐,shape中不足的部分都通过在前面加1补齐
- 2.输入数组的shape是输入数组shape的各个轴上的最大值
- 3.如果输入数组的某个轴和输出数组的对应轴的长度相同或者其长度为1时,这个数组能够用来计算,否则出错。
- 4.输入数组的某个轴的长度为1时,沿着此轴运算时都用此轴上的第一组值。
举例说明:
假设,两个矩阵要做乘法,第一个矩阵是2*2, 但第二个矩阵并不是2*2的,按照数学运算法则是不能做点乘的;
但如果有广播机制,会按照以下方式填充数据,并做乘法:
先从行广播,然后再从列广播,举例如下
Case1: 行列都不一致。先填充行,再填充列。
Case2: 行不一致,列一致。先填充行
Case3: 行一致,列不一致。由于行已经一致了,不需要填充,直接填充列。
Case4: 行列不一致,且有一个维度无法广播
更多关于广播机制,详见: basics.broadcasting
1.2.3 其他知识点
numpy的切片和索引的有关内容在 fancy-indexing-and-index-tricks 中可以找到。
至于算术运算等网上的资料已经足够多的了,不需要我再重复操作了,这里给出一个官方的资料:numpy-quickstart.html
1.3 matplotlib和skimage
matplotlib和skimage在可视化数据的时候会用到。网上的资料足够多的了,在此不多介绍,给出参考资料:
https://www.runoob.com/w3cnote/matplotlib-tutorial.html
https://www.runoob.com/numpy/numpy-matplotlib.html
scikit-image
https://cloud.tencent.com/developer/section/1414638
1.4 python序列化
Serialization序列化,是将内存中对象以二进制的方式存储起来,存到磁盘。如果将磁盘中的文件解析成一个对象,这个过程称为deSerialization。序列化的数据可以用于网络传输,不会因为编码方式而改变。Python中的序列化由pickle模块实现。以下是参考资料:
pickle:Python object serialization
https://docs.python.org/zh-cn/2.7/library/pickle.html
1.5 实践:mnist数据集解析
本章会撰写程序实现一下功能:
1.下载mnsit数据集,解析mnist数据放在numpy的array中;
2.将解析的数据先序列化,然后持久化
3.反序列化,读取mnist中的一样图像,用plt或者skimage显示。
1 # -*- coding: utf-8 -*- 2 # @File : mnist.py 3 # @Author: lizhen 4 # @Date : 2020/2/4 5 # @Desc : 工具类,datasets/mnist.py 6 7 import urllib.request # python3 8 import os.path 9 import gzip 10 import pickle 11 import os 12 import numpy as np 13 14 # http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz 15 # http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz 16 # http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz 17 # http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz 18 19 20 url_base = "http://yann.lecun.com/exdb/mnist/" 21 key_file = { 22 'train_img':'train-images-idx3-ubyte.gz', 23 'train_label':'train-labels-idx1-ubyte.gz', 24 'test_img':'t10k-images-idx3-ubyte.gz', 25 'test_label':'t10k-labels-idx1-ubyte.gz' 26 } 27 28 dataset_dir=os.path.dirname(os.path.abspath(__file__)) 29 save_file=dataset_dir + "/mnist.pkl" 30 31 train_num = 60000; 32 test_num = 10000; 33 img_dim = (1, 28, 28) 34 img_size = 28*28; 35 36 37 def _download(file_name): 38 """ 39 :param file_name: 下载mnist的文件 40 :return: null 41 """ 42 file_path = os.path.join(dataset_dir, file_name) 43 44 if os.path.exists(file_path): 45 return 46 47 print("downloading"+file_name+ "...") 48 urllib.request.urlretrieve(url_base + file_name , file_path) 49 print("Done.") 50 51 def download_mnist(): 52 """ 53 54 :return: 55 """ 56 for file_name in key_file.values(): 57 _download(file_name); 58 59 def _load_label(file_name): 60 """ 61 解析标签 62 :param file_name: 63 :return: 64 """ 65 file_path = dataset_dir+'/'+ file_name 66 67 print("converting "+file_name+" to numpy Array.") 68 with gzip.open(file_path) as f: 69 labels = np.frombuffer(f.read(), np.uint8, offset=8) 70 print("Done") 71 72 return labels 73 74 def _load_img(file_name): 75 """ 76 解析 压缩的图片 77 :param file_name: 78 :return: 79 """ 80 file_path = dataset_dir +'/' + file_name 81 82 print("converting "+ file_name + "to numpy Array") 83 with gzip.open(file_path) as f: 84 data = np.frombuffer(f.read(), np.uint8, offset=16) # 16*8= 85 data = data.reshape(-1, img_size) # N, (W*H*C)=[N,28*28*1] 86 print("Done") 87 88 return data 89 90 def _convert_numpy(): 91 """ 92 解析 image和label,将其转换为numpy 93 """ 94 dataset = {} 95 dataset['train_img'] = _load_img(key_file['train_img']) 96 dataset['train_label'] = _load_label(key_file['train_label']) 97 dataset['test_img'] = _load_img(key_file['test_img']) 98 dataset['test_label'] = _load_label(key_file['test_label']) 99 100 return dataset 101 102 def init_mnist(): 103 """ 104 初始化mnist数据集: 105 1. 下载mnist, 106 2. 以二进制的方式读取,并转换成numpy的ndarray对象 107 3. 将转换后的ndarray 序列化 108 109 :return: 110 """ 111 print("download mnist dataset...") 112 download_mnist() 113 print("convert to numpy array...") 114 dataset = _convert_numpy() 115 print("creating pickle file ...") 116 with open(save_file, 'wb') as f: 117 pickle.dump(dataset, f, -1) 118 print("Done!") 119 120 def _change_one_hot_label(Y): 121 T = np.zeros((Y.size,10)) 122 for idx,row in enumerate(T): 123 row[Y[idx]] = 1 124 return T 125 126 def load_mnist(normalize=True, flatten=True, one_hot_label=False): 127 """ 128 129 :param normalize: 将数据标准化到0.0~1.0 130 :param flatten: 是否要将数据拉伸层1D数组的形式 131 :param one_hot_label: 132 :return: (训练数据, 训练标签), (测试数据, 测试label) 133 """ 134 135 136 if not os.path.exists(save_file): 137 init_mnist() 138 139 with open(save_file,'rb') as f: 140 dataset = pickle.load(f) 141 142 if normalize: 143 for key in ('train_img','test_img'): 144 dataset[key] = dataset[key].astype(np.float32) 145 dataset[key] /=255.0 146 if one_hot_label: 147 dataset['train_label'] = _change_one_hot_label(dataset['train_label']) 148 dataset['test_label'] = _change_one_hot_label(dataset['test_label']) 149 150 if not flatten: 151 for key in ('train_img', 'test_img'): 152 dataset[key] = dataset[key].reshape(-1,1,28,28) # NCHW 153 154 return (dataset['train_img'],dataset['train_label']),(dataset['test_img'], dataset['test_label']) 155 156 if __name__ == '__main__': 157 init_mnist() 158
测试
1 # -*- coding: utf-8 -*- 2 # @File : show_mnist.py 3 # @Author: lizhen 4 # @Date : 2020/1/27 5 # @Desc : 显示图片 6 7 from src.datasets.mnist import load_mnist 8 9 from skimage import io 10 11 12 def img_show(data): 13 # pil_img = Image.fromarray(np.uint8(data)) 14 io.imshow(data) 15 io.show() 16 17 18 (x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False) 19 img = x_train[0] 20 label = t_train[0] 21 print(label) 22 23 print(img.shape) 24 img = img.reshape(28,28) 25 print(img.shape) 26 27 img_show(img) 28
2020年2月11日 修改