alex_bn_lee

导航

< 2025年3月 >
23 24 25 26 27 28 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 29
30 31 1 2 3 4 5

统计

【364】SVM 通过 sklearn 可视化实现

先看下效果图:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# 先调入需要的模块
 
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
import seaborn as sb
 
# 生成几个数据点
 
data = np.array([
    [0.1, 0.7],
    [0.3, 0.6],
    [0.4, 0.1],
    [0.5, 0.4],
    [0.8, 0.04],
    [0.42, 0.6],
    [0.9, 0.4],
    [0.6, 0.5],
    [0.7, 0.2],
    [0.7, 0.67],
    [0.27,0.8],
    [0.5, 0.72]
    ])
     
 
target = [1] * 6 + [0] * 6
 
x_line = np.linspace(0, 1, 100)
y_line = 1 - x_line
plt.scatter(data[:6, 0], data[:6, 1], marker='o', s=100, lw=3)
plt.scatter(data[6:, 0], data[6:, 1], marker='x', s=100, lw=3)
plt.plot(x_line, y_line)
 
# 定义计算域、文字说明等
 
C = 0.0001  # SVM regularization parameter, since Scikit-learn doesn't allow C=0
# linear_svc = svm.SVC(kernel='linear', C=C).fit(data, target)
 
# create a mesh to plot in
h = 0.002
x_min, x_max = data[:, 0].min() - 0.2, data[:, 0].max() + 0.2
y_min, y_max = data[:, 1].min() - 0.2, data[:, 1].max() + 0.2
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                     np.arange(y_min, y_max, h))
 
# title for the plots
titles = ['SVC with linear kernel',
          'SVC with RBF kernel',
          'SVC with polynomial (degree 3) kernel']
 
# RBF Kernel
 
plt.figure(figsize=(16, 15))
 
for i, gamma in enumerate([1, 5, 15, 35, 45, 55]):
    rbf_svc = svm.SVC(kernel='rbf', gamma=gamma, C=C).fit(data, target)
     
    # ravel - flatten
    # c_ - vstack
    # #把后面两个压扁之后变成了x1和x2,然后进行判断,得到结果在压缩成一个矩形
    Z = rbf_svc.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
     
    plt.subplot(3, 2, i + 1)
    plt.subplots_adjust(wspace=0.4, hspace=0.4)
    plt.contourf(xx, yy, Z, cmap=plt.cm.ocean, alpha=0.6)
 
    # Plot the training points
    plt.scatter(data[:6, 0], data[:6, 1], marker='o', color='r', s=100, lw=3)
    plt.scatter(data[6:, 0], data[6:, 1], marker='x', color='k', s=100, lw=3)
     
    plt.title('RBF SVM with $\gamma=$' + str(gamma))
     
plt.show()

 

posted on   McDelfino  阅读(3949)  评论(0编辑  收藏  举报

编辑推荐:
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)
历史上的今天:
2018-01-29 【299】◀▶ IDL - LIST 函数
2018-01-29 【298】◀▶ IDL 系统过程&函数
点击右上角即可分享
微信分享提示