K-means聚类的java实现

      今天把自己写的一个机器学习算法库中的K-means算法整理了一下,因为这个算法较其他的相比相对独立,可以单独贴出来,不会引用太多的其他类(不过还是有点引用,不过引用些简单的功能,看类名就知道什么意思了)。

基本功能和规则为:

1.当然是进行k-means算法,对数据集(这里使用二维数组来表示数据集,行数为数据总数,列数为数据维度)进行N维聚类

2.可以指定收敛的阀值(convergenceDis默认为0.0001)

3.为避免局部最小,可以指定重复运行次数,通过设定replicates的数值来指定,默认为0,即只重复一次聚类过程

4.测试数据格式为每一行代表一个输入,用空格分隔输入的各个维度,为了计算结果不出太大意外,建议对原始数据进行归一化

首先上骨架代码:

View Code
package org.tadoo.ml.cluster.kmeans;

import java.util.Random;

import org.tadoo.ml.exception.ClusterException;
import org.tadoo.ml.util.ArrayCompute;
import org.tadoo.ml.util.Utils;

/**
* 使用K-means方法进行聚类
*
* <p>time:2011-6-1</p>
*
@author T. QIN
*/
public class KmeansCluster
{
private double[][] dataSet = null;

private int k = 0;

private double[][] centers = null;

private double totalSumOfdistances = 0;

private boolean convergence = false;

private int iter;

private double convergenceDis = 0.0001;

private int replicates = 0;

private KMCResult[] kmcresults = null;

public KmeansCluster(double[][] x, int k) throws ClusterException
{
if (x == null || x.length == 0)
{
throw new ClusterException("输入数据不可为空。");
}
this.dataSet = x;
this.k = k;
this.centers = new double[k][dataSet[0].length];
}

private void initKCenters()
{
Random r
= new Random();
int rn = r.nextInt(dataSet.length);
for (int i = 0; i < this.k; i++)//初始化k个中心
{
for (int j = 0; j < dataSet[0].length; j++)
{
centers[i][j]
= dataSet[rn][j];
}
rn
= r.nextInt(dataSet.length);
}
}

public void train()
{
if (replicates > 1)
{
kmcresults
= new KMCResult[replicates];
for (int i = 0; i < replicates; i++)
{
beginTrain();
kmcresults[i]
= new KMCResult();
kmcresults[i].centers
= this.centers;
kmcresults[i].sum
= this.totalSumOfdistances;
kmcresults[i].iters
= this.iter;
this.centers = new double[k][dataSet[0].length];
}
}
else
{
beginTrain();
}
}

private void beginTrain()
{
int rows = dataSet.length;
int cols = dataSet[0].length;
int[] c = new int[rows];//保存每个数据属于哪个中心
int vote = 0;//如果某一中心收敛,则投票数可加一
iter = 0;
initKCenters();
convergence
= false;
while (!convergence)
{
double minDistance = Double.MAX_VALUE;
double currentDis = 0.0;
int count = 0;
int changedCenterNumber = 0;
double[] temp = new double[cols];
totalSumOfdistances
= 0;
for (int i = 0; i < rows; i++)
{
for (int j = 0; j < this.k; j++)
{
currentDis
= Utils.distance(dataSet[i], centers[j]);
if (currentDis < minDistance)
{
minDistance
= currentDis;
c[i]
= j;
}
}
totalSumOfdistances
+= minDistance;
minDistance
= Double.MAX_VALUE;
}
for (int i = 0; i < this.k; i++)
{
for (int j = 0; j < c.length; j++)
{
if (c[j] == i)
{
temp
= Utils.add(temp, dataSet[j]);
count
++;
}
}
if (count != 0)
{
temp
= ArrayCompute.devideC(temp, count);
if (isCenterConvergence(centers[i], temp))
{
vote
++;
}
centers[i]
= temp;
changedCenterNumber
++;
}
count
= 0;

temp
= new double[cols];
}
iter
++;
if (vote == changedCenterNumber)
{
convergence
= true;
}
vote
= 0;
changedCenterNumber
= 0;
}
}

/**
* 判断某中心是否收敛
*
*
@param center
*
@param pCenter
*
@return
*
@see:
*/
private boolean isCenterConvergence(double[] center, double[] pCenter)
{
boolean result = true;
double[] distance = ArrayCompute.minus(center, pCenter);
for (int i = 0; i < distance.length; i++)
{
if (Math.abs(distance[i]) > convergenceDis)
{
result
= false;
}
}
return result;
}

/**
* dataSet的 get() 方法
*
@return double[][] dataSet.
*/
public double[][] getDataSet()
{
return dataSet;
}

/**
* dataSet的 set() 方法
*
@param dataSet The dataSet to set.
*/
public void setDataSet(double[][] dataSet)
{
this.dataSet = dataSet;
}

/**
* k的 get() 方法
*
@return int k.
*/
public int getK()
{
return k;
}

/**
* k的 set() 方法
*
@param k The k to set.
*/
public void setK(int k)
{
this.k = k;
}

/**
* centers的 get() 方法
*
@return double[][] centers.
*/
public double[][] getCenters()
{
return centers;
}

/**
* centers的 set() 方法
*
@param centers The centers to set.
*/
public void setCenters(double[][] centers)
{
this.centers = centers;
}

/**
* totalSumOfdistances的 get() 方法
*
@return double totalSumOfdistances.
*/
public double getTotalSumOfdistances()
{
return totalSumOfdistances;
}

/**
* totalSumOfdistances的 set() 方法
*
@param totalSumOfdistances The totalSumOfdistances to set.
*/
public void setTotalSumOfdistances(double totalSumOfdistances)
{
this.totalSumOfdistances = totalSumOfdistances;
}

/**
* iter的 get() 方法
*
@return int iter.
*/
public int getIter()
{
return iter;
}

/**
* convergenceDis的 get() 方法
*
@return double convergenceDis.
*/
public double getConvergenceDis()
{
return convergenceDis;
}

/**
* convergenceDis的 set() 方法
*
@param convergenceDis The convergenceDis to set.
*/
public void setConvergenceDis(double convergenceDis)
{
this.convergenceDis = convergenceDis;
}

/**
* replicates的 get() 方法
*
@return int replicates.
*/
public int getReplicates()
{
return replicates;
}

/**
* replicates的 set() 方法
*
@param replicates The replicates to set.
*/
public void setReplicates(int replicates)
{
this.replicates = replicates;
}



/**
* kmcresults的 get() 方法
*
@return KMCResult[] kmcresults.
*/
public KMCResult[] getKmcresults()
{
return kmcresults;
}

/**
* kmcresults的 set() 方法
*
@param kmcresults The kmcresults to set.
*/
public void setKmcresults(KMCResult[] kmcresults)
{
this.kmcresults = kmcresults;
}



/**
* 聚类运行的结果
*
* <p>time:2011-6-2</p>
*
@author T. QIN
*/
public class KMCResult
{
public double[][] centers;

public double sum;

public int iters;
}
}
然后相关类:
View Code
package org.tadoo.ml.exception;

