KNN算法及java实现

 

public class KNNNode {

    private int index;// 元祖标号
    private double distance;// 与测试元祖的距离
    private String c;// 所属类别

    public KNNNode(int index, double distance, String c) {

        super();
        this.index = index;
        this.distance = distance;
        this.c = c;

    }

    public int getIndex() {
        return index;
    }

    public void setIndex(int index) {
        this.index = index;
    }

    public double getDistance() {
        return distance;
    }

    public void setDistance(double distance) {
        this.distance = distance;
    }
    
    
    public String getC(){
        return c;
    }
    
    public void setC(){
        this.c=c;
    }
}
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;

//KNN算法主体
public class KNN {
//设置优先级队列的函数,距离越大,优先级越高
    private Comparator<KNNNode> comparator =new Comparator<KNNNode>(){
        public int compare(KNNNode o1,KNNNode o2){
            if(o1.getDistance()>=o2.getDistance()){
                return -1;
            }
            else{
                return 1;
            }
        }
};
    
    /**
     * 获取K个不同的随机数
     * @param k随机数的个数
     * @param max随机数最大的范围
     * @return 生成随机数数组
     */
    
    public List<Integer> getRandKNum(int k,int max){
        List<Integer> rand=new ArrayList<Integer>(k);
        for(int i=0;i<k;i++){
            int temp=(int)(Math.random()*max);
            if(!rand.contains(temp)){
                rand.add(temp);
            }
            else{
                i--;
            }
        }
        
        return rand;
    }
    
    /**
     * 计算测试元祖和训练元组之间的距离
     * @param d1测试元祖
     * @param d2训练元祖
     * @return 距离值
     */
    
    public double calDistance(List<Double> d1,List<Double>d2){
        double distance=0.0;
        for(int i=0;i<d1.size();i++){
            distance+=(d1.get(i)-d2.get(i))*(d1.get(i)-d2.get(i));
        }
        
        return distance;
        
    }
    
    /**
     * 执行Knn算法,获取测试元组的类别
     * @param datas 训练数据集
     * @param 测试元组
     * @param k 设定的k值
     * @return  测试元组的类别
     */
    
    
    public String knn(List<List<Double>> datas,List<Double> testData,int k){
        PriorityQueue<KNNNode> pq=new PriorityQueue<KNNNode>(k,comparator);
        List<Integer> randNum=getRandKNum(k,datas.size());
        for(int i=0;i<k;i++){
            int index=randNum.get(i);
            List<Double> currData=datas.get(index);
            String c=currData.get(currData.size()-1).toString();
            KNNNode node=new KNNNode(index, calDistance(testData,currData),c);
            pq.add(node);
        }
        for (int i = 0; i < datas.size(); i++) {
            List<Double> t=datas.get(i);
            double distance=calDistance(testData, t);
            KNNNode top=pq.peek();
            if(top.getDistance()>distance){
                pq.remove();
                pq.add(new KNNNode(i, distance,t.get(t.size()-1).toString()));
            }
        }
        return getMostClass(pq);
    }
    
    /**
     * 获得所得到的k个最近邻元组的多数类
     * @param pq存储k个最近邻元组的优先级队列
     * @return 多数类的名称
     */
    
    
    private String getMostClass(PriorityQueue<KNNNode> pq){
        Map<String, Integer> classCount=new HashMap<String, Integer>();
        int pqsize=pq.size();
        for(int i=0;i<pqsize;i++){
            KNNNode node=pq.remove();
            String c=node.getC();
            if(classCount.containsKey(c)){
                
                
                classCount.put(c,classCount.get(c)+1);
            }
            else{
                
                classCount.put(c,1);
            }
        }
        
        int maxIndex=-1;
        int maxCount=0;
        Object[] classes=classCount.keySet().toArray();
        for(int i=0;i<classes.length;i++){
            if(classCount.get(classes[i])>maxCount){
                maxIndex=i;
                maxCount=classCount.get(classes[i]);
            }
        }
        return classes[maxIndex].toString();        
                
        
    }
    
    
    
    
    
    
    
    
    
    
    
}
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class TestCNN {
    /**
     * 从文件中读取数据
     * 
     * @param datas存储数据的集合对象
     * @param path数据文件的路径
     */

    public void read(List<List<Double>> datas, String path) {
        try {
            BufferedReader bReader = new BufferedReader(new FileReader(new File(path)));
            String reader;

            reader = bReader.readLine();
            while (reader != null) {
                String t[] = reader.split(" ");

                ArrayList<Double> list = new ArrayList<Double>();
                for (int i = 0; i < toString().length(); i++) {
                    list.add(Double.parseDouble(t[i]));
                }
                datas.add(list);
                reader = bReader.readLine();
            }
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }

    /**
     * 程序执行入口
     * 
     * @param args
     * 
     */
    public static void main(String[] args) {

    TestCNN testCNN=new TestCNN();
    String datafile=new File("").getAbsolutePath()+File.separator+"cqudata\\datafile.txt";
    String testfile=new File("").getAbsolutePath()+File.separator+"cqudata\\testfile.txt";
    List<List<Double>> datas=new ArrayList<List<Double>>();
    List<List<Double>> testDatas=new ArrayList<List<Double>>();
    testCNN.read(datas, datafile);
    testCNN.read(testDatas, testfile);
    
    KNN knn=new KNN();
    for(int i=0;i<testDatas.size();i++){
        
        List <Double> test=testDatas.get(i);
        System.out.println("测试元组为:");
        for (int j = 0; j < test.size(); j++) {
            System.out.println(test.get(j)+" ");
        }
        System.out.println("类别为: ");
        System.out.println(Math.round(Float.parseFloat(knn.knn(datas, test, 3))));
    }
}
}

 训练数据:
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0
实验数据:
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5
程序运行结果:
测试元组: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 类别为: 1
测试元组: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 类别为: 1
测试元组: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 类别为: 1
测试元组: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 类别为: 0

posted @ 2013-09-08 21:40  ilxx1988  阅读(1344)  评论(0编辑  收藏  举报