tencent_2.1_linear_regression_model
linear_regression_model.py
import tensorflow as tf import numpy as np class linearRegressionModel: def __init__(self, x_dimen): self.x_dimen = x_dimen self._index_in_epoch = 0 self.constructModel() self.sess = tf.Session() self.sess.run(tf.global_variables_initializer()) def weight_variable(self, shape): initial = tf.truncated_normal(shape, stddev = 0.1) return tf.Variable(initial) def bias_variable(self, shape): initial = tf.constant(0.1, shape = shape) return tf.Variable(initial) def constructModel(self): self.x = tf.placeholder(tf.float32, [None, self.x_dimen]) self.y = tf.placeholder(tf.float32, [None, 1]) self.w = self.weight_variable([self.x_dimen, 1]) self.b = self.bias_variable([1]) self.y_prec = tf.nn.bias_add(tf.matmul(self.x, self.w), self.b) mse = tf.reduce_mean(tf.squared_difference(self.y_prec, self.y)) l2 = tf.reduce_mean(tf.square(self.w)) self.loss = mse + 0.15*l2 self.train_step = tf.train.AdamOptimizer(0.1).minimize(self.loss) def next_batch(self, batch_size): start = self._index_in_epoch self._index_in_epoch += batch_size if self._index_in_epoch > self._num_datas: perm = np.arange(self._num_datas) np.random.shuffle(perm) self._datas = self._datas[perm] self._labels = self._labels[perm] start = 0 self._index_in_epoch = batch_size assert batch_size < self._num_datas end = self._index_in_epoch return self._datas[start:end], self._labels[start:end] def train(self, x_train, y_train): self._datas = x_train self._labels = y_train self._num_datas = x_train.shape[0] for i in range(5000): batch = self.next_batch(100) self.sess.run(self.train_step, feed_dict={self.x:batch[0],self.y:batch[1]}) if i%10 == 0: train_loss = self.sess.run(self.loss, feed_dict={self.x:batch[0],self.y:batch[1]}) print("step %d,test_loss %f" % (i, train_loss)) def predict_batch(self,arr,batch_size): for i in range(0,len(arr),batch_size): yield arr[i:i+batch_size] def predict(self, x_predict): pred_list = [] for x_test_batch in self.predict_batch(x_predict, 100): pred = self.sess.run(self.y_prec, feed_dict={self.x:x_test_batch}) pred_list.append(pred) return np.vstack(pred_list)
run.py
from sklearn.model_selection import train_test_split from sklearn.metrics import r2_score from sklearn.datasets import make_regression from sklearn.linear_model import LinearRegression from linear_regression_model import linearRegressionModel as lrm if __name__ == '__main__': x, y = make_regression(7000) x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.5) y_lrm_train = y_train.reshape(-1, 1) y_lrm_test = y_test.reshape(-1, 1) linear = lrm(x.shape[1]) linear.train(x_train, y_lrm_train) y_predict = linear.predict(x_test) print("Tensorflow R2: ", r2_score(y_predict.ravel(), y_lrm_test.ravel())) lr = LinearRegression() y_predict = lr.fit(x_train, y_train).predict(x_test) print("Sklearn R2: ", r2_score(y_predict, y_test))
结果:
ubuntu@VM-12-146-ubuntu:~$ python run.py step 0,test_loss 50554.742188 step 10,test_loss 53487.046875 step 20,test_loss 36099.449219 step 30,test_loss 50567.339844 step 40,test_loss 45398.449219 step 50,test_loss 40298.109375 step 60,test_loss 40552.335938 step 70,test_loss 33812.503906 step 80,test_loss 39265.847656 step 90,test_loss 36639.019531 step 100,test_loss 38088.800781 step 110,test_loss 34145.976562 step 120,test_loss 34928.343750 step 130,test_loss 27798.576172 step 140,test_loss 27276.896484 step 150,test_loss 24163.076172 step 160,test_loss 27816.298828 step 170,test_loss 27190.076172 step 180,test_loss 23125.751953 step 190,test_loss 22233.732422 step 200,test_loss 25407.130859 step 210,test_loss 24661.605469 step 220,test_loss 26708.384766 step 230,test_loss 19841.687500 step 240,test_loss 24072.408203 step 250,test_loss 17266.611328 step 260,test_loss 26461.554688 step 270,test_loss 22570.013672 step 280,test_loss 24905.404297 step 290,test_loss 15179.116211 step 300,test_loss 18661.962891 step 310,test_loss 19683.906250 step 320,test_loss 15327.764648 step 330,test_loss 13168.517578 step 340,test_loss 18685.867188 step 350,test_loss 15487.893555 step 360,test_loss 14875.776367 step 370,test_loss 13526.083008 step 380,test_loss 12865.832031 step 390,test_loss 15644.985352 step 400,test_loss 15185.415039 step 410,test_loss 14463.912109 step 420,test_loss 15174.072266 step 430,test_loss 12386.041016 step 440,test_loss 11713.977539 step 450,test_loss 13448.644531 step 460,test_loss 11875.729492 step 470,test_loss 10373.699219 step 480,test_loss 10216.433594 step 490,test_loss 10364.944336 step 500,test_loss 11086.302734 step 510,test_loss 5693.416992 step 520,test_loss 8974.305664 step 530,test_loss 10542.807617 step 540,test_loss 8704.875977 step 550,test_loss 7724.854980 step 560,test_loss 10312.858398 step 570,test_loss 8256.900391 step 580,test_loss 7817.630859 step 590,test_loss 5888.268555 step 600,test_loss 6200.811523 step 610,test_loss 8451.625000 step 620,test_loss 5280.057617 step 630,test_loss 6405.470215 step 640,test_loss 7642.217773 step 650,test_loss 7243.946777 step 660,test_loss 5507.767090 step 670,test_loss 5810.758789 step 680,test_loss 5605.552246 step 690,test_loss 4809.159180 step 700,test_loss 5075.568848 step 710,test_loss 6186.606445 step 720,test_loss 4644.177734 step 730,test_loss 3405.376953 step 740,test_loss 3963.750732 step 750,test_loss 4160.150879 step 760,test_loss 3322.133301 step 770,test_loss 4439.647461 step 780,test_loss 2879.551758 step 790,test_loss 4686.266602 step 800,test_loss 3389.883789 step 810,test_loss 2878.820801 step 820,test_loss 2689.784668 step 830,test_loss 3682.542969 step 840,test_loss 4084.625977 step 850,test_loss 2865.651611 step 860,test_loss 2581.408936 step 870,test_loss 2442.209717 step 880,test_loss 2742.542969 step 890,test_loss 2850.451172 step 900,test_loss 2632.563477 step 910,test_loss 2156.909180 step 920,test_loss 2305.271729 step 930,test_loss 1993.062134 step 940,test_loss 2311.316162 step 950,test_loss 1890.035156 step 960,test_loss 1763.053955 step 970,test_loss 1725.859131 step 980,test_loss 1759.168091 step 990,test_loss 1323.621216 step 1000,test_loss 1866.871338 step 1010,test_loss 1352.615479 step 1020,test_loss 1616.773560 step 1030,test_loss 1143.031250 step 1040,test_loss 1489.623291 step 1050,test_loss 1476.880371 step 1060,test_loss 1370.258667 step 1070,test_loss 1141.999878 step 1080,test_loss 1166.933228 step 1090,test_loss 1108.787476 step 1100,test_loss 960.834412 step 1110,test_loss 1265.723999 step 1120,test_loss 845.209717 step 1130,test_loss 1160.951294 step 1140,test_loss 1252.884766 step 1150,test_loss 947.760193 step 1160,test_loss 949.110107 step 1170,test_loss 703.865845 step 1180,test_loss 785.185425 step 1190,test_loss 909.500916 step 1200,test_loss 739.722473 step 1210,test_loss 752.640686 step 1220,test_loss 621.177429 step 1230,test_loss 664.028809 step 1240,test_loss 842.765198 step 1250,test_loss 537.212891 step 1260,test_loss 612.894226 step 1270,test_loss 515.576599 step 1280,test_loss 595.646362 step 1290,test_loss 591.783936 step 1300,test_loss 636.358215 step 1310,test_loss 421.886414 step 1320,test_loss 454.036652 step 1330,test_loss 441.143738 step 1340,test_loss 393.075867 step 1350,test_loss 415.702820 step 1360,test_loss 448.445557 step 1370,test_loss 350.026550 step 1380,test_loss 391.277832 step 1390,test_loss 458.568481 step 1400,test_loss 423.181671 step 1410,test_loss 410.131195 step 1420,test_loss 336.751404 step 1430,test_loss 279.585114 step 1440,test_loss 244.326080 step 1450,test_loss 322.283997 step 1460,test_loss 275.522095 step 1470,test_loss 240.525375 step 1480,test_loss 272.072205 step 1490,test_loss 202.847000 step 1500,test_loss 276.935272 step 1510,test_loss 211.451935 step 1520,test_loss 235.174835 step 1530,test_loss 201.162201 step 1540,test_loss 190.168976 step 1550,test_loss 184.773560 step 1560,test_loss 214.716644 step 1570,test_loss 171.767059 step 1580,test_loss 174.465897 step 1590,test_loss 167.292999 step 1600,test_loss 169.164307 step 1610,test_loss 169.313507 step 1620,test_loss 172.326813 step 1630,test_loss 157.986420 step 1640,test_loss 146.000031 step 1650,test_loss 134.339294 step 1660,test_loss 134.111725 step 1670,test_loss 129.241135 step 1680,test_loss 133.466339 step 1690,test_loss 128.226562 step 1700,test_loss 125.525810 step 1710,test_loss 115.412491 step 1720,test_loss 111.956238 step 1730,test_loss 114.153198 step 1740,test_loss 115.078041 step 1750,test_loss 107.463058 step 1760,test_loss 100.881134 step 1770,test_loss 108.238266 step 1780,test_loss 105.143570 step 1790,test_loss 105.204819 step 1800,test_loss 92.495026 step 1810,test_loss 96.282784 step 1820,test_loss 94.634003 step 1830,test_loss 88.530350 step 1840,test_loss 94.419586 step 1850,test_loss 90.118423 step 1860,test_loss 90.100471 step 1870,test_loss 92.007225 step 1880,test_loss 83.837082 step 1890,test_loss 88.835732 step 1900,test_loss 84.472664 step 1910,test_loss 84.264526 step 1920,test_loss 83.923073 step 1930,test_loss 81.012611 step 1940,test_loss 77.510666 step 1950,test_loss 81.193970 step 1960,test_loss 81.355026 step 1970,test_loss 79.240715 step 1980,test_loss 75.555099 step 1990,test_loss 76.817108 step 2000,test_loss 77.046547 step 2010,test_loss 75.549843 step 2020,test_loss 77.543091 step 2030,test_loss 74.993935 step 2040,test_loss 75.710793 step 2050,test_loss 74.185966 step 2060,test_loss 77.117058 step 2070,test_loss 74.278481 step 2080,test_loss 73.415520 step 2090,test_loss 71.878136 step 2100,test_loss 73.699829 step 2110,test_loss 73.163803 step 2120,test_loss 72.371719 step 2130,test_loss 73.948982 step 2140,test_loss 72.047447 step 2150,test_loss 71.245361 step 2160,test_loss 72.283531 step 2170,test_loss 71.045090 step 2180,test_loss 70.661865 step 2190,test_loss 70.973740 step 2200,test_loss 71.079689 step 2210,test_loss 71.110039 step 2220,test_loss 70.308311 step 2230,test_loss 69.957397 step 2240,test_loss 70.038406 step 2250,test_loss 70.325066 step 2260,test_loss 70.040375 step 2270,test_loss 70.044968 step 2280,test_loss 69.895622 step 2290,test_loss 69.759949 step 2300,test_loss 69.858398 step 2310,test_loss 69.752213 step 2320,test_loss 69.294426 step 2330,test_loss 69.490250 step 2340,test_loss 69.516228 step 2350,test_loss 69.264801 step 2360,test_loss 69.474541 step 2370,test_loss 69.354004 step 2380,test_loss 69.242920 step 2390,test_loss 69.212044 step 2400,test_loss 69.350334 step 2410,test_loss 69.142731 step 2420,test_loss 69.137589 step 2430,test_loss 69.149040 step 2440,test_loss 69.254677 step 2450,test_loss 69.260796 step 2460,test_loss 69.097015 step 2470,test_loss 69.137276 step 2480,test_loss 69.074081 step 2490,test_loss 69.078163 step 2500,test_loss 69.126526 step 2510,test_loss 69.039513 step 2520,test_loss 68.991615 step 2530,test_loss 68.940292 step 2540,test_loss 68.944313 step 2550,test_loss 68.915138 step 2560,test_loss 68.998390 step 2570,test_loss 68.909355 step 2580,test_loss 68.906288 step 2590,test_loss 68.861900 step 2600,test_loss 68.936501 step 2610,test_loss 68.967476 step 2620,test_loss 68.890762 step 2630,test_loss 68.898155 step 2640,test_loss 68.884277 step 2650,test_loss 68.880043 step 2660,test_loss 68.900604 step 2670,test_loss 68.914810 step 2680,test_loss 68.907867 step 2690,test_loss 68.896652 step 2700,test_loss 68.859848 step 2710,test_loss 68.878159 step 2720,test_loss 68.902023 step 2730,test_loss 68.873718 step 2740,test_loss 68.916138 step 2750,test_loss 68.877785 step 2760,test_loss 68.864380 step 2770,test_loss 68.869530 step 2780,test_loss 68.861702 step 2790,test_loss 68.882233 step 2800,test_loss 68.869751 step 2810,test_loss 68.869904 step 2820,test_loss 68.888641 step 2830,test_loss 68.850166 step 2840,test_loss 68.856880 step 2850,test_loss 68.889809 step 2860,test_loss 68.859596 step 2870,test_loss 68.855354 step 2880,test_loss 68.882050 step 2890,test_loss 68.885651 step 2900,test_loss 68.859543 step 2910,test_loss 68.818733 step 2920,test_loss 68.832703 step 2930,test_loss 68.852486 step 2940,test_loss 68.870384 step 2950,test_loss 68.855949 step 2960,test_loss 68.841423 step 2970,test_loss 68.872490 step 2980,test_loss 68.883705 step 2990,test_loss 68.855988 step 3000,test_loss 68.870636 step 3010,test_loss 68.862991 step 3020,test_loss 68.870255 step 3030,test_loss 68.863777 step 3040,test_loss 68.900009 step 3050,test_loss 68.891266 step 3060,test_loss 68.879257 step 3070,test_loss 68.858406 step 3080,test_loss 68.860107 step 3090,test_loss 68.874466 step 3100,test_loss 68.860497 step 3110,test_loss 68.858696 step 3120,test_loss 68.883636 step 3130,test_loss 68.876640 step 3140,test_loss 68.868874 step 3150,test_loss 68.874413 step 3160,test_loss 68.827744 step 3170,test_loss 68.875488 step 3180,test_loss 68.881233 step 3190,test_loss 68.859863 step 3200,test_loss 68.885933 step 3210,test_loss 68.869545 step 3220,test_loss 68.842995 step 3230,test_loss 68.863731 step 3240,test_loss 68.871620 step 3250,test_loss 68.887390 step 3260,test_loss 68.847641 step 3270,test_loss 68.856567 step 3280,test_loss 68.860397 step 3290,test_loss 68.862595 step 3300,test_loss 68.884712 step 3310,test_loss 68.877060 step 3320,test_loss 68.854836 step 3330,test_loss 68.852036 step 3340,test_loss 68.841431 step 3350,test_loss 68.856216 step 3360,test_loss 68.901741 step 3370,test_loss 68.879402 step 3380,test_loss 68.848961 step 3390,test_loss 68.894516 step 3400,test_loss 68.845978 step 3410,test_loss 68.843407 step 3420,test_loss 68.880623 step 3430,test_loss 68.874138 step 3440,test_loss 68.875328 step 3450,test_loss 68.847839 step 3460,test_loss 68.856407 step 3470,test_loss 68.871277 step 3480,test_loss 68.878189 step 3490,test_loss 68.884209 step 3500,test_loss 68.842300 step 3510,test_loss 68.860901 step 3520,test_loss 68.872070 step 3530,test_loss 68.859085 step 3540,test_loss 68.865288 step 3550,test_loss 68.876419 step 3560,test_loss 68.869156 step 3570,test_loss 68.886154 step 3580,test_loss 68.872124 step 3590,test_loss 68.894897 step 3600,test_loss 68.877914 step 3610,test_loss 68.880173 step 3620,test_loss 68.901749 step 3630,test_loss 68.838867 step 3640,test_loss 68.871284 step 3650,test_loss 68.844742 step 3660,test_loss 68.849792 step 3670,test_loss 68.876732 step 3680,test_loss 68.837036 step 3690,test_loss 68.862450 step 3700,test_loss 68.875610 step 3710,test_loss 68.856522 step 3720,test_loss 68.859993 step 3730,test_loss 68.869888 step 3740,test_loss 68.853325 step 3750,test_loss 68.863533 step 3760,test_loss 68.858566 step 3770,test_loss 68.843681 step 3780,test_loss 68.858315 step 3790,test_loss 68.876717 step 3800,test_loss 68.875565 step 3810,test_loss 68.901901 step 3820,test_loss 68.881500 step 3830,test_loss 68.873932 step 3840,test_loss 68.833977 step 3850,test_loss 68.854820 step 3860,test_loss 68.882072 step 3870,test_loss 68.848206 step 3880,test_loss 68.874969 step 3890,test_loss 68.866570 step 3900,test_loss 68.849907 step 3910,test_loss 68.879288 step 3920,test_loss 68.866447 step 3930,test_loss 68.845490 step 3940,test_loss 68.854675 step 3950,test_loss 68.858597 step 3960,test_loss 68.850449 step 3970,test_loss 68.853188 step 3980,test_loss 68.881699 step 3990,test_loss 68.852928 step 4000,test_loss 68.903976 step 4010,test_loss 68.883759 step 4020,test_loss 68.873917 step 4030,test_loss 68.858192 step 4040,test_loss 68.883629 step 4050,test_loss 68.888466 step 4060,test_loss 68.867790 step 4070,test_loss 68.853691 step 4080,test_loss 68.888000 step 4090,test_loss 68.863617 step 4100,test_loss 68.844299 step 4110,test_loss 68.887550 step 4120,test_loss 68.867096 step 4130,test_loss 68.840004 step 4140,test_loss 68.872429 step 4150,test_loss 68.880051 step 4160,test_loss 68.840759 step 4170,test_loss 68.876610 step 4180,test_loss 68.866318 step 4190,test_loss 68.864395 step 4200,test_loss 68.866844 step 4210,test_loss 68.870064 step 4220,test_loss 68.835907 step 4230,test_loss 68.853455 step 4240,test_loss 68.866196 step 4250,test_loss 68.892769 step 4260,test_loss 68.853798 step 4270,test_loss 68.830009 step 4280,test_loss 68.863434 step 4290,test_loss 68.843948 step 4300,test_loss 68.881226 step 4310,test_loss 68.859001 step 4320,test_loss 68.858551 step 4330,test_loss 68.865173 step 4340,test_loss 68.851601 step 4350,test_loss 68.886436 step 4360,test_loss 68.861244 step 4370,test_loss 68.834076 step 4380,test_loss 68.848862 step 4390,test_loss 68.852104 step 4400,test_loss 68.858315 step 4410,test_loss 68.841331 step 4420,test_loss 68.855446 step 4430,test_loss 68.870331 step 4440,test_loss 68.888527 step 4450,test_loss 68.869141 step 4460,test_loss 68.870499 step 4470,test_loss 68.873802 step 4480,test_loss 68.856995 step 4490,test_loss 68.861732 step 4500,test_loss 68.848595 step 4510,test_loss 68.877945 step 4520,test_loss 68.862206 step 4530,test_loss 68.872231 step 4540,test_loss 68.858627 step 4550,test_loss 68.879288 step 4560,test_loss 68.857910 step 4570,test_loss 68.884712 step 4580,test_loss 68.886894 step 4590,test_loss 68.882500 step 4600,test_loss 68.884560 step 4610,test_loss 68.845879 step 4620,test_loss 68.844276 step 4630,test_loss 68.860344 step 4640,test_loss 68.869537 step 4650,test_loss 68.862511 step 4660,test_loss 68.886688 step 4670,test_loss 68.858307 step 4680,test_loss 68.870544 step 4690,test_loss 68.895294 step 4700,test_loss 68.836151 step 4710,test_loss 68.854034 step 4720,test_loss 68.896736 step 4730,test_loss 68.873703 step 4740,test_loss 68.850273 step 4750,test_loss 68.877182 step 4760,test_loss 68.850876 step 4770,test_loss 68.841743 step 4780,test_loss 68.863495 step 4790,test_loss 68.880394 step 4800,test_loss 68.865532 step 4810,test_loss 68.880829 step 4820,test_loss 68.898735 step 4830,test_loss 68.868607 step 4840,test_loss 68.851143 step 4850,test_loss 68.888962 step 4860,test_loss 68.861221 step 4870,test_loss 68.856529 step 4880,test_loss 68.870476 step 4890,test_loss 68.859734 step 4900,test_loss 68.859306 step 4910,test_loss 68.870979 step 4920,test_loss 68.879097 step 4930,test_loss 68.886711 step 4940,test_loss 68.869553 step 4950,test_loss 68.861496 step 4960,test_loss 68.846848 step 4970,test_loss 68.867828 step 4980,test_loss 68.883766 step 4990,test_loss 68.881592 ('Tensorflow R2: ', 0.99999761462739589) ('Sklearn R2: ', 1.0) ubuntu@VM-12-146-ubuntu:~$