线性回归-预测销售

一、代码

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Lasso, Ridge
from sklearn.model_selection import GridSearchCV

if __name__ == "__main__":
    # pandas读入
    data = pd.read_csv('D:/data_set/Advertising.csv') # TV、Radio、Newspaper、Sales
    # print(type(data)) # <class 'pandas.core.frame.DataFrame'>
    print('data= \n',data)
    x = data[['TV', 'Radio', 'Newspaper']]
    # x = data[['TV', 'Radio']]
    y = data['Sales']
    print('X= \n',x)
    print('Y= \n',y)
    
    x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=1, train_size=0.8)
    # model = Lasso()
    model = Ridge()
    alpha_can = np.logspace(-3, 2, 10)
    print(alpha_can)
    np.set_printoptions(suppress=True) # 科学计数法(默认)
    print('alpha_can = \n', alpha_can)

    lasso_model = GridSearchCV(model, param_grid={'alpha': alpha_can}, cv=5) # 五折交叉验证
    lasso_model.fit(x_train, y_train)
    print('超参数:\n', lasso_model.best_params_)

    order = y_test.argsort(axis=0)
    print('order:',order)
    y_test = y_test.values[order]
    print('y_test_now = \n',y_test)
    x_test = x_test.values[order, :]
    print('x_test_now = \n',x_test)

    y_hat = lasso_model.predict(x_test) # y_hat是预测值
    print('lasso_model.score(x_test,y_test):',lasso_model.score(x_test, y_test))
    mse = np.average((y_hat - np.array(y_test)) ** 2)  # Mean Squared Error
    rmse = np.sqrt(mse)  # Root Mean Squared Err  or
    print('mse = ',mse,'rmse = ',rmse)
    t = np.arange(len(x_test))
    mpl.rcParams['font.sans-serif'] = ['simHei']
    mpl.rcParams['axes.unicode_minus'] = False
    plt.figure(facecolor='w')
    plt.plot(t, y_test, 'r-', linewidth=2, label='真实数据')
    plt.plot(t, y_hat, 'g-', linewidth=2, label='预测数据')
    plt.title('线性回归预测销量', fontsize=18)
    plt.legend(loc='upper left')
    plt.grid(b=True, ls=':')
    plt.show()

二、运行结果

data= 
      Unnamed: 0     TV  Radio  Newspaper  Sales
0             1  230.1   37.8       69.2   22.1
1             2   44.5   39.3       45.1   10.4
2             3   17.2   45.9       69.3    9.3
3             4  151.5   41.3       58.5   18.5
4             5  180.8   10.8       58.4   12.9
5             6    8.7   48.9       75.0    7.2
6             7   57.5   32.8       23.5   11.8
7             8  120.2   19.6       11.6   13.2
8             9    8.6    2.1        1.0    4.8
9            10  199.8    2.6       21.2   10.6
10           11   66.1    5.8       24.2    8.6
11           12  214.7   24.0        4.0   17.4
12           13   23.8   35.1       65.9    9.2
13           14   97.5    7.6        7.2    9.7
14           15  204.1   32.9       46.0   19.0
15           16  195.4   47.7       52.9   22.4
16           17   67.8   36.6      114.0   12.5
17           18  281.4   39.6       55.8   24.4
18           19   69.2   20.5       18.3   11.3
19           20  147.3   23.9       19.1   14.6
20           21  218.4   27.7       53.4   18.0
21           22  237.4    5.1       23.5   12.5
22           23   13.2   15.9       49.6    5.6
23           24  228.3   16.9       26.2   15.5
24           25   62.3   12.6       18.3    9.7
25           26  262.9    3.5       19.5   12.0
26           27  142.9   29.3       12.6   15.0
27           28  240.1   16.7       22.9   15.9
28           29  248.8   27.1       22.9   18.9
29           30   70.6   16.0       40.8   10.5
..          ...    ...    ...        ...    ...
170         171   50.0   11.6       18.4    8.4
171         172  164.5   20.9       47.4   14.5
172         173   19.6   20.1       17.0    7.6
173         174  168.4    7.1       12.8   11.7
174         175  222.4    3.4       13.1   11.5
175         176  276.9   48.9       41.8   27.0
176         177  248.4   30.2       20.3   20.2
177         178  170.2    7.8       35.2   11.7
178         179  276.7    2.3       23.7   11.8
179         180  165.6   10.0       17.6   12.6
180         181  156.6    2.6        8.3   10.5
181         182  218.5    5.4       27.4   12.2
182         183   56.2    5.7       29.7    8.7
183         184  287.6   43.0       71.8   26.2
184         185  253.8   21.3       30.0   17.6
185         186  205.0   45.1       19.6   22.6
186         187  139.5    2.1       26.6   10.3
187         188  191.1   28.7       18.2   17.3
188         189  286.0   13.9        3.7   15.9
189         190   18.7   12.1       23.4    6.7
190         191   39.5   41.1        5.8   10.8
191         192   75.5   10.8        6.0    9.9
192         193   17.2    4.1       31.6    5.9
193         194  166.8   42.0        3.6   19.6
194         195  149.7   35.6        6.0   17.3
195         196   38.2    3.7       13.8    7.6
196         197   94.2    4.9        8.1    9.7
197         198  177.0    9.3        6.4   12.8
198         199  283.6   42.0       66.2   25.5
199         200  232.1    8.6        8.7   13.4

