Codeforces Gym 100338B Spam Filter 字符串哈希+贝叶斯公式

原题链接:http://codeforces.com/gym/100338/attachments/download/2136/20062007-winter-petrozavodsk-camp-andrew-stankevich-contest-22-asc-22-en.pdf

题意

这是一个过滤垃圾邮件的算法,叫贝叶斯算法。这个算法的第一步是训练过程,通过人工给定的邮件,来确定每个词语在垃圾邮件中的概率和在普通邮件的概率。然后通过贝叶斯公式来计算每个邮件是否为垃圾邮件。具体过程可以看题,或者维基百科。

题解

模拟题目的过程即可,不过要注意的是,为了避免超时,必须哈希,使用 最小表示来记录字符串。

代码

//#include<iostream>
#include<cstring>
#include<fstream>
#include<vector>
#include<map>
#include<algorithm>
#include<string>
#include<set>
#include<cmath>
#include<queue>
#define eps 1e-10
#define MAX_N 1234
#define Pr 131
#define mod 1000000009
using namespace std;

typedef long long ll;

int s,g,n,t;
string sp[MAX_N];
string go[MAX_N];
string ma[MAX_N];

set<ll> spam[MAX_N];
set<ll> good[MAX_N];
set<ll> mail[MAX_N];

set<ll> allWord;

string bankLine;

map<ll, double> wordIsSpam;
map<ll, double> wordIsGood;

double pSpam;
double pGood;

ll Hash(string ss) {
    ll tmp = Pr;
    ll res = 0;
    for (auto c:ss) {
        res = (res + c * tmp) % mod;
        tmp = (tmp * Pr) % mod;
    }
    return res;
}

string changeToSmall(string ss) {
    string res = "";
    for (auto c:ss) {
        if (c <= 'Z' && c >= 'A')c = c - 'A' + 'a';
        res = res + c;
    }
    return res;
}

bool isAl(char c) {
    if (c <= 'Z' && c >= 'A')return true;
    return (c <= 'z' && 'a' <= c);
}

void divi(string v[],set<ll> G[],int x) {
    for (int i = 0; i < x; i++) {
        int p = 0;
        while ((!isAl(v[i][p])) && p < v[i].length())p++;
        bool flag = true;
        for (int j = p; j < v[i].length(); j++) {
            if ((!isAl(v[i][j])) && flag) {
                flag = false;
                string tmp;
                tmp.assign(v[i].begin() + p, v[i].begin() + j);
                tmp = changeToSmall(tmp);
                allWord.insert(Hash(tmp));
                G[i].insert(Hash(tmp));
            }
            else if (isAl(v[i][j]) && flag == false) {
                p = j;
                flag = true;
            }
        }
    }
}

double divide(double a,double b) {
    if (fabs(b)<eps)return 0;
    return a / b;
}

double P(ll word) {
    double ws, wg;

    if (wordIsSpam.find(word) == wordIsSpam.end())ws = 0;
    else ws = wordIsSpam[word];

    if (wordIsGood.find(word) == wordIsGood.end())wg = 0;
    else wg = wordIsGood[word];

    return divide(ws * pSpam, ws * pSpam + wg * pGood);
}

int main() {
    ifstream cin("spam.in");
    ofstream cout("spam.out");
    cin.sync_with_stdio(false);
    cin >> s >> g >> n >> t;
    pSpam = divide(s, s + g);
    pGood = divide(g, s + g);
    getline(cin, bankLine);
    for (int i = 0; i < s; i++) {
        getline(cin, sp[i]);
        sp[i] = sp[i] + " ";
    }
    for (int i = 0; i < g; i++) {
        getline(cin, go[i]);
        go[i] = go[i] + " ";
    }
    for (int i = 0; i < n; i++) {
        getline(cin, ma[i]);
        ma[i] = ma[i] + " ";
    }
    divi(sp, spam, s);
    divi(go, good, g);
    divi(ma, mail, n);

    for (auto word:allWord) {
        int cnt = 0;
        for (int i = 0; i < s; i++)
            if (spam[i].find(word) != spam[i].end())cnt++;
        wordIsSpam[word] = divide(cnt, s);
        cnt = 0;
        for (int i = 0; i < g; i++)
            if (good[i].find(word) != good[i].end())cnt++;
        wordIsGood[word] = divide(cnt, g);
    }
    for (int i = 0; i < n; i++) {
        double ans = 0;
        for (auto word:mail[i]) {
            double p = P(word);
            //cout<<word<<": "<<p<<endl;
            if (p > 0.5 || fabs(p - 0.5) < eps)ans++;
        }
        if (ans * 100 / mail[i].size() < t)cout << "good" << endl;
        else cout << "spam" << endl;
    }
    return 0;
}

 

posted @ 2015-08-21 19:19  好地方bug  阅读(330)  评论(2编辑  收藏  举报