线性回归
1 2 # -*- coding: UTF-8 -*- 3 """ 4 此脚本用于展示使用sklearn搭建线性回归模型 5 """ 6 7 8 import os 9 import sys 10 11 import numpy as np 12 import matplotlib.pyplot as plt 13 import pandas as pd 14 from sklearn import linear_model 15 16 17 def evaluateModel(model, testData, features, labels): 18 """ 19 计算线性模型的均方差和决定系数 20 参数 21 ---- 22 model : LinearRegression, 训练完成的线性模型 23 testData : DataFrame,测试数据 24 features : list[str],特征名列表 25 labels : list[str],标签名列表 26 返回 27 ---- 28 error : np.float64,均方差 29 score : np.float64,决定系数 30 """ 31 # 均方差(The mean squared error),均方差越小越好 32 error = np.mean( 33 (model.predict(testData[features]) - testData[labels]) ** 2) 34 # 决定系数(Coefficient of determination),决定系数越接近1越好 35 score = model.score(testData[features], testData[labels]) 36 return error, score 37 38 39 def visualizeModel(model, data, features, labels, error, score): 40 """ 41 模型可视化 42 """ 43 # 为在Matplotlib中显示中文,设置特殊字体 44 plt.rcParams['font.sans-serif']=['SimHei'] 45 # 创建一个图形框 46 fig = plt.figure(figsize=(6, 6), dpi=80) 47 # 在图形框里只画一幅图 48 ax = fig.add_subplot(111) 49 # 在Matplotlib中显示中文,需要使用unicode 50 # 在Python3中,str不需要decode 51 if sys.version_info[0] == 3: 52 ax.set_title(u'%s' % "线性回归示例") 53 else: 54 ax.set_title(u'%s' % "线性回归示例".decode("utf-8")) 55 ax.set_xlabel('$x$') 56 ax.set_ylabel('$y$') 57 # 画点图,用蓝色圆点表示原始数据 58 # 在Python3中,str不需要decode 59 if sys.version_info[0] == 3: 60 ax.scatter(data[features], data[labels], color='b', 61 label=u'%s: $y = x + \epsilon$' % "真实值") 62 else: 63 ax.scatter(data[features], data[labels], color='b', 64 label=u'%s: $y = x + \epsilon$' % "真实值".decode("utf-8")) 65 # 根据截距的正负,打印不同的标签 66 if model.intercept_ > 0: 67 # 画线图,用红色线条表示模型结果 68 # 在Python3中,str不需要decode 69 if sys.version_info[0] == 3: 70 ax.plot(data[features], model.predict(data[features]), color='r', 71 label=u'%s: $y = %.3fx$ + %.3f'\ 72 % ("预测值", model.coef_, model.intercept_)) 73 else: 74 ax.plot(data[features], model.predict(data[features]), color='r', 75 label=u'%s: $y = %.3fx$ + %.3f'\ 76 % ("预测值".decode("utf-8"), model.coef_, model.intercept_)) 77 else: 78 # 在Python3中,str不需要decode 79 if sys.version_info[0] == 3: 80 ax.plot(data[features], model.predict(data[features]), color='r', 81 label=u'%s: $y = %.3fx$ - %.3f'\ 82 % ("预测值", model.coef_, abs(model.intercept_))) 83 else: 84 ax.plot(data[features], model.predict(data[features]), color='r', 85 label=u'%s: $y = %.3fx$ - %.3f'\ 86 % ("预测值".decode("utf-8"), model.coef_, abs(model.intercept_))) 87 legend = plt.legend(shadow=True) 88 legend.get_frame().set_facecolor('#6F93AE') 89 # 显示均方差和决定系数 90 # 在Python3中,str不需要decode 91 if sys.version_info[0] == 3: 92 ax.text(0.99, 0.01, 93 u'%s%.3f\n%s%.3f'\ 94 % ("均方差:", error, "决定系数:", score), 95 style='italic', verticalalignment='bottom', horizontalalignment='right', 96 transform=ax.transAxes, color='m', fontsize=13) 97 else: 98 ax.text(0.99, 0.01, 99 u'%s%.3f\n%s%.3f'\ 100 % ("均方差:".decode("utf-8"), error, "决定系数:".decode("utf-8"), score), 101 style='italic', verticalalignment='bottom', horizontalalignment='right', 102 transform=ax.transAxes, color='m', fontsize=13) 103 # 展示上面所画的图片。图片将阻断程序的运行,直至所有的图片被关闭 104 # 在Python shell里面,可以设置参数"block=False",使阻断失效。 105 plt.show() 106 107 108 def trainModel(trainData, features, labels): 109 """ 110 利用训练数据,估计模型参数 111 参数 112 ---- 113 trainData : DataFrame,训练数据集,包含特征和标签 114 features : 特征名列表 115 labels : 标签名列表 116 返回 117 ---- 118 model : LinearRegression, 训练好的线性模型 119 """ 120 # 创建一个线性回归模型 121 model = linear_model.LinearRegression() 122 # 训练模型,估计模型参数 123 model.fit(trainData[features], trainData[labels]) 124 return model 125 126 127 def linearModel(data): 128 """ 129 线性回归模型建模步骤展示 130 参数 131 ---- 132 data : DataFrame,建模数据 133 """ 134 features = ["x"] 135 labels = ["y"] 136 # 划分训练集和测试集 137 trainData = data[:15] 138 testData = data[15:] 139 # 产生并训练模型 140 model = trainModel(trainData, features, labels) 141 # 评价模型效果 142 error, score = evaluateModel(model, testData, features, labels) 143 # 图形化模型结果 144 visualizeModel(model, data, features, labels, error, score) 145 146 147 def readData(path): 148 """ 149 使用pandas读取数据 150 """ 151 data = pd.read_csv(path) 152 return data 153 154 155 if __name__ == "__main__": #主模块的名字是__main__,import的模块名字是自己 156 homePath = os.path.dirname(os.path.abspath(__file__)) #os.path.dirname 是去掉文件名的路径 ,abspath获取当前文件路径 157 # Windows下的存储路径与Linux并不相同 158 if os.name == "nt": #判断当前使用的平台,nt为windows 159 dataPath = "%s\\data\\simple_example.csv" % homePath 160 else: 161 dataPath = "%s/data/simple_example.csv" % homePath 162 data = readData(dataPath) 163 linearModel(data) 164 © 2019 GitHub, Inc. 165 Terms 166 Privacy 167 Security 168 Status 169 Help 170 Contact GitHub 171 Pricing 172 API 173 Training 174 Blog 175 About