tensorflow 线性回归解决 iris 2分类
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | # Combining Everything Together #---------------------------------- # This file will perform binary classification on the # iris dataset. We will only predict if a flower is # I.setosa or not. # # We will create a simple binary classifier by creating a line # and running everything through a sigmoid to get a binary predictor. # The two features we will use are pedal length and pedal width. # # We will use batch training, but this can be easily # adapted to stochastic training. import matplotlib.pyplot as plt import numpy as np from sklearn import datasets import tensorflow as tf from tensorflow.python.framework import ops ops.reset_default_graph() # Load the iris data # iris.target = {0, 1, 2}, where '0' is setosa # iris.data ~ [sepal.width, sepal.length, pedal.width, pedal.length] iris = datasets.load_iris() binary_target = np.array([ 1. if x = = 0 else 0. for x in iris.target]) iris_2d = np.array([[x[ 2 ], x[ 3 ]] for x in iris.data]) # Declare batch size batch_size = 20 # Create graph sess = tf.Session() # Declare placeholders x1_data = tf.placeholder(shape = [ None , 1 ], dtype = tf.float32) x2_data = tf.placeholder(shape = [ None , 1 ], dtype = tf.float32) y_target = tf.placeholder(shape = [ None , 1 ], dtype = tf.float32) # Create variables A and b (0 = x1 - A*x2 + b) A = tf.Variable(tf.random_normal(shape = [ 1 , 1 ])) b = tf.Variable(tf.random_normal(shape = [ 1 , 1 ])) # Add model to graph: # x1 - A*x2 + b my_mult = tf.matmul(x2_data, A) my_add = tf.add(my_mult, b) my_output = tf.subtract(x1_data, my_add) # Add classification loss (cross entropy) xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits = my_output, labels = y_target) # Create Optimizer my_opt = tf.train.GradientDescentOptimizer( 0.05 ) train_step = my_opt.minimize(xentropy) # Initialize variables init = tf.global_variables_initializer() sess.run(init) # Run Loop for i in range ( 1000 ): rand_index = np.random.choice( len (iris_2d), size = batch_size) #rand_x = np.transpose([iris_2d[rand_index]]) rand_x = iris_2d[rand_index] rand_x1 = np.array([[x[ 0 ]] for x in rand_x]) rand_x2 = np.array([[x[ 1 ]] for x in rand_x]) #rand_y = np.transpose([binary_target[rand_index]]) rand_y = np.array([[y] for y in binary_target[rand_index]]) sess.run(train_step, feed_dict = {x1_data: rand_x1, x2_data: rand_x2, y_target: rand_y}) if (i + 1 ) % 200 = = 0 : print ( 'Step #' + str (i + 1 ) + ' A = ' + str (sess.run(A)) + ', b = ' + str (sess.run(b))) # Visualize Results # Pull out slope/intercept [[slope]] = sess.run(A) [[intercept]] = sess.run(b) # Create fitted line x = np.linspace( 0 , 3 , num = 50 ) ablineValues = [] for i in x: ablineValues.append(slope * i + intercept) # Plot the fitted line over the data setosa_x = [a[ 1 ] for i,a in enumerate (iris_2d) if binary_target[i] = = 1 ] setosa_y = [a[ 0 ] for i,a in enumerate (iris_2d) if binary_target[i] = = 1 ] non_setosa_x = [a[ 1 ] for i,a in enumerate (iris_2d) if binary_target[i] = = 0 ] non_setosa_y = [a[ 0 ] for i,a in enumerate (iris_2d) if binary_target[i] = = 0 ] plt.plot(setosa_x, setosa_y, 'rx' , ms = 10 , mew = 2 , label = 'setosa' ) plt.plot(non_setosa_x, non_setosa_y, 'ro' , label = 'Non-setosa' ) plt.plot(x, ablineValues, 'b-' ) plt.xlim([ 0.0 , 2.7 ]) plt.ylim([ 0.0 , 7.1 ]) plt.suptitle( 'Linear Separator For I.setosa' , fontsize = 20 ) plt.xlabel( 'Petal Length' ) plt.ylabel( 'Petal Width' ) plt.legend(loc = 'lower right' ) plt.show() |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
2017-05-05 influx测试——单条读性能很差,大约400条/s,批量写性能很高,7万条/s,总体说来适合IOT数据批量存,根据tag查和过滤场景,按照时间顺序返回
2017-05-05 乐视云监控数据存放到influxdb中
2017-05-05 influxdb入门——和mongodb一样可以动态增加字段
2017-05-05 InfluxDB 分布式时间序列数据库环境搭建——据qcon大会2016qiniu说集群很坑且闭源了