使用 DL4J 训练中文词向量
使用 DL4J 训练中文词向量
1 预处理
对中文语料的预处理,主要包括:分词、去停用词以及一些根据实际场景制定的规则。
package ai.mole.test;
import org.ansj.domain.Term;
import org.ansj.splitWord.analysis.ToAnalysis;
import org.nlpcn.commons.lang.tire.domain.Forest;
import org.nlpcn.commons.lang.tire.library.Library;
import java.io.*;
import java.util.LinkedList;
import java.util.List;
import java.util.regex.Pattern;
public class Preprocess {
private static final Pattern NUMERIC_PATTERN = Pattern.compile("^[.\\d]+$");
private static final Pattern ENGLISH_WORD_PATTERN = Pattern.compile("^[a-z]+$");
public static void main(String[] args) {
String inPath1 = "D:\\MyData\\XUGP3\\Desktop\\测试分词\\test1.txt";
String inPath2 = "D:\\MyData\\XUGP3\\Desktop\\测试分词\\stop_words.txt";
String outPath = "D:\\MyData\\XUGP3\\Desktop\\测试分词\\result1.txt";
String encoding = "utf-8";
PrintWriter writer = null;
Forest forest = null;
try {
writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(outPath), encoding));
forest = Library.makeForest(Test.class.getResourceAsStream("/library/userLibrary.dic"));
List<String> lineList = IOUtil.readLines(new FileInputStream(inPath1), encoding);
List<String> stopWordList = IOUtil.readLines(new FileInputStream(inPath2), encoding);
for (String line : lineList) {
String[] cols = line.split("\\t", -1);
if (cols.length < 2) {
continue;
}
String text = cols[0].trim().toLowerCase() + " " + cols[1].trim().toLowerCase();
// 分词
List<Term> termList = ToAnalysis.parse(text, forest).getTerms();
List<String> wordList = new LinkedList<>();
for (Term term : termList) {
String word = term.getName();
if (word.length() < 2) {
continue;
}
if (stopWordList.contains(word)) {
continue;
}
if (isNumeric(word)) {
continue;
}
if (isEnglishWord(word)) {
continue;
}
wordList.add(word);
}
if (wordList.size() > 5) {
String outStr = listToLine(wordList);
writer.println(outStr);
}
}
} catch (FileNotFoundException e) {
System.out.println("The file does not exist or the path is not correct!!!");
System.exit(-1);
} catch (UnsupportedEncodingException e) {
System.out.println("Does not support the current character set!!!");
} catch (IOException e) {
e.printStackTrace();
} catch (Exception e) {
e.printStackTrace();
} finally {
if (writer != null) {
writer.close();
}
}
}
private static boolean isNumeric(String text) {
return NUMERIC_PATTERN.matcher(text).matches();
}
private static boolean isEnglishWord(String text) {
return ENGLISH_WORD_PATTERN.matcher(text).matches();
}
private static String listToLine(List<String> list) {
StringBuilder sb = new StringBuilder();
for (int i=0; i<list.size(); i++) {
sb.append(list.get(i));
if (i != list.size()-1) {
sb.append(" ");
}
}
return sb.toString();
}
}
2 训练
训练的代码非常简单,可以直接看官网的教程,至于 word2vec 的原理可以看皮提果的博文。
package ai.mole.test;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.util.Collection;
public class TrainWord2VecModel {
private static Logger log = LoggerFactory.getLogger(TrainWord2VecModel.class);
public static void main(String[] args) throws IOException {
String corpusPath = "/data/analyze/xgp/words.txt";
String vectorsPath = "/data/analyze/xgp/word_vectors.txt";
log.info("Start Training...");
long st = System.currentTimeMillis();
log.info("Load & vectorize sentences...");
SentenceIterator iter = new BasicLineIterator(new File(corpusPath));
TokenizerFactory t = new DefaultTokenizerFactory();
// t.setTokenPreProcessor(new CommonPreprocessor());
log.info("Building model...");
Word2Vec vec = new Word2Vec.Builder()
.minWordFrequency(50)
.iterations(1)
.epochs(100)
.layerSize(500)
.seed(42)
.windowSize(5)
.iterate(iter)
.tokenizerFactory(t)
.build();
log.info("Fitting word2vec model...");
vec.fit();
log.info("Writing word vectors to text file...");
// WordVectorSerializer.writeWord2VecModel(vec, vectorsPath);
WordVectorSerializer.writeWordVectors(vec, vectorsPath);
log.info("Closest words:");
Collection<String> bydWordList = vec.wordsNearest("比亚迪", 10);
Collection<String> changanWordList = vec.wordsNearest("长安", 10);
System.out.print(bydWordList);
System.out.println(changanWordList);
log.info("10 words closest to '比亚迪': {}", bydWordList);
log.info("10 words closest to '长安': {}", changanWordList);
long et = System.currentTimeMillis();
log.info("Training is completed, and the time taken is " + (et-st) + " ms.");
System.out.println("Training is completed, and the time taken is " + (et-st) + " ms.");
}
}
3 调用
调用训练好的词向量也非常简单,只需要调用 WordVectorSerializer
类的静态方法 readWord2VecModel
就可以了,提供的输入参数就是训练好的词向量路径。
Word2Vec word2Vec = WordVectorSerializer.readWord2VecModel("D:\\MyData\\XUGP3\\Desktop\\测试分词\\vectors.txt");
Collection<String> bydWordList = word2Vec.wordsNearest("比亚迪", 10);
Collection<String> changanWordList = word2Vec.wordsNearest("长安", 10);
System.out.println(bydWordList);
System.out.println(changanWordList);
附录 - maven 依赖
<dependencies>
<dependency>
<groupId>org.apdplat</groupId>
<artifactId>word</artifactId>
<version>1.3</version>
</dependency>
<!-- ND4J backend. You need one in every DL4J project. Normally define artifactId as either "nd4j-native-platform" or "nd4j-cuda-7.5-platform" -->
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>${nd4j.backend}</artifactId>
<version>${nd4j.version}</version>
</dependency>
<!-- Core DL4J functionality -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nlp</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-zoo</artifactId>
<version>${dl4j.version}</version>
</dependency>
<!-- deeplearning4j-ui is used for visualization: see http://deeplearning4j.org/visualization -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-ui_${scala.binary.version}</artifactId>
<version>${dl4j.version}</version>
</dependency>
<!-- ParallelWrapper & ParallelInference live here -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-parallel-wrapper_${scala.binary.version}</artifactId>
<version>${dl4j.version}</version>
</dependency>
<!-- Next 2: used for MapFileConversion Example. Note you need *both* together -->
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-hadoop</artifactId>
<version>${datavec.version}</version>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
<version>${hadoop.version}</version>
</dependency>
<!-- Arbiter - used for hyperparameter optimization (grid/random search) -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>arbiter-deeplearning4j</artifactId>
<version>${arbiter.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>arbiter-ui_2.11</artifactId>
<version>${arbiter.version}</version>
</dependency>
<!-- datavec-data-codec: used only in video example for loading video data -->
<dependency>
<artifactId>datavec-data-codec</artifactId>
<groupId>org.datavec</groupId>
<version>${datavec.version}</version>
</dependency>
</dependencies>