python-[panda]-[sklearn]-[matplotlib]-线性预测

文章:

http://python.jobbole.com/81215/

python的函数库好强大!看完这篇博再也不会用matlab了~~

这篇文章使用【panda】读取csv的数据,使用【sklearn】中的linear_model训练模型并进行线性预测,使用【matplotlib】将拟合的情况用图表示出来。

                   下面的表格是用于训练模型的表格:

代码如下:

# -*- coding: utf-8 -*-
'''
Created on 2016/11/26

@author: chensi
'''
# Required Packages
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn import datasets, linear_model
from numpy.ma.core import getdata

# Function to get data
def get_data(file_name):
    data = pd.read_excel(file_name)
    X_parameter = []
    Y_parameter = []
    for single_square_feet ,single_price_value in zip(data['square_feet'],data['price']):
        X_parameter.append([float(single_square_feet)])
        Y_parameter.append(float(single_price_value))
    return X_parameter,Y_parameter

# Function for Fitting our data to Linear model
def linear_model_main(X_parameters,Y_parameters,predict_value):
# Create linear regression object
    regr = linear_model.LinearRegression()
    regr.fit(X_parameters, Y_parameters)
    predict_outcome = regr.predict(predict_value)
    predictions = {}
    predictions['intercept'] = regr.intercept_
    predictions['coefficient'] = regr.coef_
    predictions['predicted_value'] = predict_outcome
    return predictions



# Function to show the resutls of linear fit model
def show_linear_line(X_parameters,Y_parameters):
# Create linear regression object
    regr = linear_model.LinearRegression()
    regr.fit(X_parameters, Y_parameters)
    plt.scatter(X_parameters,Y_parameters,color='blue')
    plt.plot(X_parameters,regr.predict(X_parameters),color='red',linewidth=4)
    plt.xticks(())
    plt.yticks(())
    plt.show()
#---------Test---------------
#----------------------------
x,y = get_data("g:/input_data.csv")
show_linear_line(x,y)
print(linear_model_main(x,y,150))
#----------------------------
#----------------------------

     输出的图:

 

例子二:

 

代码:

# -*- coding: utf-8 -*-
'''
Created on 2016/11/26
 
@author: chensi
'''


# Required Packages
import csv
import sys
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn import datasets, linear_model
 
# Function to get data
def get_data(file_name):
    data = pd.read_excel(file_name)
    flash_x_parameter = []
    flash_y_parameter = []
    arrow_x_parameter = []
    arrow_y_parameter = []
    for x1,y1,x2,y2 in zip(data['flash_episode_number'],data['flash_us_viewers'],data['arrow_episode_number'],data['arrow_us_viewers']):
        flash_x_parameter.append([float(x1)])
        flash_y_parameter.append(float(y1))
        arrow_x_parameter.append([float(x2)])
        arrow_y_parameter.append(float(y2))
    return flash_x_parameter,flash_y_parameter,arrow_x_parameter,arrow_y_parameter
 
# Function to know which Tv show will have more viewers
def more_viewers(x1,y1,x2,y2):
    regr1 = linear_model.LinearRegression()
    regr1.fit(x1, y1)
    predicted_value1 = regr1.predict(9)
    print(predicted_value1)
    regr2 = linear_model.LinearRegression()
    regr2.fit(x2, y2)
    predicted_value2 = regr2.predict(9)
#print predicted_value1
#print predicted_value2
    if predicted_value1 > predicted_value2:
        print ("The Flash Tv Show will have more viewers for next week")
    else:
        print ("Arrow Tv Show will have more viewers for next week")
 
x1,y1,x2,y2 = get_data('G:/input_data_2.xlsx')
#print x1,y1,x2,y2
more_viewers(x1,y1,x2,y2)

输出:

 

posted on 2016-11-26 15:10  xcshehe  阅读(1265)  评论(0编辑  收藏  举报

导航