import datetime
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.ensemble import RandomForestRegressor
plt.rcParams["font.sans-serif"] = ["SimHei"]
file_path = r"../../机器学习数据/data_temps1.csv"
df = pd.read_csv(file_path)
df = pd.get_dummies(df)
data = train_test_split(df, shuffle=True, test_size=0.3, random_state=100)
train_data = data[0]
train_feature = train_data.drop(["当天最高温度"], axis=1)
train_label = train_data["当天最高温度"]
test_data = data[1]
test_feature = test_data.drop(["当天最高温度"], axis=1)
test_label = test_data["当天最高温度"]
n_estimators = [x for x in range(10, 101, 10)]
max_depth = [2, 4]
bootstrap = [True, False]
param_grid = {"n_estimators": n_estimators,
"max_depth": max_depth,
"bootstrap": bootstrap}
rf = RandomForestRegressor()
clf = GridSearchCV(estimator=rf, param_grid=param_grid, cv=5, verbose=5)
clf.fit(train_feature, train_label)
print(clf.best_params_)
print(clf.score(train_feature, train_label))
print(clf.score(test_feature, test_label))
pre_label = clf.predict(test_feature)
test_label = test_label.to_numpy()
plt.plot(pre_label)
plt.plot(test_label)
plt.title("拟合图")
plt.legend({"预测曲线", "真实曲线"})
plt.show()
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通