基于MR实现ngram语言模型

在大数据的今天,世界上任何一台单机都无法处理大数据,无论cpu的计算能力或者内存的容量。必须采用分布式来实现多台单机的资源整合,来进行任务的处理,包括离线的批处理和在线的实时处理。

 

鉴于上次开会讲了语言模型的发展,从规则到后来的NNLM。本章的目的就是锻炼动手能力,在知道原理的基础上,通过采用MR范式,自己实现一个ngram语言模型。

 

首先通过maven来管理相关包的依赖。

 1 <?xml version="1.0" encoding="UTF-8"?>
 2 <project xmlns="http://maven.apache.org/POM/4.0.0"
 3          xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
 4          xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
 5     <modelVersion>4.0.0</modelVersion>
 6 
 7     <groupId>com.dingheng</groupId>
 8     <artifactId>nragmMR</artifactId>
 9     <version>1.0-SNAPSHOT</version>
10 
11     <packaging>jar</packaging>
12 
13     <dependencies>
14         <dependency>
15             <groupId>org.apache.hadoop</groupId>
16             <artifactId>hadoop-client</artifactId>
17             <version>2.7.2</version>
18         </dependency>
19         <dependency>
20             <groupId>org.apache.hadoop</groupId>
21             <artifactId>hadoop-core</artifactId>
22             <version>1.2.1</version>
23         </dependency>
24         <dependency>
25             <groupId>org.apache.hadoop</groupId>
26             <artifactId>hadoop-common</artifactId>
27             <version>2.7.2</version>
28         </dependency>
29         <dependency>
30             <groupId>mysql</groupId>
31             <artifactId>mysql-connector-java</artifactId>
32             <version>8.0.12</version>
33         </dependency>
34     </dependencies>
35 </project>
View Code

 

然后直接上代码:

1.首先是driver,作为程序的启动文件。

  

 1 package com.dingheng;
 2 
 3 import java.io.IOException;
 4 import org.apache.hadoop.conf.Configuration;
 5 import org.apache.hadoop.fs.Path;
 6 import org.apache.hadoop.io.IntWritable;
 7 import org.apache.hadoop.io.NullWritable;
 8 import org.apache.hadoop.io.Text;
 9 import org.apache.hadoop.mapreduce.Job;
10 import org.apache.hadoop.mapreduce.lib.db.DBConfiguration;
11 import org.apache.hadoop.mapreduce.lib.db.DBOutputFormat;
12 import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
13 import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
14 
15 public class Driver {
16 
17     public static void main(String[] args) throws ClassNotFoundException, IOException, InterruptedException {
18 
19         // inputDir
20         // outputDir
21         // NumOfGram
22         // topK
23 
24         String inputDir = args[0];
25         String outputDir = args[1];
26         String numOfGram = args[2];
27         String threshold = args[3];
28         String topK = args[4];
29 
30         // first mapreduce
31         Configuration configurationNGram = new Configuration();
32         configurationNGram.set("textinputformat.recode.delimiter", ".");
33         configurationNGram.set("numOfGram", numOfGram);
34 
35         Job jobNGram = Job.getInstance(configurationNGram);
36         jobNGram.setJobName("NGram");
37         jobNGram.setJarByClass(Driver.class);
38 
39         jobNGram.setMapperClass(NGram.NGramMapper.class);
40         jobNGram.setReducerClass(NGram.NGramReducer.class);
41 
42         jobNGram.setOutputKeyClass(Text.class);
43         jobNGram.setMapOutputValueClass(IntWritable.class);
44 
45         jobNGram.setInputFormatClass(TextInputFormat.class);
46         jobNGram.setOutputFormatClass(TextOutputFormat.class);
47 
48         TextInputFormat.addInputPath(jobNGram, new Path(inputDir));
49         TextOutputFormat.setOutputPath(jobNGram, new Path(outputDir));
50         jobNGram.waitForCompletion(true);
51 
52         // second mapreduce
53         Configuration configurationLanguage = new Configuration();
54         configurationLanguage.set("threshold", threshold);
55         configurationLanguage.set("topK", topK);
56 
57         DBConfiguration.configureDB(configurationLanguage,
58                 "com.mysql.jdbc.Driver",
59                 "jdbc:mysql://localhost:3306/test",
60                 "root",
61                 "123456");
62 
63         Job jobLanguage = Job.getInstance(configurationLanguage);
64         jobLanguage.setJobName("LanguageModel");
65         jobLanguage.setJarByClass(Driver.class);
66 
67         jobLanguage.setMapperClass(LanguageModel.Map.class);
68         jobLanguage.setReducerClass(LanguageModel.Reduce.class);
69 
70         jobLanguage.setMapOutputKeyClass(Text.class);
71         jobLanguage.setMapOutputValueClass(Text.class);
72         jobLanguage.setOutputKeyClass(DBOutputWritable.class);
73         jobLanguage.setOutputValueClass(NullWritable.class);
74 
75         jobLanguage.setInputFormatClass(TextInputFormat.class);
76         jobLanguage.setOutputFormatClass(DBOutputFormat.class);
77 
78         DBOutputFormat.setOutput(
79                 jobLanguage,
80                 "output",
81                 new String[] { "starting_phrase", "following_word", "count"});
82 
83         TextInputFormat.setInputPaths(jobLanguage, new Path(args[1]));
84 
85 
86         jobLanguage.waitForCompletion(true);
87 
88     }
89 }
Driver

 