[200 rows x 5 columns]
X= 
         TV  Radio  Newspaper
0    230.1   37.8       69.2
1     44.5   39.3       45.1
2     17.2   45.9       69.3
3    151.5   41.3       58.5
4    180.8   10.8       58.4
5      8.7   48.9       75.0
6     57.5   32.8       23.5
7    120.2   19.6       11.6
8      8.6    2.1        1.0
9    199.8    2.6       21.2
10    66.1    5.8       24.2
11   214.7   24.0        4.0
12    23.8   35.1       65.9
13    97.5    7.6        7.2
14   204.1   32.9       46.0
15   195.4   47.7       52.9
16    67.8   36.6      114.0
17   281.4   39.6       55.8
18    69.2   20.5       18.3
19   147.3   23.9       19.1
20   218.4   27.7       53.4
21   237.4    5.1       23.5
22    13.2   15.9       49.6
23   228.3   16.9       26.2
24    62.3   12.6       18.3
25   262.9    3.5       19.5
26   142.9   29.3       12.6
27   240.1   16.7       22.9
28   248.8   27.1       22.9
29    70.6   16.0       40.8
..     ...    ...        ...
170   50.0   11.6       18.4
171  164.5   20.9       47.4
172   19.6   20.1       17.0
173  168.4    7.1       12.8
174  222.4    3.4       13.1
175  276.9   48.9       41.8
176  248.4   30.2       20.3
177  170.2    7.8       35.2
178  276.7    2.3       23.7
179  165.6   10.0       17.6
180  156.6    2.6        8.3
181  218.5    5.4       27.4
182   56.2    5.7       29.7
183  287.6   43.0       71.8
184  253.8   21.3       30.0
185  205.0   45.1       19.6
186  139.5    2.1       26.6
187  191.1   28.7       18.2
188  286.0   13.9        3.7
189   18.7   12.1       23.4
190   39.5   41.1        5.8
191   75.5   10.8        6.0
192   17.2    4.1       31.6
193  166.8   42.0        3.6
194  149.7   35.6        6.0
195   38.2    3.7       13.8
196   94.2    4.9        8.1
197  177.0    9.3        6.4
198  283.6   42.0       66.2
199  232.1    8.6        8.7

[200 rows x 3 columns]
Y=
 0      22.1
1      10.4
2       9.3
3      18.5
4      12.9
5       7.2
6      11.8
7      13.2
8       4.8
9      10.6
10      8.6
11     17.4
12      9.2
13      9.7
14     19.0
15     22.4
16     12.5
17     24.4
18     11.3
19     14.6
20     18.0
21     12.5
22      5.6
23     15.5
24      9.7
25     12.0
26     15.0
27     15.9
28     18.9
29     10.5
       ...
