SMO推导和代码-记录毕业论文4
SMO的数学公式通过Platt的论文和看这个博客:http://www.cnblogs.com/jerrylead/archive/2011/03/18/1988419.html,大概弄懂了。推导以后再写,贴上一个自己写的SMO的代码。
function [ model ] = smoSolver( designMatrix, targetGroup ) numChanged = 0; examineAll = 1; tolerance = 0.001; total_runtimes = 5000; epsilon = 0.01; n_samps = size(designMatrix,1); kernelMatrix = zeros(n_samps, n_samps); for i = 1 : n_samps for j = i : n_samps kernelMatrix(i,j) = dot(designMatrix(i,:), designMatrix(j,:)); kernelMatrix(j,i) = kernelMatrix(i,j); end end alphaArray = rand(1, n_samps); C = 1; b = 0; u = alphaArray .* targetGroup * kernelMatrix - b; E = u - targetGroup; iter = 1 ; while(numChanged > 0 || examineAll) numChanged = 0; if(examineAll) for i = 1 : n_samps numChanged = numChanged + examineExample(i); end else for i = 1 : n_samps if abs(alphaArray(i)) > tolerance && abs(alphaArray(i)-C) > tolerance numChanged = numChanged + examineExample(i); end end end if(examineAll == 1) examineAll = 0; elseif (numChanged == 0) examineAll = 1; end iter = iter + 1; if iter > total_runtimes break; end end function changed = examineExample(i) y2 = targetGroup(i); alpha2 = alphaArray(i); E2 = E(i); r2 = E2 * y2; %if((r2 < -0.01 && alpha2 < C) || (r2 > 0.01 && alpha2 > 0)) if( (r2 < -tolerance && abs(alpha2) < tolerance) || ... (r2 > tolerance && abs(alpha2-C) < tolerance) || ... (abs(r2) > tolerance && alpha2 < C-tolerance && alpha2 > tolerance ) ) non_zero_non_c = find(abs(alphaArray)>0.01 & abs(alphaArray-C)>0.01); if length(non_zero_non_c) > 1 maxIdx = 1; max = 0; for idx = 1 : n_samps if abs(E(idx) - E2) > max max = abs(E(idx) - E2); maxIdx = idx; end end if takeStep(maxIdx, i) changed = 1; return; end end for k = 1 : length(non_zero_non_c) i1 = non_zero_non_c(k); if takeStep(i1, i); changed = 1; return; end end for k = 1 : n_samps if takeStep(k, i) changed = 1; return; end end end changed = 0; return; end function tf = takeStep(i1, i2) if i1 == i2 tf = 0; return; end alpha1 = alphaArray(i1); a1 = 0; alpha2 = alphaArray(i2); a2 = 0; y1 = targetGroup(i1); y2 = targetGroup(i2); E1 = E(i1); E2 = E(i2); s = y1 * y2; if s > 0 L = max([0,alpha1+alpha2-C]); H = min([C,alpha1+alpha2]); else L = max([0,alpha2-alpha1]); H = min([C, C+alpha2-alpha1]); end if L == H tf = 0; return; end k11 = kernelMatrix(i1,i1); k12 = kernelMatrix(i1,i2); k22 = kernelMatrix(i2,i2); eta = k11 + k22 - 2*k12; if(eta > 0) a2 = alpha2 + y2 * (E1-E2)/eta; if(a2 < L) a2 = L; elseif (a2 > H) a2 = H; end else a2 = L; a1 = alpha1 + s*(alpha2-a2); alphaArrayTmp = alphaArray; alphaArrayTmp(i1) = a1; alphaArrayTmp(i2) = a2; alphaArrayTmp = alphaArrayTmp .* targetGroup; Lobj = 0.5 * alphaArrayTmp * kernelMatrix * alphaArrayTmp' - sum(alphaArrayTmp); a2 = H; a1 = alpha1 + s*(alpha2-a2); alphaArrayTmp = alphaArray; alphaArrayTmp(i1) = a1; alphaArrayTmp(i2) = a2; alphaArrayTmp = alphaArrayTmp .* targetGroup; Hobj = 0.5 * alphaArrayTmp * kernelMatrix * alphaArrayTmp' - sum(alphaArrayTmp); if(Lobj < Hobj - epsilon) a2 = L; elseif(Lobj > Hobj + epsilon) a2 = H; else a2 = alpha2; end end if (abs(a2-alpha2) < 0.01*(a2+alpha2+epsilon)) tf = 0; return; end a1 = alpha1 + s*(alpha2-a2); b1 = E1 + y1*(a1 - alpha1)*kernelMatrix(i1,i1)+y2*(a2 - alpha2)*kernelMatrix(i1,i2)+b; b2 = E2 + y1*(a1 - alpha1)*kernelMatrix(i1,i2)+y2*(a2 - alpha2)*kernelMatrix(i2,i2)+b; if(a1 > 0 && a1 < C) b = b1; elseif(a2 > 0 && a2 < C) b = b2; else b = (b1+b2)/2; end alphaArray(i1) = a1; alphaArray(i2) = a2; u = alphaArray .* targetGroup * kernelMatrix - b; E = u - targetGroup; tf = 1; return; end u = alphaArray .* targetGroup * kernelMatrix - b; alphaIdx = find(abs(alphaArray) > tolerance); model.targetGroup = targetGroup(alphaIdx); model.alpha = alphaArray(alphaIdx); model.supVec = designMatrix(alphaIdx, :); model.b = b; end
smoPredict:
function [ targetGroup ] = smoPredict( model, designMatrix ) kernelMatrix = model.supVec * designMatrix'; u = sum(kernelMatrix' .* model.alpha .* model.targetGroup) - model.b; targetGroup = sign(u); end