KMeans聚类算法Hadoop实现 (一)
Assistance.java 辅助类
1 package KMeans; 2 3 import org.apache.hadoop.conf.Configuration; 4 import org.apache.hadoop.fs.FSDataInputStream; 5 import org.apache.hadoop.fs.FSDataOutputStream; 6 import org.apache.hadoop.fs.FileSystem; 7 import org.apache.hadoop.fs.Path; 8 import org.apache.hadoop.io.Text; 9 import org.apache.hadoop.util.LineReader; 10 11 import java.io.IOException; 12 import java.util.*; 13 14 public class Assistance { 15 //读取聚类中心点信息:聚类中心ID、聚类中心点 16 public static List<ArrayList<Float>> getCenters(String inputpath){ 17 List<ArrayList<Float>> result = new ArrayList<ArrayList<Float>>(); 18 Configuration conf = new Configuration(); 19 try { 20 FileSystem hdfs = FileSystem.get(conf); 21 Path in = new Path(inputpath); 22 FSDataInputStream fsIn = hdfs.open(in); 23 LineReader lineIn = new LineReader(fsIn, conf); 24 Text line = new Text(); 25 while (lineIn.readLine(line) > 0){ 26 String record = line.toString(); 27 /* 28 因为Hadoop输出键值对时会在键跟值之间添加制表符, 29 所以用空格代替之。 30 */ 31 String[] fields = record.replace("\t", " ").split(" "); 32 List<Float> tmplist = new ArrayList<Float>(); 33 for (int i = 0; i < fields.length; ++i){ 34 tmplist.add(Float.parseFloat(fields[i])); 35 } 36 result.add((ArrayList<Float>) tmplist); 37 } 38 fsIn.close(); 39 } catch (IOException e){ 40 e.printStackTrace(); 41 } 42 return result; 43 } 44 45 //删除上一次MapReduce作业的结果 46 public static void deleteLastResult(String path){ 47 Configuration conf = new Configuration(); 48 try { 49 FileSystem hdfs = FileSystem.get(conf); 50 Path path1 = new Path(path); 51 hdfs.delete(path1, true); 52 } catch (IOException e){ 53 e.printStackTrace(); 54 } 55 } 56 //计算相邻两次迭代结果的聚类中心的距离,判断是否满足终止条件 57 public static boolean isFinished(String oldpath, String newpath, int k, float threshold) 58 throws IOException{ 59 List<ArrayList<Float>> oldcenters = Assistance.getCenters(oldpath); 60 List<ArrayList<Float>> newcenters = Assistance.getCenters(newpath); 61 float distance = 0; 62 for (int i = 0; i < k; ++i){ 63 for (int j = 1; j < oldcenters.get(i).size(); ++j){ 64 float tmp = Math.abs(oldcenters.get(i).get(j) - newcenters.get(i).get(j)); 65 distance += Math.pow(tmp, 2); 66 } 67 } 68 System.out.println("Distance = " + distance + " Threshold = " + threshold); 69 if (distance < threshold) 70 return true; 71 /* 72 如果不满足终止条件,则用本次迭代的聚类中心更新聚类中心 73 */ 74 Assistance.deleteLastResult(oldpath); 75 Configuration conf = new Configuration(); 76 FileSystem hdfs = FileSystem.get(conf); 77 hdfs.copyToLocalFile(new Path(newpath), new Path("/home/hadoop/class/oldcenter.data")); 78 hdfs.delete(new Path(oldpath), true); 79 hdfs.moveFromLocalFile(new Path("/home/hadoop/class/oldcenter.data"), new Path(oldpath)); 80 return false; 81 } 82 }
KMeansDriver.java 作业驱动类
1 package KMeans; 2 3 import org.apache.hadoop.conf.Configuration; 4 import org.apache.hadoop.fs.FileSystem; 5 import org.apache.hadoop.fs.Path; 6 import org.apache.hadoop.io.IntWritable; 7 import org.apache.hadoop.io.Text; 8 import org.apache.hadoop.mapreduce.Job; 9 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; 10 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; 11 import org.apache.hadoop.util.GenericOptionsParser; 12 13 import java.io.IOException; 14 15 public class KMeansDriver{ 16 public static void main(String[] args) throws Exception{ 17 int repeated = 0; 18 19 /* 20 不断提交MapReduce作业指导相邻两次迭代聚类中心的距离小于阈值或到达设定的迭代次数 21 */ 22 do { 23 Configuration conf = new Configuration(); 24 String[] otherArgs = new GenericOptionsParser(conf, args).getRemainingArgs(); 25 if (otherArgs.length != 6){ 26 System.err.println("Usage: <int> <out> <oldcenters> <newcenters> <k> <threshold>"); 27 System.exit(2); 28 } 29 conf.set("centerpath", otherArgs[2]); 30 conf.set("kpath", otherArgs[4]); 31 Job job = new Job(conf, "KMeansCluster");//新建MapReduce作业 32 job.setJarByClass(KMeansDriver.class);//设置作业启动类 33 34 Path in = new Path(otherArgs[0]); 35 Path out = new Path(otherArgs[1]); 36 FileInputFormat.addInputPath(job, in);//设置输入路径 37 FileSystem fs = FileSystem.get(conf); 38 if (fs.exists(out)){//如果输出路径存在,则先删除之 39 fs.delete(out, true); 40 } 41 FileOutputFormat.setOutputPath(job, out);//设置输出路径 42 43 job.setMapperClass(KMeansMapper.class);//设置Map类 44 job.setReducerClass(KMeansReducer.class);//设置Reduce类 45 46 job.setOutputKeyClass(IntWritable.class);//设置输出键的类 47 job.setOutputValueClass(Text.class);//设置输出值的类 48 49 job.waitForCompletion(true);//启动作业 50 51 ++repeated; 52 System.out.println("We have repeated " + repeated + " times."); 53 } while (repeated < 10 && (Assistance.isFinished(args[2], args[3], Integer.parseInt(args[4]), Float.parseFloat(args[5])) == false)); 54 //根据最终得到的聚类中心对数据集进行聚类 55 Cluster(args); 56 } 57 public static void Cluster(String[] args) 58 throws IOException, InterruptedException, ClassNotFoundException{ 59 Configuration conf = new Configuration(); 60 String[] otherArgs = new GenericOptionsParser(conf, args).getRemainingArgs(); 61 if (otherArgs.length != 6){ 62 System.err.println("Usage: <int> <out> <oldcenters> <newcenters> <k> <threshold>"); 63 System.exit(2); 64 } 65 conf.set("centerpath", otherArgs[2]); 66 conf.set("kpath", otherArgs[4]); 67 Job job = new Job(conf, "KMeansCluster"); 68 job.setJarByClass(KMeansDriver.class); 69 70 Path in = new Path(otherArgs[0]); 71 Path out = new Path(otherArgs[1]); 72 FileInputFormat.addInputPath(job, in); 73 FileSystem fs = FileSystem.get(conf); 74 if (fs.exists(out)){ 75 fs.delete(out, true); 76 } 77 FileOutputFormat.setOutputPath(job, out); 78 79 //因为只是将样本点聚类,不需要reduce操作,故不设置Reduce类 80 job.setMapperClass(KMeansMapper.class); 81 82 job.setOutputKeyClass(IntWritable.class); 83 job.setOutputValueClass(Text.class); 84 85 job.waitForCompletion(true); 86 } 87 }
KMeansMapper.java
package KMeans; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Mapper; import java.io.IOException; import java.util.ArrayList; import java.util.List; public class KMeansMapper extends Mapper<Object, Text, IntWritable, Text> { public void map(Object key, Text value, Context context) throws IOException, InterruptedException{ String line = value.toString(); String[] fields = line.split(" "); List<ArrayList<Float>> centers = Assistance.getCenters(context.getConfiguration().get("centerpath")); int k = Integer.parseInt(context.getConfiguration().get("kpath")); float minDist = Float.MAX_VALUE; int centerIndex = k; //计算样本点到各个中心的距离,并把样本聚类到距离最近的中心点所属的类 for (int i = 0; i < k; ++i){ float currentDist = 0; for (int j = 0; j < fields.length; ++j){ float tmp = Math.abs(centers.get(i).get(j + 1) - Float.parseFloat(fields[j])); currentDist += Math.pow(tmp, 2); } if (minDist > currentDist){ minDist = currentDist; centerIndex = i; } } context.write(new IntWritable(centerIndex), new Text(value)); } }
KMeansReducer.java
package KMeans; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Reducer; import java.io.IOException; import java.util.ArrayList; import java.util.List; public class KMeansReducer extends Reducer<IntWritable, Text, IntWritable, Text> { public void reduce(IntWritable key, Iterable<Text> value, Context context) throws IOException, InterruptedException{ List<ArrayList<Float>> assistList = new ArrayList<ArrayList<Float>>(); String tmpResult = ""; for (Text val : value){ String line = val.toString(); String[] fields = line.split(" "); List<Float> tmpList = new ArrayList<Float>(); for (int i = 0; i < fields.length; ++i){ tmpList.add(Float.parseFloat(fields[i])); } assistList.add((ArrayList<Float>) tmpList); } //计算新的聚类中心 for (int i = 0; i < assistList.get(0).size(); ++i){ float sum = 0; for (int j = 0; j < assistList.size(); ++j){ sum += assistList.get(j).get(i); } float tmp = sum / assistList.size(); if (i == 0){ tmpResult += tmp; } else{ tmpResult += " " + tmp; } } Text result = new Text(tmpResult); context.write(key, result); } }