基于字的文本相似度算法——余弦定理

一、算法原理

  基于字的文本相似度余弦定理算法的原理是:

(1)分别统计两个比较文本中所有字出现的频率,从而得出两个文本对应的向量
(2)利用余弦定理计算这两个向量的夹角余弦值

(3)根据自设置的阈值判断两个文本是否相似

 

二、算法的C++实现

 

这里引用的StringUtil.hpp文件引自:

https://github.com/yanyiwu/cppjieba/blob/master/deps/limonp/StringUtil.hpp


 

[cpp] view plain copy
 
  1. /* 
  2.  * CosineSimilarity.hpp 
  3.  * 
  4.  *  Created: 2016年10月2日 
  5.  *   Author: tang 
  6.  */  
  7.   
  8. #ifndef SRC_COSINE_SIMILARITY_HPP_  
  9. #define SRC_COSINE_SIMILARITY_HPP_  
  10. #include <iostream>  
  11. #include <vector>  
  12. #include <map>  
  13. #include <math.h>  
  14. #include "StringUtil.hpp"  
  15.   
  16. using namespace std;  
  17.   
  18. class CosineSimilarity  
  19. {  
  20. public:  
  21.     CosineSimilarity()  
  22.     {  
  23.     }  
  24.   
  25.     double CalculateTextSimilarity(string &str1,string &str2)  
  26.     {  
  27.         vector<uint16_t> words_for_str1;  
  28.         vector<uint16_t> words_for_str2;  
  29.         vector<uint16_t>::iterator it;  
  30.   
  31.         if(!utf8ToUnicode< vector<uint16_t> >(str1,words_for_str1) ||   
  32.             !utf8ToUnicode< vector<uint16_t> >(str2,words_for_str2 ) )  
  33.         {  
  34.             cout<<"TransCode Error"<<endl;  
  35.             return 0.;  
  36.         }  
  37.   
  38.         map< uint16_t,pair<int,int> >seq_map;  
  39.         map< uint16_t,pair<int,int> >::iterator map_it;  
  40.         for(it=words_for_str1.begin();it!=words_for_str1.end();++it)  
  41.         {  
  42.             if(isHanzi(*it))  
  43.             {  
  44.                 map_it=seq_map.find(*it);  
  45.                 if(map_it!=seq_map.end())  
  46.                 {  
  47.                     map_it->second.first++;  
  48.                 }  
  49.                 else  
  50.                 {  
  51.                     pair<int,int> seq;  
  52.                     seq.first=1;  
  53.                     seq.second=0;  
  54.                     seq_map[*it]=seq;  
  55.                 }  
  56.             }  
  57.         }  
  58.   
  59.         for(it=words_for_str2.begin();it!=words_for_str2.end();++it)  
  60.                 {  
  61.             if(isHanzi(*it))  
  62.                         {  
  63.                                 map_it=seq_map.find(*it);          
  64.                                 if(map_it!=seq_map.end())  
  65.                                 {  
  66.                                         map_it->second.second++;  
  67.                                 }  
  68.                                 else  
  69.                                 {  
  70.                     pair<int,int> seq;  
  71.                                         seq.first=0;  
  72.                                         seq.second=1;  
  73.                                         seq_map[*it]=seq;  
  74.                                 }  
  75.                         }  
  76.                 }  
  77.   
  78.         double sqdoc1 = 0.;    
  79.                 double sqdoc2 = 0.;    
  80.                 double denominator = 0.;  
  81.   
  82.         for(map_it=seq_map.begin();map_it!=seq_map.end();++map_it)  
  83.         {  
  84.             pair<int,int> c=map_it->second;  
  85.             denominator +=(c.first * c.second);  
  86.             sqdoc1+=(c.first * c.first);  
  87.             sqdoc2+=(c.second * c.second);  
  88.         }   
  89.   
  90.         if(0==sqdoc1 * sqdoc2)  
  91.             return -1.0;  
  92.   
  93.         return denominator/sqrt(sqdoc1 * sqdoc2);  
  94.     }  
  95.   
  96.     bool codeFilter(int code)   
  97.     {  
  98.             if ((code < 0x4e00 || code > 0x9fa5) &&   
  99.             !(code >= '0' && code <= '9') &&   
  100.             !(code >= 'a' && code <= 'z') &&   
  101.             !(code >= 'A' && code <= 'Z'))  
  102.                  return false;  
  103.           
  104.             return true;  
  105.     }  
  106.   
  107.     bool isHanzi(uint16_t ch)  
  108.     {  
  109.         return (ch >= 0x4E00 && ch <= 0x9FA5);  
  110.     }  
  111. };  

三、算法的Java实现

 

 

