TensorFlow实现线性回归模型代码
模型构建
1.示例代码linear_regression_model.py
#!/usr/bin/python # -*- coding: utf-8 -* import tensorflow as tf import numpy as np class linearRegressionModel: def __init__(self,x_dimen): self.x_dimen = x_dimen self._index_in_epoch = 0 self.constructModel() self.sess = tf.Session() self.sess.run(tf.global_variables_initializer()) #权重初始化 def weight_variable(self,shape): initial = tf.truncated_normal(shape,stddev = 0.1) return tf.Variable(initial) #偏置项初始化 def bias_variable(self,shape): initial = tf.constant(0.1,shape = shape) return tf.Variable(initial) #每次选取100个样本,如果选完,重新打乱 def next_batch(self,batch_size): start = self._index_in_epoch self._index_in_epoch += batch_size if self._index_in_epoch > self._num_datas: perm = np.arange(self._num_datas) np.random.shuffle(perm) self._datas = self._datas[perm] self._labels = self._labels[perm] start = 0 self._index_in_epoch = batch_size assert batch_size <= self._num_datas end = self._index_in_epoch return self._datas[start:end],self._labels[start:end] def constructModel(self): self.x = tf.placeholder(tf.float32, [None,self.x_dimen]) self.y = tf.placeholder(tf.float32,[None,1]) self.w = self.weight_variable([self.x_dimen,1]) self.b = self.bias_variable([1]) self.y_prec = tf.nn.bias_add(tf.matmul(self.x, self.w), self.b) mse = tf.reduce_mean(tf.squared_difference(self.y_prec, self.y)) l2 = tf.reduce_mean(tf.square(self.w)) self.loss = mse + 0.15*l2 self.train_step = tf.train.AdamOptimizer(0.1).minimize(self.loss) def train(self,x_train,y_train,x_test,y_test): self._datas = x_train self._labels = y_train self._num_datas = x_train.shape[0] for i in range(5000): batch = self.next_batch(100) self.sess.run(self.train_step,feed_dict={self.x:batch[0],self.y:batch[1]}) if i%10 == 0: train_loss = self.sess.run(self.loss,feed_dict={self.x:batch[0],self.y:batch[1]}) print('step %d,test_loss %f' % (i,train_loss)) def predict_batch(self,arr,batch_size): for i in range(0,len(arr),batch_size): yield arr[i:i + batch_size] def predict(self, x_predict): pred_list = [] for x_test_batch in self.predict_batch(x_predict,100): pred = self.sess.run(self.y_prec, {self.x:x_test_batch}) pred_list.append(pred) return np.vstack(pred_list)
2.创建run.py
#!/usr/bin/python # -*- coding: utf-8 -* from sklearn.model_selection import train_test_split from sklearn.metrics import r2_score from sklearn.datasets import make_regression from sklearn.linear_model import LinearRegression from linear_regression_model import linearRegressionModel as lrm if __name__ == '__main__': x, y = make_regression(7000) x_train,x_test,y_train, y_test = train_test_split(x, y, test_size=0.5) y_lrm_train = y_train.reshape(-1, 1) y_lrm_test = y_test.reshape(-1, 1) linear = lrm(x.shape[1]) linear.train(x_train, y_lrm_train,x_test,y_lrm_test) y_predict = linear.predict(x_test) print("Tensorflow R2: ", r2_score(y_predict.ravel(), y_lrm_test.ravel())) lr = LinearRegression() y_predict = lr.fit(x_train, y_train).predict(x_test) print("Sklearn R2: ", r2_score(y_predict, y_test)) #采用r2_score评分函数
运行结果:
step 0,test_loss 27078.781250 step 10,test_loss 29246.253906 step 20,test_loss 21168.052734 step 30,test_loss 22109.154297 step 40,test_loss 28030.435547 step 50,test_loss 24265.765625 step 60,test_loss 28433.816406 step 70,test_loss 24395.164062 step 80,test_loss 19135.515625 step 90,test_loss 20932.734375 step 100,test_loss 17176.033203 step 110,test_loss 19729.275391 step 120,test_loss 18076.587891 step 130,test_loss 24546.722656 step 140,test_loss 22370.619141 step 150,test_loss 17227.343750 step 160,test_loss 21498.363281 step 170,test_loss 17482.292969 step 180,test_loss 16188.901367 step 190,test_loss 17961.816406 step 200,test_loss 15168.850586 step 210,test_loss 14205.447266 step 220,test_loss 15992.610352 step 230,test_loss 12878.104492 step 240,test_loss 15663.670898 step 250,test_loss 11105.211914 step 260,test_loss 11135.759766 step 270,test_loss 12083.872070 step 280,test_loss 9544.156250 step 290,test_loss 12040.689453 step 300,test_loss 8685.537109 step 310,test_loss 11533.030273 step 320,test_loss 11031.776367 step 330,test_loss 11258.272461 step 340,test_loss 9219.499023 step 350,test_loss 7839.248047 step 360,test_loss 9757.743164 step 370,test_loss 7579.228027 step 380,test_loss 8326.705078 step 390,test_loss 8823.761719 step 400,test_loss 8431.373047 step 410,test_loss 8025.544922 step 420,test_loss 7954.462891 step 430,test_loss 9809.444336 step 440,test_loss 5645.476074 step 450,test_loss 7813.232422 step 460,test_loss 6410.347656 step 470,test_loss 6623.901367 step 480,test_loss 7697.770508 step 490,test_loss 5924.088867 step 500,test_loss 5174.365234 step 510,test_loss 5223.140625 step 520,test_loss 5655.796387 step 530,test_loss 4949.434570 step 540,test_loss 4330.499023 step 550,test_loss 5321.663086 step 560,test_loss 4629.940918 step 570,test_loss 3220.557373 step 580,test_loss 4162.278320 step 590,test_loss 4546.246582 step 600,test_loss 4487.117188 step 610,test_loss 5037.617676 step 620,test_loss 3526.248047 step 630,test_loss 3432.793457 step 640,test_loss 3385.915527 step 650,test_loss 3272.809814 step 660,test_loss 2710.681396 step 670,test_loss 3326.879883 step 680,test_loss 3275.361084 step 690,test_loss 2347.117432 step 700,test_loss 2957.036621 step 710,test_loss 1699.123535 step 720,test_loss 2293.731445 step 730,test_loss 2275.772705 step 740,test_loss 2176.456055 step 750,test_loss 2457.974121 step 760,test_loss 2203.473877 step 770,test_loss 1920.002686 step 780,test_loss 2047.632446 step 790,test_loss 1736.505615 step 800,test_loss 2039.262451 step 810,test_loss 2055.947510 step 820,test_loss 1908.234375 step 830,test_loss 1280.326904 step 840,test_loss 1412.927856 step 850,test_loss 1737.114258 step 860,test_loss 1251.464111 step 870,test_loss 1589.670532 step 880,test_loss 1396.735474 step 890,test_loss 1706.040527 step 900,test_loss 1558.866333 step 910,test_loss 1334.543213 step 920,test_loss 1306.657471 step 930,test_loss 942.939819 step 940,test_loss 1200.833008 step 950,test_loss 932.249695 step 960,test_loss 1328.827271 step 970,test_loss 1191.408081 step 980,test_loss 832.388062 step 990,test_loss 1052.487427 step 1000,test_loss 896.287964 step 1010,test_loss 707.095093 step 1020,test_loss 622.292297 step 1030,test_loss 798.665649 step 1040,test_loss 789.424316 step 1050,test_loss 606.861450 step 1060,test_loss 573.976074 step 1070,test_loss 465.951965 step 1080,test_loss 631.956543 step 1090,test_loss 679.685913 step 1100,test_loss 440.278046 step 1110,test_loss 476.793945 step 1120,test_loss 450.453278 step 1130,test_loss 541.740479 step 1140,test_loss 502.860077 step 1150,test_loss 363.825653 step 1160,test_loss 378.313232 step 1170,test_loss 364.206024 step 1180,test_loss 359.042999 step 1190,test_loss 304.770569 step 1200,test_loss 354.092407 step 1210,test_loss 296.288147 step 1220,test_loss 313.082031 step 1230,test_loss 321.331512 step 1240,test_loss 327.985718 step 1250,test_loss 257.409210 step 1260,test_loss 250.276291 step 1270,test_loss 191.458878 step 1280,test_loss 216.972244 step 1290,test_loss 229.754684 step 1300,test_loss 219.731140 step 1310,test_loss 197.320190 step 1320,test_loss 185.500366 step 1330,test_loss 180.765671 step 1340,test_loss 223.783081 step 1350,test_loss 166.295975 step 1360,test_loss 146.334641 step 1370,test_loss 191.004700 step 1380,test_loss 137.425964 step 1390,test_loss 155.957443 step 1400,test_loss 137.031784 step 1410,test_loss 144.765793 step 1420,test_loss 123.946625 step 1430,test_loss 133.717957 step 1440,test_loss 136.200287 step 1450,test_loss 109.962036 step 1460,test_loss 107.478485 step 1470,test_loss 111.343063 step 1480,test_loss 113.355667 step 1490,test_loss 110.620399 step 1500,test_loss 116.955994 step 1510,test_loss 102.297958 step 1520,test_loss 107.474968 step 1530,test_loss 88.769562 step 1540,test_loss 88.092247 step 1550,test_loss 93.228027 step 1560,test_loss 78.206909 step 1570,test_loss 99.623810 step 1580,test_loss 67.202003 step 1590,test_loss 77.569229 step 1600,test_loss 78.516144 step 1610,test_loss 76.165176 step 1620,test_loss 64.493408 step 1630,test_loss 70.672768 step 1640,test_loss 68.577499 step 1650,test_loss 72.143890 step 1660,test_loss 63.308643 step 1670,test_loss 64.004288 step 1680,test_loss 64.626549 step 1690,test_loss 59.137959 step 1700,test_loss 63.122589 step 1710,test_loss 56.314068 step 1720,test_loss 51.382557 step 1730,test_loss 58.105713 step 1740,test_loss 57.619289 step 1750,test_loss 54.326633 step 1760,test_loss 51.271332 step 1770,test_loss 56.553986 step 1780,test_loss 51.459373 step 1790,test_loss 49.371822 step 1800,test_loss 52.714359 step 1810,test_loss 50.442295 step 1820,test_loss 49.796776 step 1830,test_loss 48.404625 step 1840,test_loss 47.714275 step 1850,test_loss 49.141331 step 1860,test_loss 46.075230 step 1870,test_loss 47.250427 step 1880,test_loss 47.220695 step 1890,test_loss 47.975838 step 1900,test_loss 47.080906 step 1910,test_loss 45.991798 step 1920,test_loss 45.940758 step 1930,test_loss 45.241516 step 1940,test_loss 45.457054 step 1950,test_loss 44.415176 step 1960,test_loss 44.690414 step 1970,test_loss 44.910900 step 1980,test_loss 43.690544 step 1990,test_loss 42.880653 step 2000,test_loss 42.956898 step 2010,test_loss 43.080429 step 2020,test_loss 43.176693 step 2030,test_loss 43.030117 step 2040,test_loss 43.170925 step 2050,test_loss 42.681801 step 2060,test_loss 42.610954 step 2070,test_loss 42.576504 step 2080,test_loss 42.255066 step 2090,test_loss 42.081310 step 2100,test_loss 42.341095 step 2110,test_loss 42.025223 step 2120,test_loss 42.204201 step 2130,test_loss 42.335026 step 2140,test_loss 41.973049 step 2150,test_loss 42.003143 step 2160,test_loss 41.904259 step 2170,test_loss 41.881233 step 2180,test_loss 41.608265 step 2190,test_loss 41.525867 step 2200,test_loss 41.472271 step 2210,test_loss 41.472610 step 2220,test_loss 41.598587 step 2230,test_loss 41.459789 step 2240,test_loss 41.376347 step 2250,test_loss 41.300011 step 2260,test_loss 41.316811 step 2270,test_loss 41.432549 step 2280,test_loss 41.290428 step 2290,test_loss 41.279583 step 2300,test_loss 41.197216 step 2310,test_loss 41.269833 step 2320,test_loss 41.240284 step 2330,test_loss 41.202190 step 2340,test_loss 41.211605 step 2350,test_loss 41.224072 step 2360,test_loss 41.169403 step 2370,test_loss 41.151337 step 2380,test_loss 41.162971 step 2390,test_loss 41.127731 step 2400,test_loss 41.094795 step 2410,test_loss 41.089066 step 2420,test_loss 41.137642 step 2430,test_loss 41.085999 step 2440,test_loss 41.096901 step 2450,test_loss 41.096237 step 2460,test_loss 41.072151 step 2470,test_loss 41.094440 step 2480,test_loss 41.049301 step 2490,test_loss 41.062485 step 2500,test_loss 41.053036 step 2510,test_loss 41.042328 step 2520,test_loss 41.049831 step 2530,test_loss 41.078171 step 2540,test_loss 41.013088 step 2550,test_loss 41.039490 step 2560,test_loss 41.040127 step 2570,test_loss 41.047153 step 2580,test_loss 41.059521 step 2590,test_loss 41.067646 step 2600,test_loss 41.027416 step 2610,test_loss 41.019939 step 2620,test_loss 41.030586 step 2630,test_loss 41.028877 step 2640,test_loss 41.027557 step 2650,test_loss 41.026352 step 2660,test_loss 41.023903 step 2670,test_loss 41.006763 step 2680,test_loss 41.024330 step 2690,test_loss 41.046272 step 2700,test_loss 41.018227 step 2710,test_loss 41.016628 step 2720,test_loss 41.025139 step 2730,test_loss 41.019703 step 2740,test_loss 41.016834 step 2750,test_loss 41.033138 step 2760,test_loss 41.031982 step 2770,test_loss 41.027203 step 2780,test_loss 41.036865 step 2790,test_loss 41.039066 step 2800,test_loss 41.015831 step 2810,test_loss 41.021862 step 2820,test_loss 41.037052 step 2830,test_loss 41.030590 step 2840,test_loss 41.026188 step 2850,test_loss 41.019707 step 2860,test_loss 41.021141 step 2870,test_loss 41.019894 step 2880,test_loss 41.020607 step 2890,test_loss 41.024086 step 2900,test_loss 41.037041 step 2910,test_loss 41.023495 step 2920,test_loss 41.011646 step 2930,test_loss 41.022732 step 2940,test_loss 41.017460 step 2950,test_loss 41.042557 step 2960,test_loss 41.025982 step 2970,test_loss 41.023857 step 2980,test_loss 41.029766 step 2990,test_loss 41.021320 step 3000,test_loss 41.036278 step 3010,test_loss 41.026100 step 3020,test_loss 41.029068 step 3030,test_loss 41.007935 step 3040,test_loss 41.024139 step 3050,test_loss 41.023842 step 3060,test_loss 41.023033 step 3070,test_loss 41.041313 step 3080,test_loss 41.013794 step 3090,test_loss 41.021595 step 3100,test_loss 41.023506 step 3110,test_loss 41.027863 step 3120,test_loss 41.049881 step 3130,test_loss 41.037209 step 3140,test_loss 41.013416 step 3150,test_loss 41.044666 step 3160,test_loss 41.022858 step 3170,test_loss 41.026386 step 3180,test_loss 41.025173 step 3190,test_loss 41.025276 step 3200,test_loss 41.031715 step 3210,test_loss 41.019821 step 3220,test_loss 41.023750 step 3230,test_loss 41.026768 step 3240,test_loss 41.025543 step 3250,test_loss 41.030800 step 3260,test_loss 41.032837 step 3270,test_loss 41.020596 step 3280,test_loss 41.024185 step 3290,test_loss 41.014019 step 3300,test_loss 41.017628 step 3310,test_loss 41.039688 step 3320,test_loss 41.036552 step 3330,test_loss 41.041679 step 3340,test_loss 41.010323 step 3350,test_loss 41.019321 step 3360,test_loss 41.003582 step 3370,test_loss 41.039524 step 3380,test_loss 41.041386 step 3390,test_loss 41.014439 step 3400,test_loss 41.031914 step 3410,test_loss 41.047981 step 3420,test_loss 41.020836 step 3430,test_loss 41.035324 step 3440,test_loss 41.021690 step 3450,test_loss 41.026123 step 3460,test_loss 41.029877 step 3470,test_loss 41.027092 step 3480,test_loss 41.027649 step 3490,test_loss 41.023071 step 3500,test_loss 41.027126 step 3510,test_loss 41.018978 step 3520,test_loss 41.030590 step 3530,test_loss 41.026154 step 3540,test_loss 41.021610 step 3550,test_loss 41.014198 step 3560,test_loss 41.032345 step 3570,test_loss 41.030876 step 3580,test_loss 41.013630 step 3590,test_loss 41.025135 step 3600,test_loss 41.035576 step 3610,test_loss 41.018707 step 3620,test_loss 41.019424 step 3630,test_loss 41.028542 step 3640,test_loss 41.039867 step 3650,test_loss 41.014717 step 3660,test_loss 41.035339 step 3670,test_loss 41.031448 step 3680,test_loss 41.016773 step 3690,test_loss 41.025093 step 3700,test_loss 41.030968 step 3710,test_loss 41.027367 step 3720,test_loss 41.039196 step 3730,test_loss 41.024532 step 3740,test_loss 41.039036 step 3750,test_loss 41.003342 step 3760,test_loss 41.035763 step 3770,test_loss 41.035271 step 3780,test_loss 41.009220 step 3790,test_loss 41.030884 step 3800,test_loss 41.029705 step 3810,test_loss 41.029217 step 3820,test_loss 41.028343 step 3830,test_loss 41.020901 step 3840,test_loss 41.039314 step 3850,test_loss 41.045189 step 3860,test_loss 41.028725 step 3870,test_loss 41.026402 step 3880,test_loss 41.014465 step 3890,test_loss 41.027691 step 3900,test_loss 41.027061 step 3910,test_loss 41.023037 step 3920,test_loss 41.028137 step 3930,test_loss 41.035686 step 3940,test_loss 41.021793 step 3950,test_loss 41.014446 step 3960,test_loss 41.018074 step 3970,test_loss 41.037655 step 3980,test_loss 41.019314 step 3990,test_loss 41.022900 step 4000,test_loss 41.026077 step 4010,test_loss 41.035042 step 4020,test_loss 41.022713 step 4030,test_loss 41.029526 step 4040,test_loss 41.026649 step 4050,test_loss 41.033508 step 4060,test_loss 41.028713 step 4070,test_loss 41.031872 step 4080,test_loss 41.017612 step 4090,test_loss 41.031342 step 4100,test_loss 41.024128 step 4110,test_loss 41.021511 step 4120,test_loss 41.028091 step 4130,test_loss 41.025402 step 4140,test_loss 41.028831 step 4150,test_loss 41.025154 step 4160,test_loss 41.028797 step 4170,test_loss 41.023502 step 4180,test_loss 41.023289 step 4190,test_loss 41.026257 step 4200,test_loss 41.023941 step 4210,test_loss 41.017677 step 4220,test_loss 41.018219 step 4230,test_loss 41.021465 step 4240,test_loss 41.022671 step 4250,test_loss 41.035088 step 4260,test_loss 41.028889 step 4270,test_loss 41.015503 step 4280,test_loss 41.011471 step 4290,test_loss 41.034992 step 4300,test_loss 41.024700 step 4310,test_loss 41.021152 step 4320,test_loss 41.033760 step 4330,test_loss 41.022285 step 4340,test_loss 41.023975 step 4350,test_loss 41.047928 step 4360,test_loss 41.040417 step 4370,test_loss 41.015713 step 4380,test_loss 41.021191 step 4390,test_loss 41.028423 step 4400,test_loss 41.046730 step 4410,test_loss 41.019470 step 4420,test_loss 41.023933 step 4430,test_loss 41.023426 step 4440,test_loss 41.044052 step 4450,test_loss 41.023289 step 4460,test_loss 41.037994 step 4470,test_loss 41.027950 step 4480,test_loss 41.018356 step 4490,test_loss 41.026508 step 4500,test_loss 41.024136 step 4510,test_loss 41.032318 step 4520,test_loss 41.028934 step 4530,test_loss 41.027802 step 4540,test_loss 41.034740 step 4550,test_loss 41.018875 step 4560,test_loss 41.009151 step 4570,test_loss 41.028728 step 4580,test_loss 41.013172 step 4590,test_loss 41.023643 step 4600,test_loss 41.036564 step 4610,test_loss 41.023758 step 4620,test_loss 41.010895 step 4630,test_loss 41.016830 step 4640,test_loss 41.025158 step 4650,test_loss 41.031147 step 4660,test_loss 41.030773 step 4670,test_loss 41.014057 step 4680,test_loss 41.012878 step 4690,test_loss 41.020706 step 4700,test_loss 41.024204 step 4710,test_loss 41.030964 step 4720,test_loss 41.042183 step 4730,test_loss 41.004620 step 4740,test_loss 41.043163 step 4750,test_loss 41.026157 step 4760,test_loss 41.016129 step 4770,test_loss 41.028667 step 4780,test_loss 41.033478 step 4790,test_loss 41.032280 step 4800,test_loss 41.029270 step 4810,test_loss 41.032330 step 4820,test_loss 41.026970 step 4830,test_loss 41.034531 step 4840,test_loss 41.038826 step 4850,test_loss 41.033676 step 4860,test_loss 41.037766 step 4870,test_loss 41.026272 step 4880,test_loss 41.024136 step 4890,test_loss 41.020840 step 4900,test_loss 41.028576 step 4910,test_loss 41.013222 step 4920,test_loss 41.042625 step 4930,test_loss 41.035049 step 4940,test_loss 41.023026 step 4950,test_loss 41.023335 step 4960,test_loss 41.028851 step 4970,test_loss 41.024628 step 4980,test_loss 41.019810 step 4990,test_loss 41.026733 Tensorflow R2: 0.999997486127 Sklearn R2: 1.0
作者:舆-风动名扬 出处:http://www.cnblogs.com/gnool/
© 作者和博客园,欢迎转载,未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则保留追究法律责任的权利。如果觉得还有帮助的话,可以点一下右下角的【推荐】想跟我一起进步么?那就【关注】我吧。