11.1

实验四:smo算法实现与测试

一、实验目的

深入理解支持向量机(SVM)的算法原理,能够使用 Python 语言实现支持向量机的训练与测试,并且使用五折交叉验证算法进行模型训练与评估。

二、实验内容

(1)从 scikit-learn 库中加载 iris 数据集,使用留出法留出 1/3 的样本作为测试集(注

意同分布取样);

(2)使用训练集训练支持向量机—SMO 分类算法;

(3)使用五折交叉验证对模型性能(准确度、精度、召回率和 F1 值)进行评估和选

择;

(4)使用测试集,测试模型的性能,对测试结果进行分析,完成实验报告中实验四的

部分。

三、算法步骤、代码、及结果

   1. 算法伪代码

函数 SMO(训练集 D, 容忍度 ε, 常数 C):

    初始化 α[i] = 0, b = 0, 对所有数据点 i

    循环直到满足停止条件:

        # 在训练集上选择两个不同的样本

        选择一个样本 i,使其违反KKT条件

        如果没有找到合适的样本,返回当前模型

        

        选择第二个样本 j,使其与样本 i 不相同

        计算 Ei = f(xi) - yi Ej = f(xj) - yj

 

        记录 αi 和 αj 的当前值

        

        计算边界条件 L H 以确保 αi 和 αj 在合法范围内

        计算学习率 eta

        计算 αj 的更新值 Δαj

        计算 αi 的更新值 Δαi

 

        更新 αi 和 αj,并确保它们位于 [0, C] 范围内

        

        计算 b 的更新值 (如果需要)

        重新计算 α[i] 和 α[j] 对应的误差 Ei Ej

 

        如果 αi 或 αj 有足够大的变化,继续迭代

        否则,如果两者变化都太小,停止

    返回最终的 α 和 b 参数

 

函数 f(x)

    计算模型的预测值,f(x) = Σ(αi * yi * K(xi, x)) + b

 

   2. 算法主要代码

完整源代码\调用库方法(函数参数说明)

 

复制代码
from sklearn import datasets
from sklearn.svm import SVC
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
import numpy as np

# 1. 数据加载和预处理
# 使用 scikit-learn 中自带的鸢尾花数据集
data = datasets.load_iris()
X = data.data  # 特征
y = data.target  # 标签

# 数据标准化(SVM对特征缩放敏感)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# 2. 构建支持向量机模型
svm_model = SVC(kernel='linear', random_state=42)

# 3. 使用五折交叉验证进行模型训练与评估
# cross_val_score 会自动进行五折交叉验证
cv_scores = cross_val_score(svm_model, X_scaled, y, cv=5)

# 输出每一折的准确率
print("每一折的准确率:", cv_scores)

# 输出平均准确率
print("五折交叉验证的平均准确率:", np.mean(cv_scores))

# 4. 训练与测试
# 划分训练集和测试集(80% 训练,20% 测试)
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)

# 训练 SVM 模型
svm_model.fit(X_train, y_train)

# 在测试集上进行预测
y_pred = svm_model.predict(X_test)

# 计算准确率
test_accuracy = accuracy_score(y_test, y_pred)
print("在测试集上的准确率:", test_accuracy)
复制代码

 

3. 训练结果截图(包括:准确率、精度(查准率)、召回率(查全率)、F1

 

 

posted @   奶油冰激凌  阅读(7)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 25岁的心里话
· 闲置电脑爆改个人服务器(超详细) #公网映射 #Vmware虚拟网络编辑器
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 零经验选手,Compose 一天开发一款小游戏!
· 一起来玩mcp_server_sqlite,让AI帮你做增删改查!!
点击右上角即可分享
微信分享提示