7_useful_pandas_code

# plot the numerical columns vs the output SalePrice to visualise the (linear) relationship

for col in cols_to_use[:-1]:
    data.plot.scatter(x=col, y='SalePrice', ylim=(0,800000))
    plt.show()
col = 'OverallQual'
linreg = LinearRegression()
linreg.fit(X_train[col].to_frame(), y_train)
print('Train set')
pred = linreg.predict(X_train[col].to_frame())
print('Linear Regression mse: {}'.format(mean_squared_error(y_train, pred)))
print('Test set')
pred = linreg.predict(X_test[col].to_frame())
print('Linear Regression mse: {}'.format(mean_squared_error(y_test, pred)))
print()
X_test['error'] = X_test.SalePrice - pred
print('Error Stats')
print(X_test['error'].describe())
X_test.plot.scatter(x=col, y='error')

 

posted @ 2020-02-24 21:23  纯洁的小兄弟  阅读(112)  评论(0编辑  收藏  举报