使用贝叶斯做英文拼写检查(c#)
贝叶斯算法可以用来做拼写检查、文本分类、垃圾邮件过滤等工作,前面我们用贝叶斯做了文本分类,这次用它来做拼写检查,参考:How to Write a Spelling Corrector
拼写检查器的原理
给定一个单词, 我们的任务是选择和它最相似的拼写正确的单词.
对应的贝叶斯问题就是, 给定一个词 w, 在所有正确的拼写词中, 我们想要找一个正确的词 c, 使得对于 w 的条件概率最大, 也就是说:
argmaxc P(c|w)
按照贝叶斯理论上面的式子等价于:
argmaxc P(w|c) P(c) / P(w)
因为用户可以输错任何词, 因此对于任何 c 来讲, 出现 w 的概率 P(w) 都是一样的, 从而我们在上式中忽略它, 写成:
argmaxc P(w|c) P(c)
因此argmaxc P(w|c) P(c)就是编辑距离与P(c)的的乘积
其中编辑距离:两个词之间的编辑距离定义为使用了几次插入(在词中插入一个单字母), 删除(删除一个单字母), 交换(交换相邻两个字母), 替换(把一个字母换成另一个)的操作从一个词变到另一个词.
一般情况下,编辑距离为2时已经可以覆盖大部分情况
计算先验概率P(c)
为了尽量覆盖较多的词语,首先从词典中读入常见的英文单词
从en-US读取词语【词语开始[Words]】
然后,从训练语料(训练语料在此下载 big.txt)训练我们的词典(语言模型,得到词语概率,出现频率越高的词语越常见)
训练语料
1 /// <summary> 2 /// 训练词典 3 /// </summary> 4 /// <param name="trainingFile"></param> 5 /// <param name="ht"></param> 6 public static void TrainDic(string trainingFile, Dictionary<string, int> ht) 7 { 8 9 StreamReader reader = new StreamReader(trainingFile); 10 string sLine = "";//存放每一个句子 11 12 string pattern = @"[a-z]+";//匹配单词 13 14 Regex regex = new Regex(pattern); 15 int count = 0;//计算单词的个数 16 17 while (sLine != null) 18 { 19 sLine = reader.ReadLine(); 20 if (sLine != null) 21 { 22 sLine = sLine.ToLower().Replace("'", " "); 23 var matchWords = regex.Matches(sLine); 24 25 foreach (Match match in matchWords) 26 { 27 var word = match.Value; 28 if (!ht.ContainsKey(word)) 29 { 30 count++; 31 ht.Add(word, 1); 32 } 33 else 34 { 35 ht[word]++; 36 } 37 } 38 } 39 } 40 reader.Close(); 41 }
为了复用,可以将训练后的词典保存取来
StringBuilder dicBuilder = new StringBuilder(); foreach (var item in Dic) { dicBuilder.AppendLine(item.Key + "\t" + item.Value); } File.WriteAllText(dicFile, dicBuilder.ToString());
获取建议词语
我们定义优先级: 编辑举例为1》编辑举例为2
首先,找到编辑距离为1的词语
编辑距离为1的词语
/// <summary> /// 编辑距离为1的词语 /// </summary> /// <param name="word"></param> /// <returns></returns> public static List<string> GetEdits1(string word) { var n = word.Length; var tempWord = ""; var editsWords = new List<string>(); for (int i = 0; i < n; i++)//delete一个字母的情况 { tempWord = word.Substring(0, i) + word.Substring(i + 1); if (!editsWords.Contains(tempWord)) editsWords.Add(tempWord); } for (int i = 0; i < n - 1; i++)//调换transposition一个字母的情况 { tempWord = word.Substring(0, i) + word.Substring(i + 1, 1) + word.Substring(i, 1) + word.Substring(i + 2); if (!editsWords.Contains(tempWord)) editsWords.Add(tempWord); } for (int i = 0; i < n; i++)//替换replace一个字母的情况 { string t = word.Substring(i, 1); for (int ch = 'a'; ch <= 'z'; ch++) { if (ch != Convert.ToChar(t)) { tempWord = word.Substring(0, i) + Convert.ToChar(ch) + word.Substring(i + 1); if (!editsWords.Contains(tempWord)) editsWords.Add(tempWord); } } } for (int i = 0; i <= n; i++)//insert一个字母的情况 { //string t = word.Substring(i, 1); for (int ch = 'a'; ch <= 'z'; ch++) { tempWord = word.Substring(0, i) + Convert.ToChar(ch) + word.Substring(i); if (!editsWords.Contains(tempWord)) editsWords.Add(tempWord); } } return editsWords; }
如果编辑举例为1的词语没有正确的词语时,继续寻找为2的词语,为了控制规模,只选取正确的词语
获取编辑距离为2的单词
/// <summary> /// 获取编辑距离为2的单词 /// </summary> /// <param name="word"></param> /// <returns></returns> public static List<string> GetEdits2(string word) { Stopwatch watch = new Stopwatch(); watch.Start(); var words = GetEdits1(word); var result = words.AsReadOnly().ToList(); foreach (var edit in words) { GetEdits1(edit).ForEach(w => { if (Dic.ContainsKey(w)) { result.Add(w); } }); } watch.Stop(); Console.WriteLine(watch.ElapsedMilliseconds); return result; }
最后是获取建议词语的代码,最后的结果按照概率大小倒排序,取前5个
获取建议词语
/// <summary> /// 获取建议词语 /// </summary> /// <param name="word"></param> /// <returns></returns> public static List<string> GetSuggestWords(string word) { var result = GetEdits1(word).Where(w => Dic.ContainsKey(w)).ToList(); if (result.Count == 0) { result = GetEdits2(word); if (result.Count == 0) { result.Add(word); } } // 按先验概率排序 result = result.OrderByDescending(w => Dic.ContainsKey(w) ? Dic[w] : 1).ToList(); return result.Take(Math.Min(result.Count, 5)).ToList(); }
测试代码
View Code
static Dictionary<string, int> Dic; static string dicFile = "dic.txt"; static string trainingFile = "training.txt"; static void Main(string[] args) { if (File.Exists(dicFile)) { Console.WriteLine("加载词典中..."); LoadDic(); Console.WriteLine("加载词典完成"); } else { Console.WriteLine("训练词典中..."); Dic = LoadUSDic(); TrainDic(trainingFile, Dic); StringBuilder dicBuilder = new StringBuilder(); foreach (var item in Dic) { dicBuilder.AppendLine(item.Key + "\t" + item.Value); } File.WriteAllText(dicFile, dicBuilder.ToString()); var wordCount = Dic.Count; Console.WriteLine("训练完成..."); } Console.WriteLine("请输入词语..."); var inputWord = Console.ReadLine(); while (!inputWord.Equals("exit")) { if (Dic.ContainsKey(inputWord)) { Console.WriteLine("你输入的词语 【" + inputWord + "】 是正确的!"); } else { var suggestWords = GetSuggestWords(inputWord); Console.WriteLine("候选词语: "); foreach (var word in suggestWords) { Console.WriteLine("\t\t\t " + word); } } Console.WriteLine("请输入词语...."); inputWord = Console.ReadLine(); } } /// <summary> /// 加载词典 /// </summary> public static void LoadDic() { Dic = new Dictionary<string, int>(); var lines = File.ReadAllLines(dicFile); foreach (var line in lines) { if (line != "") { var dicItem = line.Split('\t'); if (dicItem.Length == 2) Dic.Add(dicItem[0], int.Parse(dicItem[1])); } } }
运行效果
完整代码
完整代码
using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Collections; using System.IO; using System.Text.RegularExpressions; using System.Diagnostics; namespace SpellCheck { class Program { static Dictionary<string, int> Dic; static string dicFile = "dic.txt"; static string trainingFile = "training.txt"; static void Main(string[] args) { if (File.Exists(dicFile)) { Console.WriteLine("加载词典中..."); LoadDic(); Console.WriteLine("加载词典完成"); } else { Console.WriteLine("训练词典中..."); Dic = LoadUSDic(); TrainDic(trainingFile, Dic); StringBuilder dicBuilder = new StringBuilder(); foreach (var item in Dic) { dicBuilder.AppendLine(item.Key + "\t" + item.Value); } File.WriteAllText(dicFile, dicBuilder.ToString()); var wordCount = Dic.Count; Console.WriteLine("训练完成..."); } Console.WriteLine("请输入词语..."); var inputWord = Console.ReadLine(); while (!inputWord.Equals("exit")) { if (Dic.ContainsKey(inputWord)) { Console.WriteLine("你输入的词语 【" + inputWord + "】 是正确的!"); } else { var suggestWords = GetSuggestWords(inputWord); Console.WriteLine("候选词语: "); foreach (var word in suggestWords) { Console.WriteLine("\t\t\t " + word); } } Console.WriteLine("请输入词语...."); inputWord = Console.ReadLine(); } } /// <summary> /// 加载词典 /// </summary> public static void LoadDic() { Dic = new Dictionary<string, int>(); var lines = File.ReadAllLines(dicFile); foreach (var line in lines) { if (line != "") { var dicItem = line.Split('\t'); if (dicItem.Length == 2) Dic.Add(dicItem[0], int.Parse(dicItem[1])); } } } /// <summary> /// 训练词典 /// </summary> /// <param name="trainingFile"></param> /// <param name="ht"></param> public static void TrainDic(string trainingFile, Dictionary<string, int> ht) { StreamReader reader = new StreamReader(trainingFile); string sLine = "";//存放每一个句子 string pattern = @"[a-z]+";//匹配单词 Regex regex = new Regex(pattern); int count = 0;//计算单词的个数 while (sLine != null) { sLine = reader.ReadLine(); if (sLine != null) { sLine = sLine.ToLower().Replace("'", " "); var matchWords = regex.Matches(sLine); foreach (Match match in matchWords) { var word = match.Value; if (!ht.ContainsKey(word)) { count++; ht.Add(word, 1); } else { ht[word]++; } } } } reader.Close(); } /// <summary> /// 从en-US读取词语【词语开始[Words]】 /// </summary> /// <returns></returns> public static Dictionary<string, int> LoadUSDic() { var dic = new Dictionary<string, int>(); string currentSection = ""; FileStream fs = new FileStream("en-US.dic", FileMode.Open, FileAccess.Read, FileShare.Read); StreamReader sr = new StreamReader(fs, Encoding.UTF8); while (sr.Peek() >= 0) { string tempLine = sr.ReadLine().Trim(); if (tempLine.Length > 0) { switch (tempLine) { case "[Words]": currentSection = tempLine; break; default: switch (currentSection) { case "[Words]": // dictionary word list // splits word into its parts string[] parts = tempLine.Split('/'); dic.Add(parts[0], 1); break; } // currentSection swith break; } //tempLine switch } // if templine } // read line sr.Close(); fs.Close(); return dic; } /// <summary> /// 编辑距离为1的词语 /// </summary> /// <param name="word"></param> /// <returns></returns> public static List<string> GetEdits1(string word) { var n = word.Length; var tempWord = ""; var editsWords = new List<string>(); for (int i = 0; i < n; i++)//delete一个字母的情况 { tempWord = word.Substring(0, i) + word.Substring(i + 1); if (!editsWords.Contains(tempWord)) editsWords.Add(tempWord); } for (int i = 0; i < n - 1; i++)//调换transposition一个字母的情况 { tempWord = word.Substring(0, i) + word.Substring(i + 1, 1) + word.Substring(i, 1) + word.Substring(i + 2); if (!editsWords.Contains(tempWord)) editsWords.Add(tempWord); } for (int i = 0; i < n; i++)//替换replace一个字母的情况 { string t = word.Substring(i, 1); for (int ch = 'a'; ch <= 'z'; ch++) { if (ch != Convert.ToChar(t)) { tempWord = word.Substring(0, i) + Convert.ToChar(ch) + word.Substring(i + 1); if (!editsWords.Contains(tempWord)) editsWords.Add(tempWord); } } } for (int i = 0; i <= n; i++)//insert一个字母的情况 { //string t = word.Substring(i, 1); for (int ch = 'a'; ch <= 'z'; ch++) { tempWord = word.Substring(0, i) + Convert.ToChar(ch) + word.Substring(i); if (!editsWords.Contains(tempWord)) editsWords.Add(tempWord); } } return editsWords; } /// <summary> /// 获取编辑距离为2的单词 /// </summary> /// <param name="word"></param> /// <returns></returns> public static List<string> GetEdits2(string word) { Stopwatch watch = new Stopwatch(); watch.Start(); var words = GetEdits1(word); var result = words.AsReadOnly().ToList(); foreach (var edit in words) { GetEdits1(edit).ForEach(w => { if (Dic.ContainsKey(w)) { result.Add(w); } }); } watch.Stop(); Console.WriteLine(watch.ElapsedMilliseconds); return result; } //static WordCompare compare = new WordCompare(); /// <summary> /// 获取建议词语 /// </summary> /// <param name="word"></param> /// <returns></returns> public static List<string> GetSuggestWords(string word) { var result = GetEdits1(word).Where(w => Dic.ContainsKey(w)).ToList(); if (result.Count == 0) { result = GetEdits2(word); if (result.Count == 0) { result.Add(word); } } // 按先验概率排序 result = result.OrderByDescending(w => Dic.ContainsKey(w) ? Dic[w] : 1).ToList(); return result.Take(Math.Min(result.Count, 5)).ToList(); } /// <summary> /// 自定义比较 /// </summary> class WordCompare : IComparer<string> { public int Compare(string x, string y) { var hash1 = Dic.ContainsKey(x) ? Dic[x] : 1; var hash2 = Dic.ContainsKey(y) ? Dic[y] : 1; return hash1.CompareTo(hash2); } } } }