170     8.4
171    14.5
172     7.6
173    11.7
174    11.5
175    27.0
176    20.2
177    11.7
178    11.8
179    12.6
180    10.5
181    12.2
182     8.7
183    26.2
184    17.6
185    22.6
186    10.3
187    17.3
188    15.9
189     6.7
190    10.8
191     9.9
192     5.9
193    19.6
194    17.3
195     7.6
196     9.7
197    12.8
198    25.5
199    13.4
Name: Sales, Length: 200, dtype: float64
C:\Users\87823\.conda\envs\tensorflow\lib\site-packages\sklearn\model_selection\_split.py:2010: FutureWarning: From version 0.21, test_size will always complement train_size unless both are specified.      
  FutureWarning)
[  1.00000000e-03   3.59381366e-03   1.29154967e-02   4.64158883e-02
   1.66810054e-01   5.99484250e-01   2.15443469e+00   7.74263683e+00
   2.78255940e+01   1.00000000e+02]
alpha_can =
 [   0.001         0.00359381    0.0129155     0.04641589    0.16681005
    0.59948425    2.15443469    7.74263683   27.82559402  100.        ]
超参数:
 {'alpha': 7.7426368268112773}
order: 58     39
40     22
34      2
102    18
184    26
198     8
95     20
4      37
29     11
168    23
171    36
18     33
11     24
89     31
110    21
118    17
159     7
35     16
136    14
59     10
51      3
16     25
44     35
94     29
31     15
162     1
38     13
28      6
193     9
27     32
47     12
165     4
194    19
177    27
176    28
97     34
174    38
73     30
69      0
172     5
Name: Sales, dtype: int64
y_test_now =
 [  7.6   8.5   9.5   9.5  10.1  10.5  10.7  11.   11.3  11.5  11.5  11.7
  11.9  11.9  12.5  12.8  12.9  12.9  13.4  14.5  14.8  14.9  15.5  15.9
  15.9  16.6  16.7  16.9  17.1  17.3  17.4  17.6  18.4  18.9  19.6  20.2
  22.3  23.2  23.8  25.5]
x_test_now =
 [[  19.6   20.1   17. ]
 [  25.1   25.7   43.3]
 [  95.7    1.4    7.4]
 [  25.6   39.     9.3]
 [  43.1   26.7   35.1]
 [  70.6   16.    40.8]
 [ 100.4    9.6    3.6]
 [ 129.4    5.7   31.3]
 [  69.2   20.5   18.3]
 [ 107.4   14.    10.9]
 [ 222.4    3.4   13.1]
 [ 170.2    7.8   35.2]
 [ 112.9   17.4   38.6]
 [ 234.5    3.4   84.8]
 [  67.8   36.6  114. ]
 [ 290.7    4.1    8.5]
 [ 180.8   10.8   58.4]
 [ 131.7   18.4   34.6]
 [ 225.8    8.2   56.5]
 [ 164.5   20.9   47.4]
 [ 280.2   10.1   21.4]
 [ 188.4   18.1   25.6]
 [ 184.9   21.    22. ]
 [ 240.1   16.7   22.9]
 [ 125.7   36.9   79.2]
 [ 202.5   22.3   31.6]
 [ 109.8   47.8   51.4]
 [ 163.3   31.6   52.9]
 [ 215.4   23.6   57.6]
 [ 149.7   35.6    6. ]
 [ 214.7   24.     4. ]
 [ 253.8   21.3   30. ]
 [ 210.7   29.5    9.3]
 [ 248.8   27.1   22.9]
 [ 166.8   42.     3.6]
 [ 248.4   30.2   20.3]
 [ 216.8   43.9   27.2]
 [ 239.9   41.5   18.5]
 [ 210.8   49.6   37.7]
 [ 283.6   42.    66.2]]
lasso_model.score(x_test,y_test): 0.892714279041
mse =  1.99274576769 rmse =  1.41164647405

三、图片

posted @ 2020-06-22 21:27  小他_W  阅读(404)  评论(0编辑  收藏  举报