基于MR实现ngram语言模型
在大数据的今天,世界上任何一台单机都无法处理大数据,无论cpu的计算能力或者内存的容量。必须采用分布式来实现多台单机的资源整合,来进行任务的处理,包括离线的批处理和在线的实时处理。
鉴于上次开会讲了语言模型的发展,从规则到后来的NNLM。本章的目的就是锻炼动手能力,在知道原理的基础上,通过采用MR范式,自己实现一个ngram语言模型。
首先通过maven来管理相关包的依赖。
1 <?xml version="1.0" encoding="UTF-8"?> 2 <project xmlns="http://maven.apache.org/POM/4.0.0" 3 xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" 4 xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> 5 <modelVersion>4.0.0</modelVersion> 6 7 <groupId>com.dingheng</groupId> 8 <artifactId>nragmMR</artifactId> 9 <version>1.0-SNAPSHOT</version> 10 11 <packaging>jar</packaging> 12 13 <dependencies> 14 <dependency> 15 <groupId>org.apache.hadoop</groupId> 16 <artifactId>hadoop-client</artifactId> 17 <version>2.7.2</version> 18 </dependency> 19 <dependency> 20 <groupId>org.apache.hadoop</groupId> 21 <artifactId>hadoop-core</artifactId> 22 <version>1.2.1</version> 23 </dependency> 24 <dependency> 25 <groupId>org.apache.hadoop</groupId> 26 <artifactId>hadoop-common</artifactId> 27 <version>2.7.2</version> 28 </dependency> 29 <dependency> 30 <groupId>mysql</groupId> 31 <artifactId>mysql-connector-java</artifactId> 32 <version>8.0.12</version> 33 </dependency> 34 </dependencies> 35 </project>
然后直接上代码:
1.首先是driver,作为程序的启动文件。
1 package com.dingheng; 2 3 import java.io.IOException; 4 import org.apache.hadoop.conf.Configuration; 5 import org.apache.hadoop.fs.Path; 6 import org.apache.hadoop.io.IntWritable; 7 import org.apache.hadoop.io.NullWritable; 8 import org.apache.hadoop.io.Text; 9 import org.apache.hadoop.mapreduce.Job; 10 import org.apache.hadoop.mapreduce.lib.db.DBConfiguration; 11 import org.apache.hadoop.mapreduce.lib.db.DBOutputFormat; 12 import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; 13 import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat; 14 15 public class Driver { 16 17 public static void main(String[] args) throws ClassNotFoundException, IOException, InterruptedException { 18 19 // inputDir 20 // outputDir 21 // NumOfGram 22 // topK 23 24 String inputDir = args[0]; 25 String outputDir = args[1]; 26 String numOfGram = args[2]; 27 String threshold = args[3]; 28 String topK = args[4]; 29 30 // first mapreduce 31 Configuration configurationNGram = new Configuration(); 32 configurationNGram.set("textinputformat.recode.delimiter", "."); 33 configurationNGram.set("numOfGram", numOfGram); 34 35 Job jobNGram = Job.getInstance(configurationNGram); 36 jobNGram.setJobName("NGram"); 37 jobNGram.setJarByClass(Driver.class); 38 39 jobNGram.setMapperClass(NGram.NGramMapper.class); 40 jobNGram.setReducerClass(NGram.NGramReducer.class); 41 42 jobNGram.setOutputKeyClass(Text.class); 43 jobNGram.setMapOutputValueClass(IntWritable.class); 44 45 jobNGram.setInputFormatClass(TextInputFormat.class); 46 jobNGram.setOutputFormatClass(TextOutputFormat.class); 47 48 TextInputFormat.addInputPath(jobNGram, new Path(inputDir)); 49 TextOutputFormat.setOutputPath(jobNGram, new Path(outputDir)); 50 jobNGram.waitForCompletion(true); 51 52 // second mapreduce 53 Configuration configurationLanguage = new Configuration(); 54 configurationLanguage.set("threshold", threshold); 55 configurationLanguage.set("topK", topK); 56 57 DBConfiguration.configureDB(configurationLanguage, 58 "com.mysql.jdbc.Driver", 59 "jdbc:mysql://localhost:3306/test", 60 "root", 61 "123456"); 62 63 Job jobLanguage = Job.getInstance(configurationLanguage); 64 jobLanguage.setJobName("LanguageModel"); 65 jobLanguage.setJarByClass(Driver.class); 66 67 jobLanguage.setMapperClass(LanguageModel.Map.class); 68 jobLanguage.setReducerClass(LanguageModel.Reduce.class); 69 70 jobLanguage.setMapOutputKeyClass(Text.class); 71 jobLanguage.setMapOutputValueClass(Text.class); 72 jobLanguage.setOutputKeyClass(DBOutputWritable.class); 73 jobLanguage.setOutputValueClass(NullWritable.class); 74 75 jobLanguage.setInputFormatClass(TextInputFormat.class); 76 jobLanguage.setOutputFormatClass(DBOutputFormat.class); 77 78 DBOutputFormat.setOutput( 79 jobLanguage, 80 "output", 81 new String[] { "starting_phrase", "following_word", "count"}); 82 83 TextInputFormat.setInputPaths(jobLanguage, new Path(args[1])); 84 85 86 jobLanguage.waitForCompletion(true); 87 88 } 89 }
2.然后是自己的定制类,自己定制了output
1 package com.dingheng; 2 3 import org.apache.hadoop.mapreduce.lib.db.DBWritable; 4 5 import java.sql.PreparedStatement; 6 import java.sql.ResultSet; 7 import java.sql.SQLException; 8 9 public class DBOutputWritable implements DBWritable{ 10 11 private String starting_phrase; 12 private String following_word; 13 private int count; 14 15 public DBOutputWritable(String starting_phrase, String following_word, int count) { 16 this.starting_phrase = starting_phrase; 17 this.following_word = following_word; 18 this.count = count; 19 } 20 21 public void write(PreparedStatement arg0) throws SQLException { 22 arg0.setString(1, starting_phrase); 23 arg0.setString(2, following_word); 24 arg0.setInt(3, count); 25 } 26 27 public void readFields(ResultSet arg0) throws SQLException { 28 this.starting_phrase = arg0.getString(1); 29 this.following_word = arg0.getString(2); 30 this.count = arg0.getInt(3); 31 } 32 }
3.之后自己的mapper和reducer。我试用了两个MR迭代,每一个迭代写在文件中
1 package com.dingheng; 2 3 import org.apache.hadoop.conf.Configuration; 4 import org.apache.hadoop.io.IntWritable; 5 import org.apache.hadoop.io.LongWritable; 6 import org.apache.hadoop.io.Text; 7 import org.apache.hadoop.mapreduce.Mapper; 8 import org.apache.hadoop.mapreduce.Reducer; 9 10 import java.io.IOException; 11 12 public class NGram { 13 14 public static class NGramMapper extends Mapper<LongWritable, Text, Text, IntWritable> { 15 16 int numOfGram; 17 18 @Override 19 public void setup(Context context) { 20 Configuration conf = context.getConfiguration(); 21 numOfGram = conf.getInt("numOfGram", 5); 22 } 23 24 @Override 25 public void map(LongWritable key, 26 Text value, 27 Context context) throws IOException, InterruptedException { 28 /* 29 input: read sentence 30 I love data n=3 31 I love -> 1 32 love data -> 1 33 I love data -> 1 34 */ 35 36 String line = value.toString().trim().toLowerCase().replaceAll("[^a-z]", " "); 37 String[] words = line.split("\\s+"); 38 39 if (words.length < 2) { 40 return; 41 } 42 43 StringBuilder sb; 44 for (int i = 0; i < words.length; i++) { 45 sb = new StringBuilder(); 46 sb.append(words[i]); 47 for (int j = 1; i + j < words.length && j < numOfGram; j++) { 48 sb.append(" "); 49 sb.append(words[i + j]); 50 context.write(new Text(sb.toString()), new IntWritable(1)); 51 } 52 } 53 } 54 } 55 56 public static class NGramReducer extends Reducer<Text, IntWritable, Text, IntWritable> { 57 58 @Override 59 public void reduce(Text key, 60 Iterable<IntWritable> values, 61 Context context) throws IOException, InterruptedException { 62 int sum = 0; 63 for (IntWritable value: values) { 64 sum = sum + value.get(); 65 } 66 context.write(key, new IntWritable(sum)); 67 } 68 } 69 }
1 package com.dingheng; 2 3 import org.apache.hadoop.conf.Configuration; 4 import org.apache.hadoop.io.LongWritable; 5 import org.apache.hadoop.io.NullWritable; 6 import org.apache.hadoop.io.Text; 7 import org.apache.hadoop.mapreduce.Mapper; 8 import org.apache.hadoop.mapreduce.Reducer; 9 10 import java.io.IOException; 11 import java.util.*; 12 13 14 public class LanguageModel { 15 16 public static class Map extends Mapper<LongWritable, Text, Text, Text> { 17 18 // input: I love big data\t10 19 // output: key: I love big value: data = 10 20 21 int threshold; 22 23 @Override 24 protected void setup(Context context) throws IOException, InterruptedException { 25 Configuration configuration = context.getConfiguration(); 26 threshold = configuration.getInt("threshold", 20); 27 } 28 29 @Override 30 public void map(LongWritable key, 31 Text value, 32 Context context) throws IOException, InterruptedException { 33 34 if ((value == null) || (value.toString().trim().length() == 0)) { 35 return; 36 } 37 38 String line = value.toString().trim(); 39 40 String[] wordsPlusCount = line.split("\t"); 41 String[] words = wordsPlusCount[0].split("\\s+"); 42 int count = Integer.valueOf(wordsPlusCount[wordsPlusCount.length - 1]); 43 44 if (wordsPlusCount.length < 2 || count < threshold) { 45 return; 46 } 47 48 StringBuilder sb = new StringBuilder(); 49 for (int i = 0; i < words.length - 1; i++) { 50 sb.append(words[i]); 51 sb.append(" "); 52 } 53 54 String outputKey = sb.toString().trim(); 55 String outputValue = words[words.length - 1]; 56 if (!(outputKey.length() < 1)) { 57 context.write(new Text(outputKey), new Text(outputValue + "=" + count)); 58 } 59 } 60 } 61 62 public static class Reduce extends Reducer<Text, Text, DBOutputWritable, NullWritable> { 63 64 int topK; 65 66 @Override 67 protected void setup(Context context) throws IOException, InterruptedException { 68 Configuration configuration = context.getConfiguration(); 69 topK = configuration.getInt("topK", 5); 70 } 71 72 @Override 73 public void reduce(Text key, 74 Iterable<Text> values, 75 Context context) throws IOException, InterruptedException { 76 // key: I love big 77 // value: <data = 10, girl = 100, boy = 1000 ...> 78 TreeMap<Integer, List<String>> tm = new TreeMap<Integer, List<String>>(Collections.<Integer>reverseOrder()); 79 // <10, <data, baby...>>, <100, <girl>>, <1000, <boy>> 80 81 for (Text val : values) { 82 // val: data = 10 83 String value = val.toString().trim(); 84 String word = value.split("=")[0].trim(); 85 int count = Integer.parseInt(value.split("=")[1].trim()); 86 87 if (tm.containsKey(count)) { 88 tm.get(count).add(word); 89 } else { 90 List<String> list = new ArrayList<String>(); 91 list.add(word); 92 tm.put(count, list); 93 } 94 } 95 96 Iterator<Integer> iter = tm.keySet().iterator(); 97 for (int j = 0; iter.hasNext() && j < topK; ) { 98 int keyCount = iter.next(); 99 List<String> words = tm.get(keyCount); 100 for (String curWord: words) { 101 context.write(new DBOutputWritable(key.toString(), curWord, keyCount), NullWritable.get()); 102 j++; 103 } 104 } 105 } 106 } 107 }