2024.12.17

# 导入必要的库
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, cross_validate
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score, make_scorer
)

# Step 1: 加载 iris 数据集并分割为训练集和测试集
iris = load_iris() # 加载数据集
X, y = iris.data, iris.target # 提取特征和标签

# 留出法分割数据集,测试集占 1/3,保证同分布
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=1/3, random_state=42, stratify=y
)

# Step 2: 初始化并训练随机森林分类器
# 初始化 RandomForestClassifier 模型
rf = RandomForestClassifier(n_estimators=100, random_state=42)

# 使用训练集训练模型
rf.fit(X_train, y_train)

# Step 3: 使用五折交叉验证评估模型性能
# 定义评估指标
scoring = {
'accuracy': make_scorer(accuracy_score),
'precision_macro': make_scorer(precision_score, average='macro'),
'recall_macro': make_scorer(recall_score, average='macro'),
'f1_macro': make_scorer(f1_score, average='macro')
}

# 五折交叉验证
cv_results = cross_validate(rf, X_train, y_train, cv=5, scoring=scoring)

# 打印交叉验证结果
print("五折交叉验证结果:")
for metric in scoring.keys():
mean = cv_results['test_' + metric].mean()
std = cv_results['test_' + metric].std()
print(f"{metric}: {mean:.4f} ± {std:.4f}")

# Step 4: 测试集评估模型性能
# 测试集预测
y_test_pred = rf.predict(X_test)

# 计算性能指标
print("\n测试集性能报告:")
print(f"准确度: {accuracy_score(y_test, y_test_pred):.4f}")
print(f"精度: {precision_score(y_test, y_test_pred, average='macro'):.4f}")
print(f"召回率: {recall_score(y_test, y_test_pred, average='macro'):.4f}")
print(f"F1 值: {f1_score(y_test, y_test_pred, average='macro'):.4f}")
package com.men.common;

import java.io.*;

/**
* 文件读取工具类
*/
public class FileUtil
{

/**
* 读取文件内容,作为字符串返回
*/
public static String readFileAsString(String filePath) throws IOException
{
File file = new File(filePath);
if (!file.exists())
{
throw new FileNotFoundException(filePath);
}

if (file.length() > 1024 * 1024 * 1024)
{
throw new IOException("File is too large");
}

StringBuilder sb = new StringBuilder((int) (file.length()));
// 创建字节输入流
FileInputStream fis = new FileInputStream(filePath);
// 创建一个长度为10240的Buffer
byte[] bbuf = new byte[10240];
// 用于保存实际读取的字节数
int hasRead = 0;
while ((hasRead = fis.read(bbuf)) > 0)
{
sb.append(new String(bbuf, 0, hasRead));
}
fis.close();
return sb.toString();
}

/**
* 根据文件路径读取byte[] 数组
*/
public static byte[] readFileByBytes(String filePath) throws IOException
{
File file = new File(filePath);
if (!file.exists())
{
throw new FileNotFoundException(filePath);
} else
{
ByteArrayOutputStream bos = new ByteArrayOutputStream((int) file.length());
BufferedInputStream in = null;

try
{
in = new BufferedInputStream(new FileInputStream(file));
short bufSize = 1024;
byte[] buffer = new byte[bufSize];
int len1;
while (-1 != (len1 = in.read(buffer, 0, bufSize)))
{
bos.write(buffer, 0, len1);
}

byte[] var7 = bos.toByteArray();
return var7;
}
finally
{
try
{
if (in != null)
{
in.close();
}
}
catch (IOException var14)
{
var14.printStackTrace();
}

bos.close();
}
}
}
}
package com.men.common;

import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;

/**
* Http 工具类
*/
public class HttpUtil {

/**
* 发送 POST 请求,默认 contentType 为 application/json
* @param requestUrl 请求 URL
* @param accessToken 百度的 accessToken
* @param params 请求参数 (JSON 格式)
* @return 请求响应的字符串
* @throws Exception
*/
public static String post(String requestUrl, String accessToken, String params)
throws Exception {
String contentType = "application/json"; // 对应 API 需要传递 JSON 格式数据
return HttpUtil.post(requestUrl, accessToken, contentType, params);
}

/**
* 发送 POST 请求,允许指定 contentType
* @param requestUrl 请求 URL
* @param accessToken 百度的 accessToken
* @param contentType 请求内容的类型
* @param params 请求参数 (JSON 格式)
* @return 请求响应的字符串
* @throws Exception
*/
public static String post(String requestUrl, String accessToken, String contentType, String params)
throws Exception {
String encoding = "UTF-8"; // 默认编码为 UTF-8
return HttpUtil.post(requestUrl, accessToken, contentType, params, encoding);
}

/**
* 发送 POST 请求,支持指定请求的编码
* @param requestUrl 请求 URL
* @param accessToken 百度的 accessToken
* @param contentType 请求内容的类型
* @param params 请求参数 (JSON 格式)
* @param encoding 编码格式
* @return 请求响应的字符串
* @throws Exception
*/
public static String post(String requestUrl, String accessToken, String contentType, String params, String encoding)
throws Exception {
// 为了在请求 URL 中包含 access_token
String url = requestUrl + "?access_token=" + accessToken;
return HttpUtil.postGeneralUrl(url, contentType, params, encoding);
}

/**
* 通用的 POST 请求方法
* @param generalUrl 请求的 URL
* @param contentType 请求内容类型
* @param params 请求的参数 (JSON 格式)
* @param encoding 编码方式
* @return 请求响应的字符串
* @throws Exception
*/
public static String postGeneralUrl(String generalUrl, String contentType, String params, String encoding)
throws Exception {
// 连接目标 URL
URL url = new URL(generalUrl);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();

// 设置请求方法为 POST
connection.setRequestMethod("POST");

// 设置请求头,指定 Content-Type 为 application/json
connection.setRequestProperty("Content-Type", contentType);
connection.setRequestProperty("Connection", "Keep-Alive");
connection.setUseCaches(false);
connection.setDoOutput(true); // 允许输出内容
connection.setDoInput(true); // 允许输入内容

// 写入请求参数(JSON)
try (DataOutputStream out = new DataOutputStream(connection.getOutputStream())) {
out.write(params.getBytes(encoding)); // 发送请求参数
out.flush();
}

// 建立连接
connection.connect();

// 获取响应内容
StringBuilder result = new StringBuilder();
try (BufferedReader in = new BufferedReader(
new InputStreamReader(connection.getInputStream(), encoding))) {
String getLine;
while ((getLine = in.readLine()) != null) {
result.append(getLine);
}
}

return result.toString();
}
}
posted @   我也不想的  阅读(8)  评论(0编辑  收藏  举报
(评论功能已被禁用)
相关博文:
阅读排行:
· winform 绘制太阳,地球,月球 运作规律
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
点击右上角即可分享
微信分享提示