手写神经网络(MNIST手写数字识别)
1 # 手写神经网络——mnist手写数字数据集 2 import numpy as np 3 # import torch 4 import torchvision 5 import torchvision.transforms as transforms 6 # from torch.utils.data import DataLoader 7 # import cv2 8 9 # input layer:784 nodes(28*28) 10 # hidden layer:three hidden layers with 20 nodes in each layer 11 # output layer:10 nodes 12 class BP: 13 def __init__(self): 14 self.input = np.zeros((100, 784)) # 100 samples per round 15 self.hidden_layer_1 = np.zeros((100, 20)) 16 self.hidden_layer_2 = np.zeros((100, 20)) 17 self.hidden_layer_3 = np.zeros((100, 20)) 18 self.output_layer = np.zeros((100, 10)) 19 self.w1 = 2 * np.random.random((784, 20)) - 1 # limit to (-1, 1) 20 self.w2 = 2 * np.random.random((20, 20)) - 1 21 self.w3 = 2 * np.random.random((20, 20)) - 1 22 self.w4 = 2 * np.random.random((20, 10)) - 1 23 self.error = np.zeros(10) 24 self.learning_rate = 0.1 25 26 def sigmoid(self, x): 27 return 1 / (1 + np.exp(-x)) 28 29 def sigmoid_deri(self, x): 30 return x * (1 - x) 31 32 def forward_prop(self, data, label): # label:100 X 10,data: 100 X 784 33 self.input = data 34 self.hidden_layer_1 = self.sigmoid(np.dot(self.input, self.w1)) 35 self.hidden_layer_2 = self.sigmoid(np.dot(self.hidden_layer_1, self.w2)) 36 self.hidden_layer_3 = self.sigmoid(np.dot(self.hidden_layer_2, self.w3)) 37 self.output_layer = self.sigmoid(np.dot(self.hidden_layer_3, self.w4)) 38 # error 39 self.error = label - self.output_layer 40 return self.output_layer 41 42 def backward_prop(self): 43 output_diff = self.error * self.sigmoid_deri(self.output_layer) 44 hidden_diff_3 = np.dot(output_diff, self.w4.T) * self.sigmoid_deri(self.hidden_layer_3) 45 hidden_diff_2 = np.dot(hidden_diff_3, self.w3.T) * self.sigmoid_deri(self.hidden_layer_2) 46 hidden_diff_1 = np.dot(hidden_diff_2, self.w2.T) * self.sigmoid_deri(self.hidden_layer_1) 47 # update 48 self.w4 += self.learning_rate * np.dot(self.hidden_layer_3.T, output_diff) 49 self.w3 += self.learning_rate * np.dot(self.hidden_layer_2.T, hidden_diff_3) 50 self.w2 += self.learning_rate * np.dot(self.hidden_layer_1.T, hidden_diff_2) 51 self.w1 += self.learning_rate * np.dot(self.input.T, hidden_diff_1) 52 53 # from torchvision load data 54 def load_data(): 55 datasets_train = torchvision.datasets.MNIST(root='../../data/', train=True, transform=transforms.ToTensor()) # , download=True) 56 # print(datasets_train) 57 datasets_test = torchvision.datasets.MNIST(root='../../data/', train=False, transform=transforms.ToTensor()) 58 59 data_train = datasets_train.data 60 # print(data_train) 61 X_train = data_train.numpy() 62 # print(X_train) 63 X_test = datasets_test.data.numpy() 64 X_train = np.reshape(X_train, (60000, 784)) 65 X_test = np.reshape(X_test, (10000, 784)) 66 Y_train = datasets_train.targets.numpy() 67 Y_test = datasets_test.targets.numpy() 68 69 real_train_y = np.zeros((60000, 10)) 70 real_test_y = np.zeros((10000, 10)) 71 # each y has ten dimensions 72 for i in range(60000): 73 real_train_y[i, Y_train[i]] = 1 74 for i in range(10000): 75 real_test_y[i, Y_test[i]] = 1 76 index = np.arange(60000) # 返回一个有终点和起点的固定步长的排列 77 np.random.shuffle(index) # 打乱顺序函数 78 # shuffle train_data 79 X_train = X_train[index] 80 real_train_y = real_train_y[index] 81 82 X_train = np.int64(X_train > 0) 83 X_test = np.int64(X_test > 0) 84 85 86 return X_train, real_train_y, X_test, real_test_y 87 88 89 def bp_network(): 90 nn = BP() 91 X_train, Y_train, X_test, Y_test = load_data() 92 batch_size = 100 93 epochs = 6000 94 for epoch in range(epochs): 95 start = (epoch % 600) * batch_size 96 end = start + batch_size 97 # print(start, end) 98 nn.forward_prop(X_train[start: end], Y_train[start: end]) 99 nn.backward_prop() 100 101 return nn 102 103 104 def bp_test(): 105 nn = bp_network() 106 sum = 0 107 X_train, Y_train, X_test, Y_test = load_data() 108 # test: 109 for i in range(len(X_test)): 110 res = nn.forward_prop(X_test[i], Y_test[i]) 111 res = res.tolist() # 转换为列表 112 index = res.index(max(res)) # 检测字符串中是否包含子字符串str 113 if Y_test[i, index] == 1: 114 sum += 1 115 116 print('accuracy:', sum / len(Y_test)) 117 118 119 if __name__ == '__main__': 120 bp_test()