knn 数据集准备
1 """ 2 Created on Mon Aug 26 20:57:24 2019 3 4 @author: huoqs 5 6 knn algorithm 7 """ 8 import numpy as np 9 import matplotlib.pyplot as plt 10 11 def generate_data(num_samples, num_features=2): 12 data_size = (num_samples, num_features) 13 data = np.random.randint(0, 100, data_size) 14 15 label_size = (num_samples, 1) 16 labels = np.random.randint(0, 2, label_size) 17 # must be float32 18 return data.astype(np.float32), labels 19 20 def plot_data(all_blue, all_red): 21 plt.scatter(all_blue[:, 0], all_blue[:, 1], c = 'b', marker = 's', s = 180) 22 plt.scatter(all_red[:, 0], all_red[:, 1], c = 'r', marker = '^', s = 180) 23 plt.xlabel('x') 24 plt.ylabel('y') 25 26 plt.style.use('ggplot') 27 28 np.random.seed(42) 29 30 train_data, labels = generate_data(11) 31 32 # print(train_data, labels) 33 34 blue = train_data[labels.ravel() == 0] 35 red = train_data[labels.ravel() == 1] 36 37 plot_data(blue, red)
知识点:
1、np.random.randint 函数,生成一个数组,参数:low,high,size,type
https://docs.scipy.org/doc/numpy-1.15.1/reference/generated/numpy.random.randint.html
2、ndarray.ravel(),将数组扁平化,变为一维数组
https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.ravel.html