[Trie][LuoguP3294][SCOI2016]背单词
题意很迷,然后你可以发现一种情况就是为了告诉你这是贪心。
显然我们可以将单词倒序插入Trie树,然后按照子树内单词个数从小到大遍历,模拟统计答案。
然后我们就得到了一个优秀的40分代码。
#include <string.h>
#include <stdio.h>
#include <iostream>
#include <vector>
#include <algorithm>
#define LL long long
#define ULL unsigned long long
#define Uint unsigned int
using namespace std;
const int MAXN=510000+5;
int ch[MAXN][26],node,n,size[MAXN];
char s[MAXN];
bool ed[MAXN];
vector <int> nxt[MAXN];
inline void insert(char *x){
int len=strlen(x+1),now=0;
for(int i=len;i>=1;--i){
int c=x[i]-'a';
if(!ch[now][c]) ch[now][c]=++node;
now=ch[now][c];
}
ed[now]=true;
}
inline void get_size(int now){
size[now]=0;
for(Uint i=0;i<nxt[now].size();++i){
int v=nxt[now][i];
get_size(v);
size[now]+=size[v];
}
if(ed[now]) size[now]++;
}
LL ans=0;int tot=0;
bool cmp(const int &x,const int &y){
return size[x]<size[y];
}
inline void get_ans(int now,int last){
if(ed[now]) tot++,ans+=tot-last,last=tot;
for(Uint i=0;i<nxt[now].size();++i){
int v=nxt[now][i];
get_ans(v,last);
}
}
int main(){
cin>>n;
for(int i=1;i<=n;++i){
scanf("%s",s+1);
insert(s);
}
for(int i=0;i<=node;++i){
for(int j=0;j<26;++j){
if(ch[i][j]) nxt[i].push_back(ch[i][j]);
}
}
get_size(0);
for(int i=0;i<=node;++i) if(nxt[i].size())sort(nxt[i].begin(),nxt[i].end(),cmp);
get_ans(0,0);
cout<<ans;
return 0;
}
为什么只有40分呢?
观察这样一组数据。(转自)
7
aa
aba
abb
ba
bb
bc
bd
我们可以得到这样一棵Trie树。
然后我们会先走右边再走左边。这样就会导致先走bba,再走aa,显然不优。
我们可以通过删去不要的节点来避免这种情况。
然后这就很优秀了。
事实上代码也就多一个rebuild函数。
#include <string.h>
#include <stdio.h>
#include <iostream>
#include <vector>
#include <algorithm>
#define LL long long
#define ULL unsigned long long
#define Uint unsigned int
using namespace std;
const int MAXN=510000+5;
int ch[MAXN][26],node,n,size[MAXN];
char s[MAXN];
bool ed[MAXN];
vector <int> nxt[MAXN];
inline void insert(char *x){
int len=strlen(x+1),now=0;
for(int i=len;i>=1;--i){
int c=x[i]-'a';
if(!ch[now][c]) ch[now][c]=++node;
now=ch[now][c];
}
ed[now]=true;
}
int cnt;
inline void rebuild(int now,int last){
if(ed[now]) nxt[last].push_back(++cnt),last=cnt;
for(int i=0;i<26;++i){
if(!ch[now][i]) continue;
rebuild(ch[now][i],last);
}
}
inline void get_size(int now){
size[now]=1;
for(Uint i=0;i<nxt[now].size();++i){
int v=nxt[now][i];
get_size(v);
size[now]+=size[v];
}
}
LL ans=0;int tot=0;
bool cmp(const int &x,const int &y){
return size[x]<size[y];
}
inline void get_ans(int now,int last){
if(now) tot++,ans+=tot-last,last=tot;
for(Uint i=0;i<nxt[now].size();++i){
int v=nxt[now][i];
get_ans(v,last);
}
}
int main(){
cin>>n;
for(int i=1;i<=n;++i){
scanf("%s",s+1);
insert(s);
}
rebuild(0,0);
get_size(0);
for(int i=0;i<=cnt;++i) if(nxt[i].size())sort(nxt[i].begin(),nxt[i].end(),cmp);
get_ans(0,0);
cout<<ans;
return 0;
}