Apiroi算法在Hadoop MapReduce上的实现

输入格式:

一行为一个Bucket

1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 34 36 38 40 42 44 46 48 50 52 54 56 58 60 62 64 66 68 70 72 74 
1 3 5 7 9 12 13 15 17 19 21 23 25 27 29 31 34 36 38 40 42 44 46 48 50 52 54 56 58 60 62 64 66 68 70 72 74 
1 3 5 7 9 12 13 16 17 19 21 23 25 27 29 31 34 36 38 40 42 44 46 48 50 52 54 56 58 60 62 64 66 68 70 72 74 
1 3 5 7 9 11 13 15 17 20 21 23 25 27 29 31 34 36 38 40 42 44 47 48 50 52 54 56 58 60 62 64 66 68 70 72 74 
1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 34 36 38 40 42 44 46 48 51 52 54 56 58 60 62 64 66 68 70 72 74 
1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 34 36 38 40 42 44 46 48 51 52 54 56 58 60 63 64 66 68 70 72 74 
1 3 5 7 9 11 13 15 17 20 21 23 25 27 29 31 34 36 38 40 42 44 47 48 51 52 54 56 58 60 62 64 66 68 70 72 74 
1 3 5 7 9 12 13 15 17 19 21 24 25 27 29 31 34 36 38 40 42 44 46 48 50 52 54 56 58 60 62 64 66 68 70 72 74 
1 3 5 7 9 11 13 15 17 19 21 24 25 27 29 31 34 36 38 40 42 44 46 48 50 52 54 56 58 60 62 65 66 68 70 72 74 
1 3 5 7 9 11 13 16 17 19 21 24 25 27 29 31 34 36 38 40 42 44 46 48 50 52 54 56 58 60 62 64 66 68 70 72 74 
1 3 5 7 9 12 13 16 17 19 21 24 25 27 29 31 34 36 38 40 42 44 46 48 50 52 54 56 58 60 62 64 66 68 70 72 74 
1 3 5 7 9 11 13 15 17 20 21 24 25 27 29 31 34 36 38 40 42 44 47 48 50 52 54 56 58 60 62 64 66 68 70 72 74 
1 3 5 7 9 11 13 15 17 20 21 24 25 27 29 31 34 36 38 40 42 44 47 48 50 52 54 56 58 60 62 65 66 68 70 72 74 
1 3 5 7 9 11 13 15 17 20 21 24 25 27 29 31 34 36 38 40 43 44 47 48 50 52 54 56 58 60 62 65 66 68 70 72 74 

 

输出格式:

<item1,item2,...itemK, frequency>

25    2860
29    3181
3    2839
34    3040
36    3099
40    3170
48    3013
5    2971
52    3185
56    3021

 

