# 导入必要的库
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, cross_validate
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score, make_scorer
)
# Step 1: 加载 iris 数据集并分割为训练集和测试集
iris = load_iris() # 加载数据集
X, y = iris.data, iris.target # 提取特征和标签
# 留出法分割数据集,测试集占 1/3,保证同分布
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=1/3, random_state=42, stratify=y
)
# Step 2: 初始化并训练随机森林分类器
# 初始化 RandomForestClassifier 模型
rf = RandomForestClassifier(n_estimators=100, random_state=42)
# 使用训练集训练模型
rf.fit(X_train, y_train)
# Step 3: 使用五折交叉验证评估模型性能
# 定义评估指标
scoring = {
'accuracy': make_scorer(accuracy_score),
'precision_macro': make_scorer(precision_score, average='macro'),
'recall_macro': make_scorer(recall_score, average='macro'),
'f1_macro': make_scorer(f1_score, average='macro')
}
# 五折交叉验证
cv_results = cross_validate(rf, X_train, y_train, cv=5, scoring=scoring)
# 打印交叉验证结果
print("五折交叉验证结果:")
for metric in scoring.keys():
mean = cv_results['test_' + metric].mean()
std = cv_results['test_' + metric].std()
print(f"{metric}: {mean:.4f} ± {std:.4f}")
# Step 4: 测试集评估模型性能
# 测试集预测
y_test_pred = rf.predict(X_test)
# 计算性能指标
print("\n测试集性能报告:")
print(f"准确度: {accuracy_score(y_test, y_test_pred):.4f}")
print(f"精度: {precision_score(y_test, y_test_pred, average='macro'):.4f}")
print(f"召回率: {recall_score(y_test, y_test_pred, average='macro'):.4f}")
print(f"F1 值: {f1_score(y_test, y_test_pred, average='macro'):.4f}")
package com.men.common;
import java.io.*;
/**
* 文件读取工具类
*/
public class FileUtil
{
/**
* 读取文件内容,作为字符串返回
*/
public static String readFileAsString(String filePath) throws IOException
{
File file = new File(filePath);
if (!file.exists())
{
throw new FileNotFoundException(filePath);
}
if (file.length() > 1024 * 1024 * 1024)
{
throw new IOException("File is too large");
}
StringBuilder sb = new StringBuilder((int) (file.length()));
// 创建字节输入流
FileInputStream fis = new FileInputStream(filePath);
// 创建一个长度为10240的Buffer
byte[] bbuf = new byte[10240];
// 用于保存实际读取的字节数
int hasRead = 0;
while ((hasRead = fis.read(bbuf)) > 0)
{
sb.append(new String(bbuf, 0, hasRead));
}
fis.close();
return sb.toString();
}
/**
* 根据文件路径读取byte[] 数组
*/
public static byte[] readFileByBytes(String filePath) throws IOException
{
File file = new File(filePath);
if (!file.exists())
{
throw new FileNotFoundException(filePath);
} else
{
ByteArrayOutputStream bos = new ByteArrayOutputStream((int) file.length());
BufferedInputStream in = null;
try
{
in = new BufferedInputStream(new FileInputStream(file));
short bufSize = 1024;
byte[] buffer = new byte[bufSize];
int len1;
while (-1 != (len1 = in.read(buffer, 0, bufSize)))
{
bos.write(buffer, 0, len1);
}
byte[] var7 = bos.toByteArray();
return var7;
}
finally
{
try
{
if (in != null)
{
in.close();
}
}
catch (IOException var14)
{
var14.printStackTrace();
}
bos.close();
}
}
}
}
package com.men.common;
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
/**
* Http 工具类
*/
public class HttpUtil {
/**
* 发送 POST 请求,默认 contentType 为 application/json
* @param requestUrl 请求 URL
* @param accessToken 百度的 accessToken
* @param params 请求参数 (JSON 格式)
* @return 请求响应的字符串
* @throws Exception
*/
public static String post(String requestUrl, String accessToken, String params)
throws Exception {
String contentType = "application/json"; // 对应 API 需要传递 JSON 格式数据
return HttpUtil.post(requestUrl, accessToken, contentType, params);
}
/**
* 发送 POST 请求,允许指定 contentType
* @param requestUrl 请求 URL
* @param accessToken 百度的 accessToken
* @param contentType 请求内容的类型
* @param params 请求参数 (JSON 格式)
* @return 请求响应的字符串
* @throws Exception
*/
public static String post(String requestUrl, String accessToken, String contentType, String params)
throws Exception {
String encoding = "UTF-8"; // 默认编码为 UTF-8
return HttpUtil.post(requestUrl, accessToken, contentType, params, encoding);
}
/**
* 发送 POST 请求,支持指定请求的编码
* @param requestUrl 请求 URL
* @param accessToken 百度的 accessToken
* @param contentType 请求内容的类型
* @param params 请求参数 (JSON 格式)
* @param encoding 编码格式
* @return 请求响应的字符串
* @throws Exception
*/
public static String post(String requestUrl, String accessToken, String contentType, String params, String encoding)
throws Exception {
// 为了在请求 URL 中包含 access_token
String url = requestUrl + "?access_token=" + accessToken;
return HttpUtil.postGeneralUrl(url, contentType, params, encoding);
}
/**
* 通用的 POST 请求方法
* @param generalUrl 请求的 URL
* @param contentType 请求内容类型
* @param params 请求的参数 (JSON 格式)
* @param encoding 编码方式
* @return 请求响应的字符串
* @throws Exception
*/
public static String postGeneralUrl(String generalUrl, String contentType, String params, String encoding)
throws Exception {
// 连接目标 URL
URL url = new URL(generalUrl);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
// 设置请求方法为 POST
connection.setRequestMethod("POST");
// 设置请求头,指定 Content-Type 为 application/json
connection.setRequestProperty("Content-Type", contentType);
connection.setRequestProperty("Connection", "Keep-Alive");
connection.setUseCaches(false);
connection.setDoOutput(true); // 允许输出内容
connection.setDoInput(true); // 允许输入内容
// 写入请求参数(JSON)
try (DataOutputStream out = new DataOutputStream(connection.getOutputStream())) {
out.write(params.getBytes(encoding)); // 发送请求参数
out.flush();
}
// 建立连接
connection.connect();
// 获取响应内容
StringBuilder result = new StringBuilder();
try (BufferedReader in = new BufferedReader(
new InputStreamReader(connection.getInputStream(), encoding))) {
String getLine;
while ((getLine = in.readLine()) != null) {
result.append(getLine);
}
}
return result.toString();
}
}
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· winform 绘制太阳,地球,月球 运作规律
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人