mapReduce编程之auto complete

1 n-gram模型与auto complete

  n-gram模型是假设文本中一个词出现的概率只与它前面的N-1个词相关。auto complete的原理就是,根据用户输入的词,将后续出现概率较大的词组显示出来。因此我们可以基于n-gram模型来对用户的输入作预测。

  我们的实现方法是:首先用mapreduce在offline对语料库中的数据进行n-gram建模,存到数据库中。然后用户在输入的时候向数据库中查询,获取之后出现的概率较大的词,通过前端php脚本刷新实时显示在界面上。如下所示:

2 mapReduce流程

  

2.1 MR1

  mapper负责按句读入语料库中的数据,分别作2~Ngram的切分(1-gram在这里没用),发送给reducer。

  reducer则统计所有N-gram出现的次数。(这里就是一个wordcount)

2.2 MR2

  mapper负责读入之前生成的N-gram及次数,将最后一个单词切分出来,以前面N-1个单词为key向reducer发送。

 

  reducer里面得到的就是N-gram概率模型,即已知前N-1个词组成的phrase,最后一个词出现的所有可能及其概率。这里我们不用计算概率,仍然沿用词频能达到相同的效果,因为auto complete关注的是概率之间的相对大小而不是概率值本身。这里我们选择出现概率最大的topk个词来存入数据库,可以用treemap或者priorityQueue来做。

    (注:这里的starting_word是1~n-1个词,following_word只能是一个词,因为这样才符合我们N-gram概率模型的意义。)

 2.3 如何预测后面n个单词

  数据库中的n-gram模型:

 

  如上所述,我们看出使用n-gram模型只能与预测下一个单词。为了预测结果的多样性,如果我们要预测之后的n个单词怎么做?

  使用sql语句,查询的时候查询匹配"input%"的所有starting_phrase,就可以实现。

3 代码

 NGramLibraryBuilder.java

 1 import java.io.IOException;
 2 
 3 import org.apache.hadoop.conf.Configuration;
 4 import org.apache.hadoop.fs.Path;
 5 import org.apache.hadoop.io.IntWritable;
 6 import org.apache.hadoop.io.LongWritable;
 7 import org.apache.hadoop.io.Text;
 8 import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
 9 import org.apache.hadoop.mapreduce.Job;
10 import org.apache.hadoop.mapreduce.Mapper;
11 import org.apache.hadoop.mapreduce.Reducer;
12 import org.apache.hadoop.mapreduce.Mapper.Context;
13 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
14 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
15 import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
16 
17 public class NGramLibraryBuilder {
18     public static class NGramMapper extends Mapper<LongWritable, Text, Text, IntWritable> {
19 
20         int noGram;
21         @Override
22         public void setup(Context context) {
23             Configuration conf = context.getConfiguration();
24             noGram = conf.getInt("noGram", 5);
25         }
26 
27         // map method
28         @Override
29         public void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
30             
31             String line = value.toString();
32             
33             line = line.trim().toLowerCase();
34             line = line.replaceAll("[^a-z]", " ");
35             
36             String[] words = line.split("\\s+"); //split by ' ', '\t'...ect
37             
38             if(words.length<2) {
39                 return;
40             }
41             
42             //I love big data
43             StringBuilder sb;
44             for(int i = 0; i < words.length-1; i++) {
45                 sb = new StringBuilder();
46                 sb.append(words[i]);
47                 for(int j=1; i+j<words.length && j<noGram; j++) {
48                     sb.append(" ");
49                     sb.append(words[i+j]);
50                     context.write(new Text(sb.toString().trim()), new IntWritable(1));
51                 }
52             }
53         }
54     }
55 
56     public static class NGramReducer extends Reducer<Text, IntWritable, Text, IntWritable> {
57         // reduce method
58         @Override
59         public void reduce(Text key, Iterable<IntWritable> values, Context context)
60                 throws IOException, InterruptedException {
61             int sum = 0;
62             for(IntWritable value: values) {
63                 sum += value.get();
64             }
65             context.write(key, new IntWritable(sum));
66         }
67     }
68 
69 }
View Code

