tencent_2.1_linear_regression_model

 linear_regression_model.py

import tensorflow as tf
import numpy as np

class linearRegressionModel:

    def __init__(self, x_dimen):
        self.x_dimen = x_dimen
        self._index_in_epoch = 0
        self.constructModel()
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())

    def weight_variable(self, shape):
        initial = tf.truncated_normal(shape, stddev = 0.1)
        return tf.Variable(initial)

    def bias_variable(self, shape):
        initial = tf.constant(0.1, shape = shape)
        return tf.Variable(initial)

    def constructModel(self):
        self.x = tf.placeholder(tf.float32, [None, self.x_dimen])
        self.y = tf.placeholder(tf.float32, [None, 1])
        self.w = self.weight_variable([self.x_dimen, 1])
        self.b = self.bias_variable([1])
        self.y_prec = tf.nn.bias_add(tf.matmul(self.x, self.w), self.b)

        mse = tf.reduce_mean(tf.squared_difference(self.y_prec, self.y))
        l2 = tf.reduce_mean(tf.square(self.w))
        self.loss = mse + 0.15*l2
        self.train_step = tf.train.AdamOptimizer(0.1).minimize(self.loss)

    def next_batch(self, batch_size):
        start = self._index_in_epoch
        self._index_in_epoch += batch_size
        if self._index_in_epoch > self._num_datas:
            perm = np.arange(self._num_datas)
            np.random.shuffle(perm)
            self._datas = self._datas[perm]
            self._labels = self._labels[perm]
            start = 0
            self._index_in_epoch = batch_size
            assert batch_size < self._num_datas
        end = self._index_in_epoch
        return self._datas[start:end], self._labels[start:end]

    def train(self, x_train, y_train):
        self._datas = x_train
        self._labels = y_train
        self._num_datas = x_train.shape[0]
        for i in range(5000):
            batch = self.next_batch(100)
            self.sess.run(self.train_step, feed_dict={self.x:batch[0],self.y:batch[1]})
            if i%10 == 0:
                train_loss = self.sess.run(self.loss, feed_dict={self.x:batch[0],self.y:batch[1]})
                print("step %d,test_loss %f" % (i, train_loss))

    def predict_batch(self,arr,batch_size):
        for i in range(0,len(arr),batch_size):
            yield arr[i:i+batch_size]

    def predict(self, x_predict):
        pred_list = []
        for x_test_batch in self.predict_batch(x_predict, 100):
            pred = self.sess.run(self.y_prec, feed_dict={self.x:x_test_batch})
            pred_list.append(pred)
        return np.vstack(pred_list)

 

run.py

from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression
from linear_regression_model import linearRegressionModel as lrm

if __name__ == '__main__':
    x, y = make_regression(7000)
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.5)
    y_lrm_train = y_train.reshape(-1, 1)
    y_lrm_test = y_test.reshape(-1, 1)

    linear = lrm(x.shape[1])
    linear.train(x_train, y_lrm_train)
    y_predict = linear.predict(x_test)
    print("Tensorflow R2: ", r2_score(y_predict.ravel(), y_lrm_test.ravel()))

    lr = LinearRegression()
    y_predict = lr.fit(x_train, y_train).predict(x_test)
    print("Sklearn R2: ", r2_score(y_predict, y_test))

 

结果:

