Codeforces Gym 100338B Spam Filter 字符串哈希+贝叶斯公式
原题链接:http://codeforces.com/gym/100338/attachments/download/2136/20062007-winter-petrozavodsk-camp-andrew-stankevich-contest-22-asc-22-en.pdf
题意
这是一个过滤垃圾邮件的算法,叫贝叶斯算法。这个算法的第一步是训练过程,通过人工给定的邮件,来确定每个词语在垃圾邮件中的概率和在普通邮件的概率。然后通过贝叶斯公式来计算每个邮件是否为垃圾邮件。具体过程可以看题,或者维基百科。
题解
模拟题目的过程即可,不过要注意的是,为了避免超时,必须哈希,使用 最小表示来记录字符串。
代码
//#include<iostream> #include<cstring> #include<fstream> #include<vector> #include<map> #include<algorithm> #include<string> #include<set> #include<cmath> #include<queue> #define eps 1e-10 #define MAX_N 1234 #define Pr 131 #define mod 1000000009 using namespace std; typedef long long ll; int s,g,n,t; string sp[MAX_N]; string go[MAX_N]; string ma[MAX_N]; set<ll> spam[MAX_N]; set<ll> good[MAX_N]; set<ll> mail[MAX_N]; set<ll> allWord; string bankLine; map<ll, double> wordIsSpam; map<ll, double> wordIsGood; double pSpam; double pGood; ll Hash(string ss) { ll tmp = Pr; ll res = 0; for (auto c:ss) { res = (res + c * tmp) % mod; tmp = (tmp * Pr) % mod; } return res; } string changeToSmall(string ss) { string res = ""; for (auto c:ss) { if (c <= 'Z' && c >= 'A')c = c - 'A' + 'a'; res = res + c; } return res; } bool isAl(char c) { if (c <= 'Z' && c >= 'A')return true; return (c <= 'z' && 'a' <= c); } void divi(string v[],set<ll> G[],int x) { for (int i = 0; i < x; i++) { int p = 0; while ((!isAl(v[i][p])) && p < v[i].length())p++; bool flag = true; for (int j = p; j < v[i].length(); j++) { if ((!isAl(v[i][j])) && flag) { flag = false; string tmp; tmp.assign(v[i].begin() + p, v[i].begin() + j); tmp = changeToSmall(tmp); allWord.insert(Hash(tmp)); G[i].insert(Hash(tmp)); } else if (isAl(v[i][j]) && flag == false) { p = j; flag = true; } } } } double divide(double a,double b) { if (fabs(b)<eps)return 0; return a / b; } double P(ll word) { double ws, wg; if (wordIsSpam.find(word) == wordIsSpam.end())ws = 0; else ws = wordIsSpam[word]; if (wordIsGood.find(word) == wordIsGood.end())wg = 0; else wg = wordIsGood[word]; return divide(ws * pSpam, ws * pSpam + wg * pGood); } int main() { ifstream cin("spam.in"); ofstream cout("spam.out"); cin.sync_with_stdio(false); cin >> s >> g >> n >> t; pSpam = divide(s, s + g); pGood = divide(g, s + g); getline(cin, bankLine); for (int i = 0; i < s; i++) { getline(cin, sp[i]); sp[i] = sp[i] + " "; } for (int i = 0; i < g; i++) { getline(cin, go[i]); go[i] = go[i] + " "; } for (int i = 0; i < n; i++) { getline(cin, ma[i]); ma[i] = ma[i] + " "; } divi(sp, spam, s); divi(go, good, g); divi(ma, mail, n); for (auto word:allWord) { int cnt = 0; for (int i = 0; i < s; i++) if (spam[i].find(word) != spam[i].end())cnt++; wordIsSpam[word] = divide(cnt, s); cnt = 0; for (int i = 0; i < g; i++) if (good[i].find(word) != good[i].end())cnt++; wordIsGood[word] = divide(cnt, g); } for (int i = 0; i < n; i++) { double ans = 0; for (auto word:mail[i]) { double p = P(word); //cout<<word<<": "<<p<<endl; if (p > 0.5 || fabs(p - 0.5) < eps)ans++; } if (ans * 100 / mail[i].size() < t)cout << "good" << endl; else cout << "spam" << endl; } return 0; }