[java] view plain copy
 
  1. import java.io.UnsupportedEncodingException;    
  2. import java.util.Date;    
  3. import java.util.HashMap;    
  4. import java.util.Iterator;    
  5. import java.util.Map;   
  6.   
  7.   
  8. public class CosineSimilarity{  
  9.       
  10.     /** 
  11.      * 输入两段文本利用孜频率的余弦定理判断二者间的相似度 
  12.      *  
  13.      * @param doc1,文本1 
  14.      * @param doc2,文本2 
  15.      * @return 相似度值 
  16.      */  
  17.     public double CalculateTextSim(String doc1, String doc2) {  
  18.         if (doc1 != null && doc1.trim().length() > 0 && doc2 != null  
  19.                 && doc2.trim().length() > 0) {  
  20.               
  21.             Map<Integer, int[]> AlgorithmMap = new HashMap<Integer, int[]>();  
  22.               
  23.             //将两个字符串中的中文字符以及出现的总数封装到,AlgorithmMap中  
  24.             for (int i = 0; i < doc1.length(); i++) {  
  25.                 char d1 = doc1.charAt(i);  
  26.                 if(isHanZi(d1)){  
  27.                     int charIndex = getGB2312Id(d1);  
  28.                     if(charIndex != -1){  
  29.                         int[] fq = AlgorithmMap.get(charIndex);  
  30.                         if(fq != null && fq.length == 2){  
  31.                             fq[0]++;  
  32.                         }else {  
  33.                             fq = new int[2];  
  34.                             fq[0] = 1;  
  35.                             fq[1] = 0;  
  36.                             AlgorithmMap.put(charIndex, fq);  
  37.                         }  
  38.                     }  
  39.                 }  
  40.             }  
  41.   
  42.             for (int i = 0; i < doc2.length(); i++) {  
  43.                 char d2 = doc2.charAt(i);  
  44.                 if(isHanZi(d2)){  
  45.                     int charIndex = getGB2312Id(d2);  
  46.                     if(charIndex != -1){  
  47.                         int[] fq = AlgorithmMap.get(charIndex);  
  48.                         if(fq != null && fq.length == 2){  
  49.                             fq[1]++;  
  50.                         }else {  
  51.                             fq = new int[2];  
  52.                             fq[0] = 0;  
  53.                             fq[1] = 1;  
  54.                             AlgorithmMap.put(charIndex, fq);  
  55.                         }  
  56.                     }  
  57.                 }  
  58.             }  
  59.               
  60.             Iterator<Integer> iterator = AlgorithmMap.keySet().iterator();  
  61.             double sqdoc1 = 0;  
  62.             double sqdoc2 = 0;  
  63.             double denominator = 0;   
  64.             while(iterator.hasNext()){  
  65.                 int[] c = AlgorithmMap.get(iterator.next());  
  66.                 denominator += c[0]*c[1];  
  67.                 sqdoc1 += c[0]*c[0];  
  68.                 sqdoc2 += c[1]*c[1];  
  69.             }  
  70.               
  71.             return denominator / Math.sqrt(sqdoc1*sqdoc2);  
  72.         } else {  
  73.             throw new NullPointerException("the Document is null or have not cahrs!!");  
  74.         }  
  75.     }  
  76.   
  77.     /** 
  78.      * 输入一个字符判断是否为中文汉字 
  79.      *  
  80.      * @param ch,字符 
  81.      * @return true为中文汉字,否则为false 
  82.      */   
  83.     public boolean isHanZi(char ch) {  
  84.     return (ch >= 0x4E00 && ch <= 0x9FA5);      
  85.     }     
  86.       
  87.     /** 
  88.      * 根据输入的Unicode字符,获取它的GB2312编码或者ascii编码, 
  89.      *  
  90.      * @param ch,输入的GB2312中文字符或者ASCII字符(128个) 
  91.      * @return ch在GB2312中的位置,-1表示该字符不认识 
  92.      */  
  93.     public static short getGB2312Id(char ch) {  
  94.         try {  
  95.             byte[] buffer = Character.toString(ch).getBytes("GB2312");  
  96.             if (buffer.length != 2) {  
  97.                 // 正常情况下buffer应该是两个字节,否则说明ch不属于GB2312编码,故返回'?',此时说明不认识该字符  
  98.                 return -1;  
  99.             }  
  100.             int b0 = (int) (buffer[0] & 0x0FF) - 161; // 编码从A1开始,因此减去0xA1=161  
  101.             int b1 = (int) (buffer[1] & 0x0FF) - 161; // 第一个字符和最后一个字符没有汉字,因此每个区只收16*6-2=94个汉字  
  102.             return (short) (b0 * 94 + b1);  
  103.         } catch (UnsupportedEncodingException e) {  
  104.             e.printStackTrace();  
  105.         }  
  106.         return -1;  
  107.     }  
  108. }  


 

posted @ 2017-11-28 13:46  Histring  阅读(1377)  评论(0编辑  收藏  举报