/**
* 聚类异常
*
* <p>time:2011-5-25</p>
*
@author T. QIN
*/
public class ClusterException extends RuntimeException
{
public ClusterException()
{
super();
}

public ClusterException(String s)
{
super(s);
}
}
View Code
package org.tadoo.ml.util;

/**
* 简单数组计算
*
* <p>time:2011-5-27</p>
*
@author T. QIN
*/
public class ArrayCompute
{
/**
* 数组相加
*
*
@param x1
*
@param x2
*
@return
*
@see:
*/
public static double[] add(final double[] x1, final double[] x2)
{
if (x1.length != x2.length)
{
System.err.print(
"向量长度不等不能相加!");
System.exit(
0);
}
double[] result = new double[x1.length];
for (int i = 0; i < result.length; i++)
{
result[i]
= x1[i] + x2[i];
}
return result;
}

/**
* 数组相减
*
*
@param x1
*
@param x2
*
@return
*
@see:
*/
public static double[] minus(final double[] x1, final double[] x2)
{
if (x1.length != x2.length)
{
System.err.print(
"向量长度不等不能相减!");
System.exit(
0);
}
double[] result = new double[x1.length];
for (int i = 0; i < result.length; i++)
{
result[i]
= x1[i] - x2[i];
}
return result;
}

/**
* 数组乘以一个常数
*
*
@param x1
*
@param c
*
@return
*
@see:
*/
public static double[] multiplyC(final double[] x1, final double c)
{
double[] ret = new double[x1.length];
for (int i = 0; i < x1.length; i++)
{
ret[i]
= x1[i] * c;
}
return ret;
}

/**
* 数组除以一个常数
*
*
@param x1
*
@param c
*
@return
*
@see:
*/
public static double[] devideC(final double[] x1, final double c)
{
double[] ret = new double[x1.length];
for (int i = 0; i < x1.length; i++)
{
ret[i]
= x1[i] / c;
}
return ret;
}
}
View Code
package org.tadoo.ml.util;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;

