【StatLearn】统计学习中knn算法实验(2)
接着统计学习中knn算法实验(1)的内容
Problem:
- Explore the data before classification using summary statistics or visualization
- Pre-process the data (such as denoising, normalization, feature selection, …)
- Try other distance metrics or distance-based voting
- Try other dimensionality reduction methods
- How to set the k value, if not using cross validation? Verify your idea
- 在对数据分类之前使用对数据进行可视化处理
- 预处理数据(去噪,归一化,数据选择)
- 在knn算法中使用不同的距离计算方法
- 使用其他的降维算法
- 如何在不使用交叉验证的情况下设置k值
使用Parallel coordinates plot做数据可视化,首先对数据进行归一化处理,数据的动态范围控制在[0,1]。注意归一化的处理针对的是每一个fearture。
通过对图的仔细观察,我们挑选出重叠度比较低的feature来进行fearture selection,feature selection实际上是对数据挑选出更易区分的类型作为下一步分类算法的数据。我们挑选出feature序号为(1)、(2)、(5)、(6)、(7)、(10)的feature。个人认为,feature selection是一种简单而粗暴的降维和去噪的操作,但是可能效果会很好。
根据上一步的操作,从Parallel coordinates上可以看出,序号为(1)、(2)、(5)、(6)、(7)、(10)这几个feature比较适合作为classify的feature。我们选取以上几个feature作knn,得到的结果如下:
当K=1 的时候,Accuracy达到了85.38%,并且相比于简单的使用knn或者PCA+knn的方式,Normalization、Featrure Selection的方法使得准确率大大提升。我们也可以使用不同的feature搭配,通过实验得到更好的结果。
MaxAccuracy= 0.8834 when k=17 (Normalization+FeartureSelection+KNN)
Denoising的代码如下:
function[DNData]=DataDenoising(InputData,KillRange) DNData=InputData; %MedianData=median(DNData); for i=2:size(InputData,2) [temp,DNIndex]=sort(DNData(:,i)); DNData=DNData(DNIndex(1+KillRange:end-KillRange),:); end
采用LLE作为降维的手段,通过和以上的几种方案作对比,如下:
MaxAccuracy= 0.9376 when K=23 (LLE dimensionality reduction to 2)
关于LLE算法,参见这篇论文
- Nonlinear dimensionality reduction by locally linear embedding.Sam Roweis & Lawrence Saul.Science, v.290 no.5500 , Dec.22, 2000. pp.2323--2326.
源代码:
StatLearnProj.m
clear; data=load('wine.data.txt'); %calc 5-folder knn Accuracy=[]; for i=1:5 Test=data(i:5:end,:); TestData=Test(:,2:end); TestLabel=Test(:,1); Trainning=setdiff(data,Test,'rows'); TrainningData=Trainning(:,2:end); TrainningLabel=Trainning(:,1); Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel)); end AccuracyKNN=mean(Accuracy,1); %calc PCA Accuracy=[]; %PCA [Coeff,Score,Latent]=princomp(data(:,2:end)); dataPCA=[data(:,1),Score(:,1:6)]; Latent for i=1:5 Test=dataPCA(i:5:end,:); TestData=Test(:,2:end); TestLabel=Test(:,1); Trainning=setdiff(dataPCA,Test,'rows'); TrainningData=Trainning(:,2:end); TrainningLabel=Trainning(:,1); Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel)); end AccuracyPCA=mean(Accuracy,1); BarData=[AccuracyKNN;AccuracyPCA]; bar(1:2:51,BarData'); [D,I]=sort(AccuracyKNN,'descend'); D(1) I(1) [D,I]=sort(AccuracyPCA,'descend'); D(1) I(1) %pre-processing data %Normalization labs1={'1)Alcohol','(2)Malic acid','3)Ash','4)Alcalinity of ash'}; labs2={'5)Magnesium','6)Total phenols','7)Flavanoids','8)Nonflavanoid phenols'}; labs3={'9)Proanthocyanins','10)Color intensity','11)Hue','12)OD280/OD315','13)Proline'}; uniData=[]; for i=2:size(data,2) uniData=cat(2,uniData,(data(:,i)-min(data(:,i)))/(max(data(:,i))-min(data(:,i)))); end figure(); parallelcoords(uniData(:,1:4),'group',data(:,1),'labels',labs1); figure(); parallelcoords(uniData(:,5:8),'group',data(:,1),'labels',labs2); figure(); parallelcoords(uniData(:,9:13),'group',data(:,1),'labels',labs3); %denoising %Normalization && Feature Selection uniData=[data(:,1),uniData]; %Normalization all feature for i=1:5 Test=uniData(i:5:end,:); TestData=Test(:,2:end); TestLabel=Test(:,1); Trainning=setdiff(uniData,Test,'rows'); TrainningData=Trainning(:,2:end); TrainningLabel=Trainning(:,1); Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel)); end AccuracyNorm=mean(Accuracy,1); %KNN PCA Normalization BarData=[AccuracyKNN;AccuracyPCA;AccuracyNorm]; bar(1:2:51,BarData'); %Normalization& FS 1 2 5 6 7 10 we select 1 2 5 6 7 10 feature FSData=uniData(:,[1 2 3 6 7 8 11]); size(FSData) for i=1:5 Test=FSData(i:5:end,:); Trainning=setdiff(FSData,Test,'rows'); TestData=Test(:,2:end); TestLabel=Test(:,1); TrainningData=Trainning(:,2:end); TrainningLabel=Trainning(:,1); Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel)); end AccuracyNormFS1=mean(Accuracy,1); %Normalization& FS 1 6 7 FSData=uniData(:,[1 2 7 8]); for i=1:5 Test=FSData(i:5:end,:); Trainning=setdiff(FSData,Test,'rows'); TestData=Test(:,2:end); TestLabel=Test(:,1); TrainningData=Trainning(:,2:end); TrainningLabel=Trainning(:,1); Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel)); end AccuracyNormFS2=mean(Accuracy,1); figure(); BarData=[AccuracyNorm;AccuracyNormFS1;AccuracyNormFS2]; bar(1:2:51,BarData'); [D,I]=sort(AccuracyNorm,'descend'); D(1) I(1) [D,I]=sort(AccuracyNormFS1,'descend'); D(1) I(1) [D,I]=sort(AccuracyNormFS2,'descend'); D(1) I(1) %denoiding %Normalization& FS 1 6 7 FSData=uniData(:,[1 2 7 8]); for i=1:5 Test=FSData(i:5:end,:); Trainning=setdiff(FSData,Test,'rows'); Trainning=DataDenoising(Trainning,2); TestData=Test(:,2:end); TestLabel=Test(:,1); TrainningData=Trainning(:,2:end); TrainningLabel=Trainning(:,1); Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel)); end AccuracyNormFSDN=mean(Accuracy,1); figure(); hold on plot(1:2:51,AccuracyNormFSDN); plot(1:2:51,AccuracyNormFS2,'r'); %other distance metrics Dist='cityblock'; for i=1:5 Test=uniData(i:5:end,:); TestData=Test(:,2:end); TestLabel=Test(:,1); Trainning=setdiff(uniData,Test,'rows'); TrainningData=Trainning(:,2:end); TrainningLabel=Trainning(:,1); Accuracy=cat(1,Accuracy,CalcAccuracyPlus(TestData,TestLabel,TrainningData,TrainningLabel,Dist)); end AccuracyNormCity=mean(Accuracy,1); BarData=[AccuracyNorm;AccuracyNormCity]; figure(); bar(1:2:51,BarData'); [D,I]=sort(AccuracyNormCity,'descend'); D(1) I(1) %denoising FSData=uniData(:,[1 2 7 8]); Dist='cityblock'; for i=1:5 Test=FSData(i:5:end,:); TestData=Test(:,2:end); TestLabel=Test(:,1); Trainning=setdiff(FSData,Test,'rows'); Trainning=DataDenoising(Trainning,3); TrainningData=Trainning(:,2:end); TrainningLabel=Trainning(:,1); Accuracy=cat(1,Accuracy,CalcAccuracyPlus(TestData,TestLabel,TrainningData,TrainningLabel,Dist)); end AccuracyNormCityDN=mean(Accuracy,1); figure(); hold on plot(1:2:51,AccuracyNormCityDN); plot(1:2:51,AccuracyNormCity,'r'); %call lle data=load('wine.data.txt'); uniData=[]; for i=2:size(data,2) uniData=cat(2,uniData,(data(:,i)-min(data(:,i)))/(max(data(:,i))-min(data(:,i)))); end uniData=[data(:,1),uniData]; LLEData=lle(uniData(:,2:end)',5,2); %size(LLEData) LLEData=LLEData'; LLEData=[data(:,1),LLEData]; Accuracy=[]; for i=1:5 Test=LLEData(i:5:end,:); TestData=Test(:,2:end); TestLabel=Test(:,1); Trainning=setdiff(LLEData,Test,'rows'); Trainning=DataDenoising(Trainning,2); TrainningData=Trainning(:,2:end); TrainningLabel=Trainning(:,1); Accuracy=cat(1,Accuracy,CalcAccuracyPlus(TestData,TestLabel,TrainningData,TrainningLabel,'cityblock')); end AccuracyLLE=mean(Accuracy,1); [D,I]=sort(AccuracyLLE,'descend'); D(1) I(1) BarData=[AccuracyNorm;AccuracyNormFS2;AccuracyNormFSDN;AccuracyLLE]; figure(); bar(1:2:51,BarData'); save('ProcessingData.mat');
CalcAccuracy.m
function Accuracy=CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel) %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %calculate the accuracy of classify %TestData:M*D matrix D stand for dimension,M is sample %TrainningData:T*D matrix %TestLabel:Label of TestData %TrainningLabel:Label of Trainning Data %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% CompareResult=[]; for k=1:2:51 ClassResult=knnclassify(TestData,TrainningData,TrainningLabel,k); CompareResult=cat(2,CompareResult,(ClassResult==TestLabel)); end SumCompareResult=sum(CompareResult,1); Accuracy=SumCompareResult/length(CompareResult(:,1));
CalcAccuracyPlus.m
function Accuracy=CalcAccuracyPlus(TestData,TestLabel,TrainningData,TrainningLabel,Dist) %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %just as CalcAccuracy,but add distance metrics %calculate the accuracy of classify %TestData:M*D matrix D stand for dimension,M is sample %TrainningData:T*D matrix %TestLabel:Label of TestData %TrainningLabel:Label of Trainning Data %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% CompareResult=[]; for k=1:2:51 ClassResult=knnclassify(TestData,TrainningData,TrainningLabel,k,Dist); CompareResult=cat(2,CompareResult,(ClassResult==TestLabel)); end SumCompareResult=sum(CompareResult,1); Accuracy=SumCompareResult/length(CompareResult(:,1));