[bzoj4567] [loj#2012] [SCOI2016] 背单词
Description
\(Lweb\) 面对如山的英语单词,陷入了深深的沉思,「我怎么样才能快点学完,然后去玩三国杀呢?」。这时候睿智的凤老师从远处飘来,他送给了 \(Lweb\) 一本计划册和一大缸泡椒,然后凤老师告诉 \(Lweb\) ,我知道你要学习的单词总共有 \(n\) 个,现在我们从上往下完成计划表,对于一个序号为 \(x\) 的单词(序号 \(1...x-1\) 都已经被填入):
1.如果存在一个单词是它的后缀,并且当前没有被填入表内,那他需要吃 \(n \times n\) 颗泡椒才能学会;
2.当它的所有后缀都被填入表内的情况下,如果在 \(1...x-1\) 的位置上的单词都不是它的后缀,那么他吃 \(x\) 颗泡椒就能记住它;
3.当它的所有后缀都被填入表内的情况下,如果 \(1...x-1\) 的位置上存在是它后缀的单词,所有是它后缀的单词中,序号最大为 \(y\) ,那么他只要吃 \(x-y\) 颗泡椒就能把它记住。
\(Lweb\) 是一个吃到辣辣的东西会暴走的奇怪小朋友,所以请你帮助 \(Lweb\),寻找一种最优的填写单词方案,使得他记住这 \(n\) 个单词的情况下,吃最少的泡椒。
Input
输入一个整数 ,表示 \(Lweb\) 要学习的单词数。接下来 \(n\) 行,每行有一个单词(由小写字母构成,且保证任意单词两两互不相同)。
Output
\(Lweb\) 吃的最少泡椒数。
Sample Input
2
a
ba
Sample Output
2
HINT
\(1 \leq n \leq 10^5\),所有字符的长度总和 \(1 \leq |len| \leq 510000\)
想法
把所有串反过来,建在 \(trie\) 树上,形成一颗前缀树,那么原题中的后缀即这里的前缀
一个显然的贪心是所有串都在它所有前缀填入后再填,\(n \times n\) 的代价要不起。
相当于在树上进行遍历,使所有点的父节点都在子节点前访问,总代价是所有点与其父节点的访问时间差 的总和。
于是又一个贪心,每个点访问后按它所有子节点 \(size\) 的大小,从小到大依次访问。
很容易想到这个贪心,但咋证明是对的呢?【听说这是个经典问题,但网上也没看见多少证明啊(逃】
可以看 \(neither_nor\) 的证明 https://blog.csdn.net/neither_nor/article/details/51362523
代码
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
const int N = 110000;
typedef long long ll;
int n;
int fa[N],sz[N];
struct trie{
trie *ch[26];
int id;
}pool[N*5],*root;
int cnt;
void insert(char s[],int id){
trie *p=root;
int len=strlen(s);
for(int i=len-1;i>=0;i--){
if(!p->ch[s[i]-'a']) p->ch[s[i]-'a']=&pool[++cnt];
p=p->ch[s[i]-'a'];
}
p->id=id;
}
struct node{
node *nxt;
int v;
}pool2[N],*h[N];
int cnt2;
void addedge(int u,int v){
node *p=&pool2[++cnt2];
p->v=v;p->nxt=h[u];h[u]=p;
}
void dfs(trie *p,int cur){
int nc=p->id ? p->id : cur;
for(int i=0;i<26;i++)
if(p->ch[i])
dfs(p->ch[i],nc);
if(p->id){
sz[p->id]++;
fa[p->id]=cur;
sz[cur]+=sz[p->id];
addedge(cur,p->id);
}
}
ll ans;
int b[N],tot;
void work(int u){
for(node *p=h[u];p;p=p->nxt) work(p->v);
tot=0;
for(node *p=h[u];p;p=p->nxt) b[++tot]=sz[p->v];
sort(b+1,b+1+tot);
for(int i=1;i<=tot;i++)
ans+=1ll*b[i]*(tot-i);
ans+=tot;
}
int main()
{
char s[N];
scanf("%d",&n);
root=&pool[++cnt];
for(int i=1;i<=n;i++){
scanf("%s",s);
insert(s,i);
}
dfs(root,0);
ans=0; work(0);
printf("%lld\n",ans);
return 0;
}