public class KnnTest
{
public static void readFileToList(String path, List<List<Double>> list)
{
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(path));
while (br.ready()) {
String line = br.readLine();
if (line.trim().isEmpty()) {
continue;
}
String[] tokens = line.split(" ");
List<Double> box = new ArrayList<Double>();
for (String num : tokens) {
box.add(Double.parseDouble(num));
}
list.add(box);
}
}
catch (IOException ex) {
ex.printStackTrace();
}
}
public static void main(String[] args)
{
int length = 2;
String dataFile = "data.txt";
String testFile = "test.txt";
KNN knn = new KNN();
try {
List<List<Double>> dataList = new ArrayList<List<Double>>();
List<List<Double>> testList = new ArrayList<List<Double>>();
readFileToList(dataFile, dataList);
readFileToList(testFile, testList);
for (List<Double> test : testList) {
for (Double d : test) {
System.out.print(d + " ");
}
String category = knn.knn(dataList, test, length);
System.out.println(Math.round(Float.parseFloat(category)));
}
}
catch (Exception ex) {
ex.printStackTrace();
}
}
}
class KNN
{
private static Comparator<Node> comparator = new Comparator<Node>()
{
public int compare(Node n1, Node n2)
{
if (n1.getDistans() > n2.getDistans()) {
return 1;
}
return 0;
}
};
private int[] getRankNumbers(int n, int max)
{
int[] result = new int[n];
int current = 0;
back: for (int i = 0; i < n; i++) {
current = (int) (Math.random() * max);
for (int j = 0; j < i; j++) {
if (current == result[j]) {
i--;
continue back;
}
}
result[i] = current;
}
return result;
}
public String knn(List<List<Double>> example, List<Double> test, int k)
{
PriorityQueue<Node> pq = new PriorityQueue<Node>(k, comparator);
int[] rand = getRankNumbers(k, example.size());
for (int i = 0; i < k; i++) {
List<Double> list = example.get(rand[i]);
String category = list.get(list.size() - 1).toString();
Node node = new Node(rand[i], calDistans(test, list), category);
pq.add(node);
}
for (int i = 0; i < example.size(); i++) {
List<Double> list = example.get(i);
double distans = calDistans(test, list);
Node node = pq.peek();
if (node.getDistans() > distans) {
pq.remove();
pq.add(new Node(i, distans, list.get(list.size() - 1).toString()));
}
}
return getMostCategory(pq);
}
private String getMostCategory(PriorityQueue<Node> pq)
{
Map<String, Integer> rankMapping = new HashMap<String, Integer>(pq.size(), 1);
for (int i = 0; i < pq.size(); i++) {
Node node = pq.remove();
String category = node.getCategory();
if (rankMapping.containsKey(category)) {
rankMapping.put(category, rankMapping.get(category) + 1);
}
else {
rankMapping.put(category, 1);
}
}
int index = -1;
int count = 0;
Object[] data = rankMapping.keySet().toArray();
for (int i = 0; i < data.length; i++) {
if (rankMapping.get(data[i]) > count) {
index = i;
count = rankMapping.get(data[i]);
}
}
return data[index].toString();
}
public double calDistans(List<Double> list1, List<Double> list2)
{
double result = 0.00;
for (int i = 0; i < list1.size(); i++) {
result += (list1.get(i) - list2.get(i)) * (list1.get(i) - list2.get(i));
}
return result;
}
static class Node
{
private int index;
private double distans;
private String category;
public Node(int index, double distans, String category)
{
this.index = index;
this.distans = distans;
this.category = category;
}
public int getIndex()
{
return index;
}
public void setIndex(int index)
{
this.index = index;
}
public double getDistans()
{
return distans;
}
public void setDistans(double distans)
{
this.distans = distans;
}
public String getCategory()
{
return category;
}
public void setCategory(String category)
{
this.category = category;
}
}
}