2.然后是自己的定制类,自己定制了output

 

 1 package com.dingheng;
 2 
 3 import org.apache.hadoop.mapreduce.lib.db.DBWritable;
 4 
 5 import java.sql.PreparedStatement;
 6 import java.sql.ResultSet;
 7 import java.sql.SQLException;
 8 
 9 public class DBOutputWritable implements DBWritable{
10 
11     private String starting_phrase;
12     private String following_word;
13     private int count;
14 
15     public DBOutputWritable(String starting_phrase, String following_word, int count) {
16         this.starting_phrase = starting_phrase;
17         this.following_word = following_word;
18         this.count = count;
19     }
20 
21     public void write(PreparedStatement arg0) throws SQLException {
22         arg0.setString(1, starting_phrase);
23         arg0.setString(2, following_word);
24         arg0.setInt(3, count);
25     }
26 
27     public void readFields(ResultSet arg0) throws SQLException {
28         this.starting_phrase = arg0.getString(1);
29         this.following_word = arg0.getString(2);
30         this.count = arg0.getInt(3);
31     }
32 }
DBOutputWritable

 

3.之后自己的mapper和reducer。我试用了两个MR迭代,每一个迭代写在文件中

 1 package com.dingheng;
 2 
 3 import org.apache.hadoop.conf.Configuration;
 4 import org.apache.hadoop.io.IntWritable;
 5 import org.apache.hadoop.io.LongWritable;
 6 import org.apache.hadoop.io.Text;
 7 import org.apache.hadoop.mapreduce.Mapper;
 8 import org.apache.hadoop.mapreduce.Reducer;
 9 