LanguageModel.java

  1 import java.io.IOException;
  2 import java.util.ArrayList;
  3 import java.util.Collections;
  4 import java.util.Iterator;
  5 import java.util.List;
  6 import java.util.Set;
  7 import java.util.TreeMap;
  8 
  9 import org.apache.hadoop.conf.Configuration;
 10 import org.apache.hadoop.fs.Path;
 11 import org.apache.hadoop.io.LongWritable;
 12 import org.apache.hadoop.io.NullWritable;
 13 import org.apache.hadoop.io.Text;
 14 import org.apache.hadoop.mapreduce.Job;
 15 import org.apache.hadoop.mapreduce.Mapper;
 16 import org.apache.hadoop.mapreduce.Reducer;
 17 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
 18 
 19 public class LanguageModel {
 20     public static class Map extends Mapper<LongWritable, Text, Text, Text> {
 21 
 22         int threashold;
 23         // get the threashold parameter from the configuration
 24         @Override
 25         public void setup(Context context) {
 26             Configuration conf = context.getConfiguration();
 27             threashold = conf.getInt("threashold", 20);
 28         }
 29 
 30         
 31         @Override
 32         public void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
 33             if((value == null) || (value.toString().trim()).length() == 0) {
 34                 return;
 35             }
 36             //this is cool\t20
 37             String line = value.toString().trim();
 38             
 39             String[] wordsPlusCount = line.split("\t");
 40             if(wordsPlusCount.length < 2) {
 41                 return;
 42             }
 43             
 44             String[] words = wordsPlusCount[0].split("\\s+");
 45             int count = Integer.valueOf(wordsPlusCount[1]);
 46             
 47             if(count < threashold) {
 48                 return;
 49             }
 50             
 51             //this is --> cool = 20
 52             StringBuilder sb = new StringBuilder();
 53             for(int i = 0; i < words.length-1; i++) {
 54                 sb.append(words[i]).append(" ");
 55             }
 56             String outputKey = sb.toString().trim();
 57             String outputValue = words[words.length - 1];
 58             
 59             if(!((outputKey == null) || (outputKey.length() <1))) {
 60                 context.write(new Text(outputKey), new Text(outputValue + "=" + count));
 61             }
 62         }
 63     }
 64 
 65     public static class Reduce extends Reducer<Text, Text, DBOutputWritable, NullWritable> {
 66 
 67         int n;
 68         // get the n parameter from the configuration
 69         @Override
 70         public void setup(Context context) {
 71             Configuration conf = context.getConfiguration();
 72             n = conf.getInt("n", 5);
 73         }
 74 
 75         @Override
 76         public void reduce(Text key, Iterable<Text> values, Context context) throws IOException, InterruptedException {
 77             
 78             //this is, <girl = 50, boy = 60>
 79             TreeMap<Integer, List<String>> tm = new TreeMap<Integer, List<String>>(Collections.reverseOrder());
 80             for(Text val: values) {
 81                 String curValue = val.toString().trim();
 82                 String word = curValue.split("=")[0].trim();
 83                 int count = Integer.parseInt(curValue.split("=")[1].trim());
 84                 if(tm.containsKey(count)) {
 85                     tm.get(count).add(word);
 86                 }
 87                 else {
 88                     List<String> list = new ArrayList<String>();
 89                     list.add(word);
 90                     tm.put(count, list);
 91                 }
 92             }
 93             //<50, <girl, bird>> <60, <boy...>>
 94             Iterator<Integer> iter = tm.keySet().iterator();
 95             for(int j=0; iter.hasNext() && j<n; j++) {
 96                 int keyCount = iter.next();
 97                 List<String> words = tm.get(keyCount);
 98                 for(String curWord: words) {
 99                     context.write(new DBOutputWritable(key.toString(), curWord, keyCount),NullWritable.get());
100                     j++;
101                 }
102             }
103         }
104     }
105 }
View Code

