数学之路(3)-机器学习(3)-机器学习算法-SVM[7]

 

本博客所有内容是原创,未经书面许可,严禁任何形式的转载

http://blog.csdn.net/u010255642

 

 

根据SMO的算法描述,用python实现,部分代码如下,定义了一个svm_pmcp类,所有的运算在svm_pmcp完成,这样便于封装和实际应用

 

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#麦好:myhaspl@qq.com
#http://blog.csdn.net/u010255642
#svm算法
import numpy as np import math
import matplotlib.pyplot as plt

#内积线性核函数
def arraydot(x,y):
    return x.T*y

#svm参数与计算类
class svm_pmcp:
    def __init__(self):
        '''初始化参数变量'''
        self.alpha = []
        self.samples=[]
        self.labels=[]
        self.boundalpha=[]
    def samples_init(self,samples):
        '''样子及乘子参数初始化'''
        for (mysp,mylb) in samples:
            self.samples.append(mysp)
            self.labels.append(mylb)
        #初始化拉格朗日乘子alpha为0
        for i in xrange(0,len(self.samples)):
            self.alpha.append(0)
        #初始化b为0
        self.b = 0
    def kernel_init(self,func):
        '''指定核函数'''
        self.kernel_func=func
    def lagrange_multiplier(self,i):
        '''求拉格朗日乘子'''
        pass
    def svmoutput(self,i):
        pass
    def tol_init(self,mytol):
        self.tol=mytol
    def eps_inbit(self,myeps):
        self.eps=myeps
    def c_init(selfm,myc):
        self.c=myc
    def choicesecond_max(self,nte):
        pass
    def choicesecond_random(self):
        pass
    def get_lh(self,i,j):
        pass
    def update_b(self):
        pass
    def update_w(self):
        pass
    def alpha_nozero_noc(self):
        pass
    def store_alpha(self,i1,a1,i2,a2):
        pass

    def takestep(i1,i2,e2,alpha2):
        if (i1==i2):
            return False
        alpha1=lagrange_multiplier(i1)
        y1=labels[i1]
        e1=svmoutput(i1)-y1
        s=y1*y2
        l,h=get_lh(i2,i1)
        if l==h:
            return False
        k11=kernel_func(self.samples[i1],self.samples[i1])
        k12=kernel_func(self.samples[i1],self.samples[i2])
        k13=kernel_func(self.samples[i2],self.samples[i2])
        eta=float(2*k12-k11-k22)
        if (eta<0):
            a2=alpha2-y2*(e1-e2))/eta
            if a2<l:
                a2=l
            elif a2>h:
                a2=h
        else:
            lobj=obfuncl()
            hobj=obfunch()
            if lobj>hobj+self.eps:
                a2=l
            elif lobj<hobj-self.eps:
                a2=h
            else:
                a2=alpha2
        if abs(a2-alpha2)<self.eps*(a2+alph2+self.eps):
            return False
        a1=alpha1+s*(alpha2-a2)
        update_b()
        update_w()
        store_alpha(i1,a1,i2,a2)
        return True
        
                
            
            
            



    def examineexample(myi):
        y2=labels[myi]
        alpha2=lagrange_multiplier(myi)
        e2=svmoutput(myi)-y2
        r2=e2*y2
        if  ((r2<-self.tol and alpha2<self.c) or (r2>self.tol and alpha2>0):
             if (len(self.boundalpha)>0):
                 myj=choicesecond_max(e)
                 if takestep(myj,myi,e2,alpha2):
                     return 1
             else:
                 myj=choicesecond_random(myi)
                 if takestep(myj,myi,e2,alpha2):
                     return 1
        return 0
    
    def loop1(self,nc):
        for i in xrange(0,len(mysvm.samples)):
            nc+=examineexample(i)
    def loop2(self,nc):
        for i in alpha_nozero_noc():
            nc+=examineexample(i)

    def mainroutine(self):
        numchanged=0
        examineall=True
        while (numchanged>0 or examineall):
            numchanged=0
            if examineall:
                numchanged=loop1(numchanged)
            else:
                numchanged=loop2(numchanged)
            examineall=not examineall










def mainsvm(mysamples):
    mysvm = svm_pmcp()
    mysvm.samples_init(mysamples)
    mysvm.kernel_init(arraydot)
    mysvm.tol_init(0.001)
    mysvm.eps_init(0.00001)
    mysvm.c_init(1)
    mysvm.mainroutine()









后面关于svm的章节将提供类下载地址及调用代码

 

 

posted @ 2013-07-22 19:34  爱生活,爱编程  阅读(238)  评论(0编辑  收藏  举报