import org.tadoo.ml.Matrix;

/**
*
*
* <p>time:2011-3-23</p>
*
@author T. QIN
*/
/**
*
*
* <p>time:2011-3-28</p>
*
@author T. QIN
*/
public class Utils
{

/**
* 计算两个点之间的欧几里德距离
*
*
@param x1
*
@param x2
*
@return
*
@see:
*/
public static double distance(double[] x1, double[] x2)
{
double r = 0.0;
for (int i = 0; i < x1.length; i++)
{
r
+= Math.pow(x1[i] - x2[i], 2);
}
return Math.sqrt(r);
}

/**
* 数组相加
*
*
@param x1
*
@param x2
*
@return
*
@see:
*/
public static double[] add(final double[] x1, final double[] x2)
{
if (x1.length != x2.length)
{
System.err.print(
"向量长度不等不能相加!");
System.exit(
0);
}
double[] result = new double[x1.length];
for (int i = 0; i < result.length; i++)
{
result[i]
= x1[i] + x2[i];
}
return result;
}
}
Matrix
package org.tadoo.ml;

import java.io.PrintStream;

import org.tadoo.ml.exception.MatrixComputeException;

/**
* 矩阵结构
*
* <p>time:2011-3-23</p>
*
@author T. QIN
*/
public class Matrix
{
private int rowNum;

private int colNum;

private double value[][];

/**
* 构造器方法
*
*
@param rows 行数
*
@param cols 列数
*
@see:
*
@author: T. QIN
*/
public Matrix(int rows, int cols)
{
this.rowNum = rows;
this.colNum = cols;
this.value = new double[rows][cols];
}

/**
* 构造器方法
*
*
@param rows 行数
*
@param cols 列数
*
@param isInitialMemory 是否初始化权值矩阵
*
@see:
*
@author: T. QIN
*/
public Matrix(int rows, int cols, boolean isInitialMemory)
{
this.rowNum = rows;
this.colNum = cols;
if (isInitialMemory)
{
this.value = new double[rows][cols];
}
}

/**
* 替换矩阵值
*
*
@param v
*
@throws MatrixComputeException
*
@see:
*/
public void changeWholeValue(double v[][]) throws MatrixComputeException
{
if (v.length != this.rowNum && v[0].length != this.colNum)
{
throw new MatrixComputeException("矩阵大小不拟合");
}
this.value = v;
}

public void print(PrintStream ps)
{
if (ps == null)
{
ps
= System.out;
}
for (int i = 0; i < rowNum; i++)
{
for (int j = 0; j < colNum; j++)
{
ps.print(value[i][j]
+ "\t");
}
ps.println();
}
}

/**
* overwrite
*
*
@return
*
@see:
*/
public String toString()
{
StringBuffer sb
= new StringBuffer();
for (int i = 0; i < rowNum; i++)
{
for (int j = 0; j < colNum; j++)
{
sb.append(value[i][j]
+ "\t");
}
sb.append(
"\n");
}
return sb.toString();
}

/**
* rowNum的 get() 方法
*
@return int rowNum.
*/
public int getRowNum()
{
return rowNum;
}

/**
* rowNum的 set() 方法
*
@param rowNum The rowNum to set.
*/
public void setRowNum(int rowNum)
{
this.rowNum = rowNum;
}

/**
* colNum的 get() 方法
*
@return int colNum.
*/
public int getColNum()
{
return colNum;
}

/**
* colNum的 set() 方法
*
@param colNum The colNum to set.
*/
public void setColNum(int colNum)
{
this.colNum = colNum;
}

/**
* value的 get() 方法
*
@return double[][] value.
*/
public double[][] getValue()
{
return value;
}

/**
* value的 set() 方法
*
@param value The value to set.
*/
public void setValue(double[][] value)
{
this.value = value;
}

}
MatrixComputeException
package org.tadoo.ml.exception;