ubuntu@VM-12-146-ubuntu:~$ python run.py
step 0,test_loss 50554.742188
step 10,test_loss 53487.046875
step 20,test_loss 36099.449219
step 30,test_loss 50567.339844
step 40,test_loss 45398.449219
step 50,test_loss 40298.109375
step 60,test_loss 40552.335938
step 70,test_loss 33812.503906
step 80,test_loss 39265.847656
step 90,test_loss 36639.019531
step 100,test_loss 38088.800781
step 110,test_loss 34145.976562
step 120,test_loss 34928.343750
step 130,test_loss 27798.576172
step 140,test_loss 27276.896484
step 150,test_loss 24163.076172
step 160,test_loss 27816.298828
step 170,test_loss 27190.076172
step 180,test_loss 23125.751953
step 190,test_loss 22233.732422
step 200,test_loss 25407.130859
step 210,test_loss 24661.605469
step 220,test_loss 26708.384766
step 230,test_loss 19841.687500
step 240,test_loss 24072.408203
step 250,test_loss 17266.611328
step 260,test_loss 26461.554688
step 270,test_loss 22570.013672
step 280,test_loss 24905.404297
step 290,test_loss 15179.116211
step 300,test_loss 18661.962891
step 310,test_loss 19683.906250
step 320,test_loss 15327.764648
step 330,test_loss 13168.517578
step 340,test_loss 18685.867188
step 350,test_loss 15487.893555
step 360,test_loss 14875.776367
step 370,test_loss 13526.083008
step 380,test_loss 12865.832031
step 390,test_loss 15644.985352
step 400,test_loss 15185.415039
step 410,test_loss 14463.912109
step 420,test_loss 15174.072266
step 430,test_loss 12386.041016
step 440,test_loss 11713.977539
step 450,test_loss 13448.644531
step 460,test_loss 11875.729492
step 470,test_loss 10373.699219
step 480,test_loss 10216.433594
step 490,test_loss 10364.944336
step 500,test_loss 11086.302734
step 510,test_loss 5693.416992
step 520,test_loss 8974.305664
step 530,test_loss 10542.807617
step 540,test_loss 8704.875977
step 550,test_loss 7724.854980
step 560,test_loss 10312.858398
step 570,test_loss 8256.900391
step 580,test_loss 7817.630859
step 590,test_loss 5888.268555
step 600,test_loss 6200.811523
step 610,test_loss 8451.625000
step 620,test_loss 5280.057617
step 630,test_loss 6405.470215
step 640,test_loss 7642.217773
step 650,test_loss 7243.946777
step 660,test_loss 5507.767090
step 670,test_loss 5810.758789
step 680,test_loss 5605.552246
step 690,test_loss 4809.159180
step 700,test_loss 5075.568848
step 710,test_loss 6186.606445
step 720,test_loss 4644.177734
step 730,test_loss 3405.376953
step 740,test_loss 3963.750732
step 750,test_loss 4160.150879
step 760,test_loss 3322.133301
step 770,test_loss 4439.647461
step 780,test_loss 2879.551758
step 790,test_loss 4686.266602
step 800,test_loss 3389.883789
step 810,test_loss 2878.820801
step 820,test_loss 2689.784668
step 830,test_loss 3682.542969
step 840,test_loss 4084.625977
step 850,test_loss 2865.651611
step 860,test_loss 2581.408936
step 870,test_loss 2442.209717
step 880,test_loss 2742.542969
step 890,test_loss 2850.451172
step 900,test_loss 2632.563477
step 910,test_loss 2156.909180
step 920,test_loss 2305.271729
step 930,test_loss 1993.062134
step 940,test_loss 2311.316162
step 950,test_loss 1890.035156
step 960,test_loss 1763.053955
step 970,test_loss 1725.859131
step 980,test_loss 1759.168091
step 990,test_loss 1323.621216
step 1000,test_loss 1866.871338
step 1010,test_loss 1352.615479
step 1020,test_loss 1616.773560
step 1030,test_loss 1143.031250
step 1040,test_loss 1489.623291
step 1050,test_loss 1476.880371
step 1060,test_loss 1370.258667
step 1070,test_loss 1141.999878
step 1080,test_loss 1166.933228
step 1090,test_loss 1108.787476
step 1100,test_loss 960.834412
step 1110,test_loss 1265.723999
step 1120,test_loss 845.209717
step 1130,test_loss 1160.951294
step 1140,test_loss 1252.884766
step 1150,test_loss 947.760193
step 1160,test_loss 949.110107
step 1170,test_loss 703.865845
step 1180,test_loss 785.185425
step 1190,test_loss 909.500916
step 1200,test_loss 739.722473
step 1210,test_loss 752.640686
step 1220,test_loss 621.177429
step 1230,test_loss 664.028809
step 1240,test_loss 842.765198
step 1250,test_loss 537.212891
step 1260,test_loss 612.894226
step 1270,test_loss 515.576599
step 1280,test_loss 595.646362
step 1290,test_loss 591.783936
step 1300,test_loss 636.358215
step 1310,test_loss 421.886414
step 1320,test_loss 454.036652
step 1330,test_loss 441.143738
step 1340,test_loss 393.075867
step 1350,test_loss 415.702820
step 1360,test_loss 448.445557
step 1370,test_loss 350.026550
step 1380,test_loss 391.277832
step 1390,test_loss 458.568481
step 1400,test_loss 423.181671
step 1410,test_loss 410.131195
step 1420,test_loss 336.751404
step 1430,test_loss 279.585114
step 1440,test_loss 244.326080
step 1450,test_loss 322.283997
step 1460,test_loss 275.522095
step 1470,test_loss 240.525375
step 1480,test_loss 272.072205
step 1490,test_loss 202.847000
step 1500,test_loss 276.935272
step 1510,test_loss 211.451935
step 1520,test_loss 235.174835
step 1530,test_loss 201.162201
step 1540,test_loss 190.168976
step 1550,test_loss 184.773560
step 1560,test_loss 214.716644
step 1570,test_loss 171.767059
step 1580,test_loss 174.465897
step 1590,test_loss 167.292999
step 1600,test_loss 169.164307
step 1610,test_loss 169.313507
step 1620,test_loss 172.326813
step 1630,test_loss 157.986420
step 1640,test_loss 146.000031
step 1650,test_loss 134.339294
step 1660,test_loss 134.111725
step 1670,test_loss 129.241135
step 1680,test_loss 133.466339
step 1690,test_loss 128.226562
step 1700,test_loss 125.525810
step 1710,test_loss 115.412491
step 1720,test_loss 111.956238
step 1730,test_loss 114.153198
step 1740,test_loss 115.078041
step 1750,test_loss 107.463058
step 1760,test_loss 100.881134
step 1770,test_loss 108.238266
step 1780,test_loss 105.143570
step 1790,test_loss 105.204819
step 1800,test_loss 92.495026
step 1810,test_loss 96.282784
step 1820,test_loss 94.634003
step 1830,test_loss 88.530350
step 1840,test_loss 94.419586
step 1850,test_loss 90.118423
step 1860,test_loss 90.100471
step 1870,test_loss 92.007225
step 1880,test_loss 83.837082
step 1890,test_loss 88.835732
step 1900,test_loss 84.472664
step 1910,test_loss 84.264526
step 1920,test_loss 83.923073
step 1930,test_loss 81.012611
step 1940,test_loss 77.510666
step 1950,test_loss 81.193970
step 1960,test_loss 81.355026
step 1970,test_loss 79.240715
step 1980,test_loss 75.555099
step 1990,test_loss 76.817108
step 2000,test_loss 77.046547
step 2010,test_loss 75.549843
step 2020,test_loss 77.543091
step 2030,test_loss 74.993935
step 2040,test_loss 75.710793
step 2050,test_loss 74.185966
step 2060,test_loss 77.117058
step 2070,test_loss 74.278481
step 2080,test_loss 73.415520
step 2090,test_loss 71.878136
step 2100,test_loss 73.699829
step 2110,test_loss 73.163803
step 2120,test_loss 72.371719
step 2130,test_loss 73.948982
step 2140,test_loss 72.047447
step 2150,test_loss 71.245361
step 2160,test_loss 72.283531
step 2170,test_loss 71.045090
step 2180,test_loss 70.661865
step 2190,test_loss 70.973740
step 2200,test_loss 71.079689
step 2210,test_loss 71.110039
step 2220,test_loss 70.308311
step 2230,test_loss 69.957397
step 2240,test_loss 70.038406
step 2250,test_loss 70.325066
step 2260,test_loss 70.040375
step 2270,test_loss 70.044968
step 2280,test_loss 69.895622
step 2290,test_loss 69.759949
step 2300,test_loss 69.858398
step 2310,test_loss 69.752213
step 2320,test_loss 69.294426
step 2330,test_loss 69.490250
step 2340,test_loss 69.516228
step 2350,test_loss 69.264801
step 2360,test_loss 69.474541
step 2370,test_loss 69.354004
step 2380,test_loss 69.242920
step 2390,test_loss 69.212044
step 2400,test_loss 69.350334
step 2410,test_loss 69.142731
step 2420,test_loss 69.137589
step 2430,test_loss 69.149040
step 2440,test_loss 69.254677
step 2450,test_loss 69.260796
step 2460,test_loss 69.097015
step 2470,test_loss 69.137276
step 2480,test_loss 69.074081
step 2490,test_loss 69.078163
step 2500,test_loss 69.126526
step 2510,test_loss 69.039513
step 2520,test_loss 68.991615
step 2530,test_loss 68.940292
step 2540,test_loss 68.944313
step 2550,test_loss 68.915138
step 2560,test_loss 68.998390
step 2570,test_loss 68.909355
step 2580,test_loss 68.906288
step 2590,test_loss 68.861900
step 2600,test_loss 68.936501
step 2610,test_loss 68.967476
step 2620,test_loss 68.890762
step 2630,test_loss 68.898155
step 2640,test_loss 68.884277
step 2650,test_loss 68.880043
step 2660,test_loss 68.900604
step 2670,test_loss 68.914810
step 2680,test_loss 68.907867
step 2690,test_loss 68.896652
step 2700,test_loss 68.859848
step 2710,test_loss 68.878159
step 2720,test_loss 68.902023
step 2730,test_loss 68.873718
step 2740,test_loss 68.916138
step 2750,test_loss 68.877785
step 2760,test_loss 68.864380
step 2770,test_loss 68.869530
step 2780,test_loss 68.861702
step 2790,test_loss 68.882233
step 2800,test_loss 68.869751
step 2810,test_loss 68.869904
step 2820,test_loss 68.888641
step 2830,test_loss 68.850166
step 2840,test_loss 68.856880
step 2850,test_loss 68.889809
step 2860,test_loss 68.859596
step 2870,test_loss 68.855354
step 2880,test_loss 68.882050
step 2890,test_loss 68.885651
step 2900,test_loss 68.859543
step 2910,test_loss 68.818733
step 2920,test_loss 68.832703
step 2930,test_loss 68.852486
step 2940,test_loss 68.870384
step 2950,test_loss 68.855949
step 2960,test_loss 68.841423
step 2970,test_loss 68.872490
step 2980,test_loss 68.883705
step 2990,test_loss 68.855988
step 3000,test_loss 68.870636
step 3010,test_loss 68.862991
step 3020,test_loss 68.870255
step 3030,test_loss 68.863777
step 3040,test_loss 68.900009
step 3050,test_loss 68.891266
step 3060,test_loss 68.879257
step 3070,test_loss 68.858406
step 3080,test_loss 68.860107
step 3090,test_loss 68.874466
step 3100,test_loss 68.860497
step 3110,test_loss 68.858696
step 3120,test_loss 68.883636
step 3130,test_loss 68.876640
step 3140,test_loss 68.868874
step 3150,test_loss 68.874413
step 3160,test_loss 68.827744
step 3170,test_loss 68.875488
step 3180,test_loss 68.881233
step 3190,test_loss 68.859863
step 3200,test_loss 68.885933
step 3210,test_loss 68.869545
step 3220,test_loss 68.842995
step 3230,test_loss 68.863731
step 3240,test_loss 68.871620
step 3250,test_loss 68.887390
step 3260,test_loss 68.847641
step 3270,test_loss 68.856567
step 3280,test_loss 68.860397
step 3290,test_loss 68.862595
step 3300,test_loss 68.884712
step 3310,test_loss 68.877060
step 3320,test_loss 68.854836
step 3330,test_loss 68.852036
step 3340,test_loss 68.841431
step 3350,test_loss 68.856216
step 3360,test_loss 68.901741
step 3370,test_loss 68.879402
step 3380,test_loss 68.848961
step 3390,test_loss 68.894516
step 3400,test_loss 68.845978
step 3410,test_loss 68.843407
step 3420,test_loss 68.880623
step 3430,test_loss 68.874138
step 3440,test_loss 68.875328
step 3450,test_loss 68.847839
step 3460,test_loss 68.856407
step 3470,test_loss 68.871277
step 3480,test_loss 68.878189
step 3490,test_loss 68.884209
step 3500,test_loss 68.842300
step 3510,test_loss 68.860901
step 3520,test_loss 68.872070
step 3530,test_loss 68.859085
step 3540,test_loss 68.865288
step 3550,test_loss 68.876419
step 3560,test_loss 68.869156
step 3570,test_loss 68.886154
step 3580,test_loss 68.872124
step 3590,test_loss 68.894897
step 3600,test_loss 68.877914
step 3610,test_loss 68.880173
step 3620,test_loss 68.901749
step 3630,test_loss 68.838867
step 3640,test_loss 68.871284
step 3650,test_loss 68.844742
step 3660,test_loss 68.849792
step 3670,test_loss 68.876732
step 3680,test_loss 68.837036
step 3690,test_loss 68.862450
step 3700,test_loss 68.875610
step 3710,test_loss 68.856522
step 3720,test_loss 68.859993
step 3730,test_loss 68.869888
step 3740,test_loss 68.853325
step 3750,test_loss 68.863533
step 3760,test_loss 68.858566
step 3770,test_loss 68.843681
step 3780,test_loss 68.858315
step 3790,test_loss 68.876717
step 3800,test_loss 68.875565
step 3810,test_loss 68.901901
step 3820,test_loss 68.881500
step 3830,test_loss 68.873932
step 3840,test_loss 68.833977
step 3850,test_loss 68.854820
step 3860,test_loss 68.882072
step 3870,test_loss 68.848206
step 3880,test_loss 68.874969
step 3890,test_loss 68.866570
step 3900,test_loss 68.849907
step 3910,test_loss 68.879288
step 3920,test_loss 68.866447
step 3930,test_loss 68.845490
step 3940,test_loss 68.854675
step 3950,test_loss 68.858597
step 3960,test_loss 68.850449
step 3970,test_loss 68.853188
step 3980,test_loss 68.881699
step 3990,test_loss 68.852928
step 4000,test_loss 68.903976
step 4010,test_loss 68.883759
step 4020,test_loss 68.873917
step 4030,test_loss 68.858192
step 4040,test_loss 68.883629
step 4050,test_loss 68.888466
step 4060,test_loss 68.867790
step 4070,test_loss 68.853691
step 4080,test_loss 68.888000
step 4090,test_loss 68.863617
step 4100,test_loss 68.844299
step 4110,test_loss 68.887550
step 4120,test_loss 68.867096
step 4130,test_loss 68.840004
step 4140,test_loss 68.872429
step 4150,test_loss 68.880051
step 4160,test_loss 68.840759
step 4170,test_loss 68.876610
step 4180,test_loss 68.866318
step 4190,test_loss 68.864395
step 4200,test_loss 68.866844
step 4210,test_loss 68.870064
step 4220,test_loss 68.835907
step 4230,test_loss 68.853455
step 4240,test_loss 68.866196
step 4250,test_loss 68.892769
step 4260,test_loss 68.853798
step 4270,test_loss 68.830009
step 4280,test_loss 68.863434
step 4290,test_loss 68.843948
step 4300,test_loss 68.881226
step 4310,test_loss 68.859001
step 4320,test_loss 68.858551
step 4330,test_loss 68.865173
step 4340,test_loss 68.851601
step 4350,test_loss 68.886436
step 4360,test_loss 68.861244
step 4370,test_loss 68.834076
step 4380,test_loss 68.848862
step 4390,test_loss 68.852104
step 4400,test_loss 68.858315
step 4410,test_loss 68.841331
step 4420,test_loss 68.855446
step 4430,test_loss 68.870331
step 4440,test_loss 68.888527
step 4450,test_loss 68.869141
step 4460,test_loss 68.870499
step 4470,test_loss 68.873802
step 4480,test_loss 68.856995
step 4490,test_loss 68.861732
step 4500,test_loss 68.848595
step 4510,test_loss 68.877945
step 4520,test_loss 68.862206
step 4530,test_loss 68.872231
step 4540,test_loss 68.858627
step 4550,test_loss 68.879288
step 4560,test_loss 68.857910
step 4570,test_loss 68.884712
step 4580,test_loss 68.886894
step 4590,test_loss 68.882500
step 4600,test_loss 68.884560
step 4610,test_loss 68.845879
step 4620,test_loss 68.844276
step 4630,test_loss 68.860344
step 4640,test_loss 68.869537
step 4650,test_loss 68.862511
step 4660,test_loss 68.886688
step 4670,test_loss 68.858307
step 4680,test_loss 68.870544
step 4690,test_loss 68.895294
step 4700,test_loss 68.836151
step 4710,test_loss 68.854034
step 4720,test_loss 68.896736
step 4730,test_loss 68.873703
step 4740,test_loss 68.850273
step 4750,test_loss 68.877182
step 4760,test_loss 68.850876
step 4770,test_loss 68.841743
step 4780,test_loss 68.863495
step 4790,test_loss 68.880394
step 4800,test_loss 68.865532
step 4810,test_loss 68.880829
step 4820,test_loss 68.898735
step 4830,test_loss 68.868607
step 4840,test_loss 68.851143
step 4850,test_loss 68.888962
step 4860,test_loss 68.861221
step 4870,test_loss 68.856529
step 4880,test_loss 68.870476
step 4890,test_loss 68.859734
step 4900,test_loss 68.859306
step 4910,test_loss 68.870979
step 4920,test_loss 68.879097
step 4930,test_loss 68.886711
step 4940,test_loss 68.869553
step 4950,test_loss 68.861496
step 4960,test_loss 68.846848
step 4970,test_loss 68.867828
step 4980,test_loss 68.883766
step 4990,test_loss 68.881592
('Tensorflow R2: ', 0.99999761462739589)
('Sklearn R2: ', 1.0)
ubuntu@VM-12-146-ubuntu:~$ 
View Code

 

posted @ 2019-08-08 17:07  Johnny、  阅读(210)  评论(0编辑  收藏  举报