10 import java.io.IOException;
11 
12 public class NGram {
13 
14     public static class NGramMapper extends Mapper<LongWritable, Text, Text, IntWritable> {
15 
16         int numOfGram;
17 
18         @Override
19         public void setup(Context context) {
20             Configuration conf = context.getConfiguration();
21             numOfGram = conf.getInt("numOfGram", 5);
22         }
23 
24         @Override
25         public void map(LongWritable key,
26                         Text value,
27                         Context context) throws IOException, InterruptedException {
28             /*
29             input: read sentence
30             I love data n=3
31             I love -> 1
32             love data -> 1
33             I love data -> 1
34             */
35 
36             String line = value.toString().trim().toLowerCase().replaceAll("[^a-z]", " ");
37             String[] words = line.split("\\s+");
38 
39             if (words.length < 2) {
40                 return;
41             }
42 
43             StringBuilder sb;
44             for (int i = 0; i < words.length; i++) {
45                 sb = new StringBuilder();
46                 sb.append(words[i]);
47                 for (int j = 1; i + j < words.length && j < numOfGram; j++) {
48                     sb.append(" ");
49                     sb.append(words[i + j]);
50                     context.write(new Text(sb.toString()), new IntWritable(1));
51                 }
52             }
53         }
54     }
55 
56     public static class NGramReducer extends Reducer<Text, IntWritable, Text, IntWritable> {
57 
58         @Override
59         public void reduce(Text key,
60                            Iterable<IntWritable> values,
61                            Context context) throws IOException, InterruptedException {
62             int sum = 0;
63             for (IntWritable value: values) {
64                 sum = sum + value.get();
65             }
66             context.write(key, new IntWritable(sum));
67         }
68     }
69 }
NGram
  1 package com.dingheng;
  2 
  3 import org.apache.hadoop.conf.Configuration;
  4 import org.apache.hadoop.io.LongWritable;
  5 import org.apache.hadoop.io.NullWritable;
  6 import org.apache.hadoop.io.Text;
  7 import org.apache.hadoop.mapreduce.Mapper;
  8 import org.apache.hadoop.mapreduce.Reducer;
  9 
 10 import java.io.IOException;
 11 import java.util.*;
 12 
 13 
 14 public class LanguageModel {
 15 
 16     public static class Map extends Mapper<LongWritable, Text, Text, Text> {
 17 
 18         // input: I love big data\t10
 19         // output: key: I love big  value: data = 10
 20 
 21         int threshold;
 22 
 23         @Override
 24         protected void setup(Context context) throws IOException, InterruptedException {
 25             Configuration configuration = context.getConfiguration();
 26             threshold = configuration.getInt("threshold", 20);
 27         }
 28 
 29         @Override
 30         public void map(LongWritable key,
 31                         Text value,
 32                         Context context) throws IOException, InterruptedException {
 33 
 34             if ((value == null) || (value.toString().trim().length() == 0)) {
 35                 return;
 36             }
 37 
 38             String line = value.toString().trim();
 39 
 40             String[] wordsPlusCount = line.split("\t");
 41             String[] words = wordsPlusCount[0].split("\\s+");
 42             int count = Integer.valueOf(wordsPlusCount[wordsPlusCount.length - 1]);
 43 
 44             if (wordsPlusCount.length < 2 || count < threshold) {
 45                 return;
 46             }
 47 
 48             StringBuilder sb = new StringBuilder();
 49             for (int i = 0; i < words.length - 1; i++) {
 50                 sb.append(words[i]);
 51                 sb.append(" ");
 52             }
 53 
 54             String outputKey = sb.toString().trim();
 55             String outputValue = words[words.length - 1];
 56             if (!(outputKey.length() < 1)) {
 57                 context.write(new Text(outputKey), new Text(outputValue + "=" + count));
 58             }
 59         }
 60     }
 61 
 62     public static class Reduce extends Reducer<Text, Text, DBOutputWritable, NullWritable> {
 63 
 64         int topK;
 65 
 66         @Override
 67         protected void setup(Context context) throws IOException, InterruptedException {
 68             Configuration configuration = context.getConfiguration();
 69             topK = configuration.getInt("topK", 5);
 70         }
 71 
 72         @Override
 73         public void reduce(Text key,
 74                            Iterable<Text> values,
 75                            Context context) throws IOException, InterruptedException {
 76             // key: I love big
 77             // value: <data = 10, girl = 100, boy = 1000 ...>
 78             TreeMap<Integer, List<String>> tm = new TreeMap<Integer, List<String>>(Collections.<Integer>reverseOrder());
 79             // <10, <data, baby...>>, <100, <girl>>, <1000, <boy>>
 80 
 81             for (Text val : values) {
 82                 // val: data = 10
 83                 String value = val.toString().trim();
 84                 String word = value.split("=")[0].trim();
 85                 int count = Integer.parseInt(value.split("=")[1].trim());
 86 
 87                 if (tm.containsKey(count)) {
 88                     tm.get(count).add(word);
 89                 } else {
 90                     List<String> list = new ArrayList<String>();
 91                     list.add(word);
 92                     tm.put(count, list);
 93                 }
 94             }
 95 
 96             Iterator<Integer> iter = tm.keySet().iterator();
 97             for (int j = 0; iter.hasNext() && j < topK; ) {
 98                 int keyCount = iter.next();
 99                 List<String> words = tm.get(keyCount);
100                 for (String curWord: words) {
101                     context.write(new DBOutputWritable(key.toString(), curWord, keyCount), NullWritable.get());
102                     j++;
103                 }
104             }
105         }
106     }
107 }
LanguageModel

 

posted @ 2019-09-11 08:43  _Meditation  阅读(808)  评论(0编辑  收藏  举报