kaggle House_Price_XGBoost

kaggle House_Price_final

代码

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Imputer
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt
from sklearn.preprocessing import Imputer
from xgboost import XGBRegressor

train_path = r"C:\Users\cbattle\Desktop\train.csv"
test_path = r"C:\Users\cbattle\Desktop\test.csv"
out_path = r"C:\Users\cbattle\Desktop\out.csv"

# 读入数据
train = pd.read_csv(train_path)
test = pd.read_csv(test_path)
print('train:',train.shape)
print('test :',test.shape)

# 划分X,y
X = train.drop(['Id','SalePrice'],axis=1)
y = train['SalePrice']
Xtest = test.drop(['Id'],axis=1)
print('X    :',X.shape)
print('y    :',y.shape)
print('Xtest:',Xtest.shape)
# for col in X:
#     print(X[col].dtype,col)

key = [col for col in X
      if X[col].dtype in ['int64','float64']
      or X[col].dtype == 'object' and X[col].nunique()<10
      ]
X = X[key]
Xtest = Xtest[key]

# 独热编码
key = [col for col in X
      if X[col].dtype in ['int64','float64']
      or X[col].dtype == 'object' and X[col].nunique()<10
      ]
X = X[key]
Xtest = Xtest[key]

print(X.shape, Xtest.shape)
X = pd.get_dummies(X)
Xtest = pd.get_dummies(Xtest)
X, Xtest = X.align(Xtest, join = 'left', axis=1)
print(X.shape, Xtest.shape)

# 填补空值
my_imputer = Imputer()
X = my_imputer.fit_transform(X)
Xtest = my_imputer.transform(Xtest)
print(X.shape, Xtest.shape)

# 决策树
# decisionTree = DecisionTreeRegressor()
# decisionTree.fit(X,y)
# ans = decisionTree.predict(Xtest)

# XG boost
xgb = XGBRegressor()
xgb.fit(X,y,verbose=False)
ans = xgb.predict(Xtest)

# my_model = XGBRegressor(n_estimators=1000)
# my_model.fit(train_X, train_y, early_stopping_rounds=5, 
#              eval_set=[(val_X, val_y)], verbose=False)
# ans = my_model.predict(Xtest)

# 输出
myAns = pd.DataFrame({'Id':test['Id'],'SalePrice':ans})
myAns.to_csv(r"C:\Users\cbattle\Desktop\out.csv", index=False)
print('ok')
posted @ 2018-04-12 22:22  cbattle  阅读(385)  评论(0编辑  收藏  举报