DBOutputWritable.java

 1 import java.sql.PreparedStatement;
 2 import java.sql.ResultSet;
 3 import java.sql.SQLException;
 4 
 5 import org.apache.hadoop.mapreduce.lib.db.DBWritable;
 6 
 7 public class DBOutputWritable implements DBWritable{
 8 
 9     private String starting_phrase;
10     private String following_word;
11     private int count;
12     
13     public DBOutputWritable(String starting_prhase, String following_word, int count) {
14         this.starting_phrase = starting_prhase;
15         this.following_word = following_word;
16         this.count= count;
17     }
18 
19     public void readFields(ResultSet arg0) throws SQLException {
20         this.starting_phrase = arg0.getString(1);
21         this.following_word = arg0.getString(2);
22         this.count = arg0.getInt(3);
23         
24     }
25 
26     public void write(PreparedStatement arg0) throws SQLException {
27         arg0.setString(1, starting_phrase);
28         arg0.setString(2, following_word);
29         arg0.setInt(3, count);
30         
31     }
32 
33 }
View Code

Driver.java

 1 import java.io.IOException;
 2 
 3 import org.apache.hadoop.conf.Configuration;
 4 import org.apache.hadoop.fs.Path;
 5 import org.apache.hadoop.io.IntWritable;
 6 import org.apache.hadoop.io.NullWritable;
 7 import org.apache.hadoop.io.Text;
 8 import org.apache.hadoop.mapreduce.Job;
 9 import org.apache.hadoop.mapreduce.lib.db.DBConfiguration;
10 import org.apache.hadoop.mapreduce.lib.db.DBOutputFormat;
11 import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
12 import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
13 
14 
15 public class Driver {
16 
17     public static void main(String[] args) throws ClassNotFoundException, IOException, InterruptedException {
18         //job1
19         Configuration conf1 = new Configuration();
20         conf1.set("textinputformat.record.delimiter", ".");
21         conf1.set("noGram", args[2]);
22         
23         Job job1 = Job.getInstance();
24         job1.setJobName("NGram");
25         job1.setJarByClass(Driver.class);
26         
27         job1.setMapperClass(NGramLibraryBuilder.NGramMapper.class);
28         job1.setReducerClass(NGramLibraryBuilder.NGramReducer.class);
29         
30         job1.setOutputKeyClass(Text.class);
31         job1.setOutputValueClass(IntWritable.class);
32         
33         job1.setInputFormatClass(TextInputFormat.class);
34         job1.setOutputFormatClass(TextOutputFormat.class);
35         
36         TextInputFormat.setInputPaths(job1, new Path(args[0]));
37         TextOutputFormat.setOutputPath(job1, new Path(args[1]));
38         job1.waitForCompletion(true);
39         
40         //how to connect two jobs?
41         // last output is second input
42         
43         //2nd job
44         Configuration conf2 = new Configuration();
45         conf2.set("threashold", args[3]);
46         conf2.set("n", args[4]);
47         
48         DBConfiguration.configureDB(conf2, 
49                 "com.mysql.jdbc.Driver",
50                 "jdbc:mysql://ip_address:port/test",
51                 "root",
52                 "password");
53         
54         Job job2 = Job.getInstance(conf2);
55         job2.setJobName("Model");
56         job2.setJarByClass(Driver.class);
57         
58         job2.addArchiveToClassPath(new Path("path_to_ur_connector"));
59         job2.setMapOutputKeyClass(Text.class);
60         job2.setMapOutputValueClass(Text.class);
61         job2.setOutputKeyClass(DBOutputWritable.class);
62         job2.setOutputValueClass(NullWritable.class);
63         
64         job2.setMapperClass(LanguageModel.Map.class);
65         job2.setReducerClass(LanguageModel.Reduce.class);
66         
67         job2.setInputFormatClass(TextInputFormat.class);
68         job2.setOutputFormatClass(DBOutputFormat.class);
69         
70         DBOutputFormat.setOutput(job2, "output", 
71                 new String[] {"starting_phrase", "following_word", "count"});
72 
73         TextInputFormat.setInputPaths(job2, args[1]);
74         job2.waitForCompletion(true);
75     }
76 
77 }
View Code

 

posted @ 2016-11-20 00:51  coldyan  阅读(621)  评论(0编辑  收藏  举报