TensorFlow生成三色线性训练集

Python生成三色训练集(可直接使用)

 

 1 import numpy as np
 2 import tensorflow as tf
 3 import matplotlib.pyplot as plt
 4 
 5 # 总训练数量
 6 batchSize=100
 7 # 单个训练集数量
 8 eachBatchSize = 10
 9 
10 # Red points
11 redPointsX = np.random.normal(loc=2.0, scale=1.5, size=batchSize)
12 print("redX=", redPointsX)
13 redPointsY = np.random.normal(loc=2.0, scale=1.0, size=batchSize)
14 print("redY=", redPointsY)
15 
16 '''
17 numpy中reshape函数的三种常见相关用法
18 reshape(1,-1)转化成1行:
19 reshape(2,-1)转换成两行:
20 reshape(-1,1)转换成1列: 
21 reshape(-1,2)转化成两列
22 '''
23 redP = np.concatenate(([redPointsX], [redPointsY]), 0)
24 print("red=", redP)
25 print("redPxEachBatch=", redP[0, 1:eachBatchSize], "redPyEachBatch=", redP[1, 1:eachBatchSize])
26 
27 # Blue Points
28 bluePointsX = np.random.normal(loc=5.0, scale=1.5, size=batchSize)
29 print("blueX=", bluePointsX)
30 bluePointsY = np.random.normal(loc=5.0, scale=2.3, size=batchSize)
31 print("blueY=", bluePointsY)
32 
33 blueP = np.concatenate(([bluePointsX], [bluePointsY]), 0)
34 print("blue=", blueP)
35 print("bluePxEachBatch=", blueP[0, 1:eachBatchSize], "bluePyEachBatch=", blueP[1, 1:eachBatchSize])
36 
37 # Yellow points
38 yellowPointsX = np.random.normal(loc=5.0, scale=1.2, size=batchSize)
39 print("yellowX=", yellowPointsX)
40 yellowPointsY = np.random.normal(loc=2.0, scale=1.6, size=batchSize)
41 print("yellowY=", yellowPointsY)
42 
43 yellowP = np.concatenate(([yellowPointsX], [yellowPointsY]), 0)
44 print("yellow=", yellowP)
45 print("yellowPxEachBatch=", yellowP[0, 1:eachBatchSize], "yellowPyEachBatch=", yellowP[1, 1:eachBatchSize])
46 
47 # one-hot编码
48 X0 = np.concatenate(([redPointsX], [bluePointsX], [yellowPointsX]), 0)
49 Y0 = np.concatenate(([redPointsX], [bluePointsY], [yellowPointsY]), 0)
50 oneHotIndex = np.concatenate((np.zeros((1, len(redPointsX))), np.zeros((1, len(bluePointsX))),
51                               np.zeros((1, len(yellowPointsX)))), 1)
52 print("oneHotIndex=", oneHotIndex)
53 
54 plt.figure(1)
55 plt.scatter(redP[0, :], redP[1, :], c='red')
56 plt.scatter(blueP[0, :], blueP[1, :], c='blue')
57 plt.scatter(yellowP[0, :], yellowP[1, :], c='yellow')
58 plt.title('Exercise Set')
59 plt.show()

 

posted @ 2020-04-09 18:35  20岁博客少女  阅读(125)  评论(0编辑  收藏  举报