代码:

  1 package apriori;
  2 
  3 import java.io.IOException;
  4 import java.util.Iterator;
  5 import java.util.StringTokenizer;
  6 import java.util.List;
  7 import java.util.ArrayList;
  8 import java.util.Collections;
  9 import java.util.Map;
 10 import java.util.HashMap;
 11 import java.io.*;
 12 
 13 import org.apache.hadoop.conf.Configuration;
 14 import org.apache.hadoop.conf.Configured;
 15 import org.apache.hadoop.fs.Path;
 16 import org.apache.hadoop.fs.FileSystem;
 17 import org.apache.hadoop.io.Text;
 18 import org.apache.hadoop.io.IntWritable;
 19 import org.apache.hadoop.mapreduce.Job;
 20 import org.apache.hadoop.mapreduce.Mapper;
 21 import org.apache.hadoop.mapreduce.Mapper.Context;
 22 import org.apache.hadoop.mapreduce.Reducer;
 23 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
 24 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
 25 import org.apache.hadoop.mapreduce.lib.jobcontrol.JobControl;
 26 import org.apache.hadoop.mapreduce.lib.jobcontrol.ControlledJob;
 27 import org.apache.hadoop.util.Tool;
 28 import org.apache.hadoop.util.ToolRunner;
 29 
 30 class AprioriPass1Mapper extends Mapper<Object,Text,Text,IntWritable>{
 31     private final static IntWritable one = new IntWritable(1);
 32     private Text number = new Text();
 33 
 34     //第一次pass的Mapper只要把每个item映射为1
 35     public void map(Object key,Text value,Context context) throws IOException,InterruptedException{
 36 
 37         String[] ids = value.toString().split("[\\s\\t]+");
 38         for(int i = 0;i < ids.length;i++){
 39             context.write(new Text(ids[i]),one);
 40         }
 41     }
 42 }
 43 
 44 class AprioriReducer extends Reducer<Text,IntWritable,Text,IntWritable>{
 45     private IntWritable result = new IntWritable();
 46 
 47     //所有Pass的job共用一个reducer,即统计一种itemset的个数,并筛选除大于s的
 48     public void reduce(Text key,Iterable<IntWritable> values,Context context) throws IOException,InterruptedException{
 49         int sum = 0;
 50 
 51         int minSup = context.getConfiguration().getInt("minSup",5);
 52         for(IntWritable val : values){
 53             sum += val.get();
 54         }
 55         result.set(sum);
 56 
 57         if(sum > minSup){
 58             context.write(key,result);
 59         }
 60     }
 61 }
 62 
 63 class AprioriPassKMapper extends Mapper<Object,Text,Text,IntWritable>{
 64     private final static IntWritable one = new IntWritable(1);
 65     private Text item = new Text();
 66 
 67     private List< List<Integer> > prevItemsets = new ArrayList< List<Integer> >();
 68     private List< List<Integer> > candidateItemsets = new ArrayList< List<Integer> >();
 69     private Map<String,Boolean> candidateItemsetsMap = new HashMap<String,Boolean>();
 70 
 71 
 72     //第一个以后的pass使用该Mapper,在map函数执行前会执行setup来从k-1次pass的输出中构建候选itemsets,对应于apriori算法
 73     @Override
 74     public void setup(Context context) throws IOException, InterruptedException{
 75         int passNum = context.getConfiguration().getInt("passNum",2);
 76         String prefix = context.getConfiguration().get("hdfsOutputDirPrefix","");
 77         String lastPass1 = context.getConfiguration().get("fs.default.name") + "/user/hadoop/chess-" + (passNum - 1) + "/part-r-00000";
 78         String lastPass = context.getConfiguration().get("fs.default.name") + prefix + (passNum - 1) + "/part-r-00000";
 79 
 80         try{
 81             Path path = new Path(lastPass);
 82             FileSystem fs = FileSystem.get(context.getConfiguration());
 83             BufferedReader fis = new BufferedReader(new InputStreamReader(fs.open(path)));
 84             String line = null;
 85 
 86             while((line = fis.readLine()) != null){
 87                 
 88                 List<Integer> itemset = new ArrayList<Integer>();
 89 
 90                 String itemsStr = line.split("[\\s\\t]+")[0];
 91                 for(String itemStr : itemsStr.split(",")){
 92                     itemset.add(Integer.parseInt(itemStr));
 93                 }
 94 
 95                 prevItemsets.add(itemset);
 96             }
 97         }catch (Exception e){
 98             e.printStackTrace();
 99         }
100 
101         //get candidate itemsets from the prev itemsets
102         candidateItemsets = getCandidateItemsets(prevItemsets,passNum - 1);
103     }
104 
105 
106     public void map(Object key,Text value,Context context) throws IOException,InterruptedException{
107         String[] ids = value.toString().split("[\\s\\t]+");
108 
109         List<Integer> itemset = new ArrayList<Integer>();
110         for(String id : ids){ 
111             itemset.add(Integer.parseInt(id));
112         }
113 
114         //遍历所有候选集合
115         for(List<Integer> candidateItemset : candidateItemsets){
116             //如果输入的一行中包含该候选集合,则映射1,这样来统计候选集合被包括的次数 
117             //子集合,消耗掉了大部分时间
118             if(contains(candidateItemset,itemset)){
119                 String outputKey = "";
120                 for(int i = 0;i < candidateItemset.size();i++){
121                     outputKey += candidateItemset.get(i) + ",";
122                 }
123                 outputKey = outputKey.substring(0,outputKey.length() - 1);
124                 context.write(new Text(outputKey),one);
125             }
126         }
127     }
128 
129     //返回items是否是allItems的子集
130     private boolean contains(List<Integer> items,List<Integer> allItems){
131         
132         int i = 0;
133         int j = 0;
134         while(i < items.size() && j < allItems.size()){
135             if(allItems.get(j) > items.get(i)){
136                 return false;
137             }else if(allItems.get(j) == items.get(i)){
138                 j++;
139                 i++;
140             }else{
141                 j++;
142             }    
143         }
144 
145         if(i != items.size()){
146             return false;
147         }
148         return true;
149     }
150 
151     //获取所有候选集合,参考apriori算法
152     private List< List<Integer> > getCandidateItemsets(List< List<Integer> > prevItemsets, int passNum){
153 
154         List< List<Integer> > candidateItemsets = new ArrayList<List<Integer> >();
155         
156         //上次pass的输出中选取连个itemset构造大小为k + 1的候选集合
157         for(int i = 0;i < prevItemsets.size();i++){
158             for(int j = i + 1;j < prevItemsets.size();j++){
159                 List<Integer> outerItems = prevItemsets.get(i);
160                 List<Integer> innerItems = prevItemsets.get(j);
161 
162                 List<Integer> newItems = null;
163                 if(passNum == 1){
164                     newItems = new ArrayList<Integer>();
165                     newItems.add(outerItems.get(0));        
166                     newItems.add(innerItems.get(0));        
167                 }
168                 else{    
169                     int nDifferent = 0;
170                     int index = -1;
171                     for(int k = 0; k < passNum && nDifferent < 2;k++){
172                         if(!innerItems.contains(outerItems.get(k))){
173                             nDifferent++;
174                             index = k;
175                         }
176                     }
177 
178                     if(nDifferent == 1){
179                         //System.out.println("inner " + innerItems + " outer : " + outerItems);
180                         newItems = new ArrayList<Integer>();
181                         newItems.addAll(innerItems);
182                         newItems.add(outerItems.get(index));
183                     }
184                 }
185                 if(newItems == null){continue;}
186 
187                 Collections.sort(newItems);
188 
189                 //候选集合必须满足所有的子集都在上次pass的输出中,调用isCandidate进行检测,通过后加入到候选子集和列表
190                 if(isCandidate(newItems,prevItemsets) && !candidateItemsets.contains(newItems)){
191                     candidateItemsets.add(newItems);    
192                     //System.out.println(newItems);
193                 }
194             }
195         }
196 
197         return candidateItemsets;
198     }
199 
200     private boolean isCandidate(List<Integer> newItems,List< List<Integer> > prevItemsets){
201     
202         List<List<Integer>> subsets = getSubsets(newItems);     
203         
204         for(List<Integer> subset : subsets){
205             if(!prevItemsets.contains(subset)){
206                 return false;
207             }
208         }
209 
210         return true;
211     }
212 
213     private List<List<Integer>> getSubsets(List<Integer> items){
214     
215         List<List<Integer>> subsets = new ArrayList<List<Integer>>();
216         for(int i = 0;i < items.size();i++){
217             List<Integer> subset = new ArrayList<Integer>(items);
218             subset.remove(i);
219             subsets.add(subset);
220         }
221 
222         return subsets;
223     }
224 }
225 
226 public class Apriori extends Configured implements Tool{
227 
228     public static int s;
229     public static int k;
230 
231     public int run(String[] args)throws IOException,InterruptedException,ClassNotFoundException{
232         long startTime = System.currentTimeMillis();
233 
234         String hdfsInputDir = args[0];        //从参数1中读取输入数据
235         String hdfsOutputDirPrefix = args[1];    //参数2为输出数据前缀,和第pass次组成输出目录
236         s = Integer.parseInt(args[2]);        //阈值
237         k = Integer.parseInt(args[3]);        //k次pass
238         
239         //循环执行K次pass
240         for(int pass = 1; pass <= k;pass++){
241             long passStartTime = System.currentTimeMillis();
242             
243             //配置执行该job
244             if(!runPassKMRJob(hdfsInputDir,hdfsOutputDirPrefix,pass)){
245                 return -1;    
246             }
247         
248             long passEndTime = System.currentTimeMillis();
249             System.out.println("pass " + pass + " time : " + (passEndTime - passStartTime));
250         }
251 
252         long endTime = System.currentTimeMillis();
253         System.out.println("total time : " + (endTime - startTime));
254 
255         return 0;
256     }
257 
258     private static boolean runPassKMRJob(String hdfsInputDir,String hdfsOutputDirPrefix,int passNum)
259             throws IOException,InterruptedException,ClassNotFoundException{
260 
261             Configuration passNumMRConf = new Configuration();
262             passNumMRConf.setInt("passNum",passNum);
263             passNumMRConf.set("hdfsOutputDirPrefix",hdfsOutputDirPrefix);
264             passNumMRConf.setInt("minSup",s);
265 
266             Job passNumMRJob = new Job(passNumMRConf,"" + passNum);
267             passNumMRJob.setJarByClass(Apriori.class);
268             if(passNum == 1){
269                 //第一次pass的Mapper类特殊对待,不许要构造候选itemsets
270                 passNumMRJob.setMapperClass(AprioriPass1Mapper.class);
271             }
272             else{
273                 //第一次之后的pass的Mapper类特殊对待,不许要构造候选itemsets
274                 passNumMRJob.setMapperClass(AprioriPassKMapper.class);
275             }
276             passNumMRJob.setReducerClass(AprioriReducer.class);
277             passNumMRJob.setOutputKeyClass(Text.class);
278             passNumMRJob.setOutputValueClass(IntWritable.class);
279 
280             FileInputFormat.addInputPath(passNumMRJob,new Path(hdfsInputDir));
281             FileOutputFormat.setOutputPath(passNumMRJob,new Path(hdfsOutputDirPrefix + passNum));
282 
283             return passNumMRJob.waitForCompletion(true);
284     }
285 
286     public static void main(String[] args) throws Exception{
287         int exitCode = ToolRunner.run(new Apriori(),args);
288         System.exit(exitCode);
289     }
290 }

 

 posted on 2016-09-28 00:55  莫扎特的代码  阅读(4493)  评论(1编辑  收藏  举报