SMO推导和代码-记录毕业论文4

SMO的数学公式通过Platt的论文和看这个博客:http://www.cnblogs.com/jerrylead/archive/2011/03/18/1988419.html,大概弄懂了。推导以后再写,贴上一个自己写的SMO的代码。

function [ model ] = smoSolver( designMatrix, targetGroup )
numChanged = 0;
examineAll = 1;
tolerance = 0.001; total_runtimes = 5000; epsilon = 0.01;
n_samps = size(designMatrix,1);
kernelMatrix = zeros(n_samps, n_samps);
for i = 1 : n_samps
    for j = i : n_samps
        kernelMatrix(i,j) = dot(designMatrix(i,:), designMatrix(j,:));
        kernelMatrix(j,i) =  kernelMatrix(i,j);
    end
end
alphaArray = rand(1, n_samps);
C = 1; b = 0;
u = alphaArray .* targetGroup * kernelMatrix - b;
E = u - targetGroup;
iter = 1 ;
while(numChanged > 0 || examineAll)
    numChanged = 0;
    if(examineAll)
        for i = 1 : n_samps
            numChanged = numChanged + examineExample(i);
        end
    else
        for i = 1 : n_samps
            if abs(alphaArray(i)) > tolerance && abs(alphaArray(i)-C) > tolerance
                numChanged = numChanged + examineExample(i);
            end
        end
    end
    if(examineAll == 1)
        examineAll = 0;
    elseif (numChanged == 0)
        examineAll = 1;
    end
    iter = iter + 1;
    if iter > total_runtimes
        break;
    end
end

function changed = examineExample(i)
    y2 = targetGroup(i);
    alpha2 = alphaArray(i);
    E2 = E(i);
    r2 = E2 * y2;
    %if((r2 < -0.01 && alpha2 < C) || (r2 > 0.01 && alpha2 > 0))
    if( (r2 < -tolerance && abs(alpha2) < tolerance) || ...
         (r2 > tolerance && abs(alpha2-C) < tolerance) || ...
         (abs(r2) > tolerance && alpha2 < C-tolerance && alpha2 > tolerance ) )
        non_zero_non_c = find(abs(alphaArray)>0.01 & abs(alphaArray-C)>0.01);
        if length(non_zero_non_c) > 1
            maxIdx = 1; max = 0;
            for idx = 1 : n_samps
                if abs(E(idx) - E2) > max
                    max = abs(E(idx) - E2);
                    maxIdx = idx;
                end
            end
            if takeStep(maxIdx, i)
                changed = 1; return;
            end
        end
        
        for k = 1 : length(non_zero_non_c)
            i1 = non_zero_non_c(k);
            if takeStep(i1, i);
                changed = 1; return;
            end
        end
        
        for k = 1 : n_samps
            if takeStep(k, i)
                changed = 1; return;
            end
        end
    end
    changed = 0; return;
end

function tf = takeStep(i1, i2)
if i1 == i2
    tf = 0; return;
end

alpha1 = alphaArray(i1); a1 = 0;
alpha2 = alphaArray(i2); a2 = 0;
y1 = targetGroup(i1); y2 = targetGroup(i2);
E1 = E(i1); E2 = E(i2);
s = y1 * y2;
if s > 0
    L = max([0,alpha1+alpha2-C]);
    H = min([C,alpha1+alpha2]);
else
    L = max([0,alpha2-alpha1]);
    H = min([C, C+alpha2-alpha1]);
end

if L == H
    tf = 0; return;
end
k11 = kernelMatrix(i1,i1);
k12 = kernelMatrix(i1,i2);
k22 = kernelMatrix(i2,i2);
eta = k11 + k22 - 2*k12;
if(eta > 0)
    a2 = alpha2 + y2 * (E1-E2)/eta; 
    if(a2 < L) 
        a2 = L;
    elseif (a2 > H)
        a2 = H;
    end
else
    a2 = L;
    a1 = alpha1 + s*(alpha2-a2);
    alphaArrayTmp = alphaArray; alphaArrayTmp(i1) = a1; alphaArrayTmp(i2) = a2;
    alphaArrayTmp = alphaArrayTmp .* targetGroup;
    Lobj = 0.5 * alphaArrayTmp * kernelMatrix * alphaArrayTmp' - sum(alphaArrayTmp);
    
    a2 = H;
    a1 = alpha1 + s*(alpha2-a2);
    alphaArrayTmp = alphaArray; alphaArrayTmp(i1) = a1; alphaArrayTmp(i2) = a2;
    alphaArrayTmp = alphaArrayTmp .* targetGroup;
    Hobj = 0.5 * alphaArrayTmp * kernelMatrix * alphaArrayTmp' - sum(alphaArrayTmp);
    if(Lobj < Hobj - epsilon)
        a2 = L;
    elseif(Lobj > Hobj + epsilon)
        a2 = H;
    else
        a2 = alpha2;
    end
end
if (abs(a2-alpha2) < 0.01*(a2+alpha2+epsilon))
    tf = 0; return;
end

a1 = alpha1 + s*(alpha2-a2);

b1 = E1 + y1*(a1 - alpha1)*kernelMatrix(i1,i1)+y2*(a2 - alpha2)*kernelMatrix(i1,i2)+b;
b2 = E2 + y1*(a1 - alpha1)*kernelMatrix(i1,i2)+y2*(a2 - alpha2)*kernelMatrix(i2,i2)+b;
if(a1 > 0 && a1 < C)
    b = b1;
elseif(a2 > 0 && a2 < C)
    b = b2;
else
    b = (b1+b2)/2;
end
alphaArray(i1) = a1; alphaArray(i2) = a2;

u = alphaArray .* targetGroup * kernelMatrix - b;
E = u - targetGroup;

tf = 1; return;
end

u = alphaArray .* targetGroup * kernelMatrix - b;
alphaIdx = find(abs(alphaArray) > tolerance); 
model.targetGroup = targetGroup(alphaIdx);
model.alpha = alphaArray(alphaIdx);
model.supVec = designMatrix(alphaIdx, :);
model.b = b;

end

smoPredict:

function [ targetGroup ] = smoPredict( model, designMatrix )
kernelMatrix = model.supVec * designMatrix';
u = sum(kernelMatrix' .* model.alpha .* model.targetGroup) - model.b;
targetGroup = sign(u);
end
posted @ 2016-01-09 20:28  Key_Ky  阅读(743)  评论(0编辑  收藏  举报