/**
*
*
* <p>time:2011-3-23</p>
*
@author T. QIN
*/
public class MatrixComputeException extends Exception
{
public MatrixComputeException(String s)
{
super(s);
}
}
DataUtil
package org.tadoo.ml.util;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
* 加载文件中的数据
*
* <p>time:2011-5-31</p>
*
@author T. QIN
*/
public class DataUtil
{
/**
* 加载数据
*
*
@param filePath
*
@return
*
@see:
*/
public static double[][] load(String filePath)
{
BufferedReader reader
= null;
List
<String[]> container = new ArrayList<String[]>();
String line
= null;
double[][] result = null;
int xs, ys = 0;
try
{
reader
= new BufferedReader(new FileReader(new File(filePath)));
while ((line = reader.readLine()) != null)
{
String temp[]
= line.trim().split("[\\s]+");
container.add(temp);
}
xs
= (((String[]) container.get(0)).length);
ys
= container.size(); //数据条目
result = new double[ys][xs];
String[] strings
= null;
for (int i = 0, n = container.size(); i < n; i++)
{
strings
= (String[]) container.get(i);
for (int j = 0; j < strings.length; j++)
{
result[i][j]
= Double.parseDouble(strings[j]);
}
}
}
catch (FileNotFoundException e)
{
e.printStackTrace();
}
catch (IOException e)
{
e.printStackTrace();
}
return result;
}

//TODO:
/**
* 输出数据到文件,可选择某几列属性
*
*
@param data
*
@param saveFilename
*
@param columns
*
@see:
*/
public static void save(double[][] data, String saveFilename, int[] columns)
{
BufferedWriter fp_saver
= null;
Arrays.sort(columns);
try
{
fp_saver
= new BufferedWriter(new FileWriter(saveFilename));
for (int i = 0; i < data.length; i++)
{
for (int j = 0; j < columns.length; j++)
{
fp_saver.write(String.valueOf(data[i][columns[j]])
+ " ");
}
fp_saver.write(
"\n");
}
fp_saver.flush();
}
catch (IOException e)
{
e.printStackTrace();
}
finally
{
try
{
fp_saver.close();
}
catch (IOException e)
{
e.printStackTrace();
}
}
}
}

然后是测试:

View Code
package org.tadoo.ml.test;

import junit.framework.TestCase;

import org.tadoo.ml.Matrix;
import org.tadoo.ml.cluster.kmeans.KmeansCluster;
import org.tadoo.ml.util.DataUtil;
import org.tadoo.ml.util.Utils;

/**
* 测试K-means聚类器
*
* <p>time:2011-6-2</p>
*
@author T. QIN
*/
public class TestKmeansCluster extends TestCase
{
Matrix dataSet
= null;

double[][] ds = null;

protected void setUp()
{
dataSet
= Utils.uniformFileInputIntoFeatures("D:\\test.s.txt");
ds
= DataUtil.load("D:\\data1.txt");
}

/**
* 测试用K-means选取中心节点
*
*
@see:
*/
public void testKmeansCenters()
{
KmeansCluster kmc
= new KmeansCluster(dataSet.getValue(), 2);
kmc.train();
System.out.println(kmc.getTotalSumOfdistances());
System.out.println(kmc.getIter());
double[][] centers = kmc.getCenters();
for (int i = 0; i < centers.length; i++)
{
for (int j = 0; j < centers[i].length; j++)
{
System.out.print(centers[i][j]
+ "\t");
}
System.out.println();
}
}

public void testKmeansReplicate(){
KmeansCluster kmc
= new KmeansCluster(dataSet.getValue(), 11);
kmc.setReplicates(
12);
kmc.train();
KmeansCluster.KMCResult[] kmcr
= kmc.getKmcresults();
for (int i = 0; i < kmcr.length; i++)
{
System.out.println(
"iters:"+kmcr[i].iters+"\tSum:"+kmcr[i].sum);
}
}
}
因为是初学者,写的很冗长,不过多少能算出结果来了。
posted @ 2011-06-02 15:47  tadoo  阅读(2832)  评论(5编辑  收藏  举报