[SCOI2016]背单词
题目链接:https://www.luogu.com.cn/problem/P3294
思路:
首先看到题目,又是trie的题,所以自然想到用trie来存储单词,考虑到有两种存储顺序:正序插入和倒序插入(即前缀和后缀),根据题目,选择倒序更好,这样一来单词及其后缀就都在同一条链上了。
对于题目的理解,我刚开始以为是:选择单词顺序,选择时ans加上当前节点到最近的有结尾标记的祖先或根节点的距离,为此我甚至还想出了按深度(dfs序)安排顺序(尽量避免情况1和情况2),但不久发现我想错了,不过想出来的贪心思路还是有作用的。
那么结合贪心思路再捋一捋题意:选择单词顺序,对于当前单词,设它的编号为x,离它最近的后缀祖先的编号为y,那么ans加上y-x,否则加上x(以贪心思路,先选择后缀,尽量避免情况1和2);然后求ans的最小值。
既然如此,那么每次计算的时候就要找到当前单词的祖先,然而一个一个地往上查找难免费力,所以打算压缩一下路径,由于计算当中只需要单词结尾的点,所以非结尾的点可尽情扔掉。
再想一想怎样选择,才能让ans最小,由于我太蒻了,这里只提供一个思路:设两个子树,一个子节点多一个子节点少,计算两个不同顺序的优劣性,可以证得:遇到分枝时,优先选择子树规模小的单词节点,ans最小;
最后,再模拟算出ans,问题解决!
总流程一览:
1.构建后缀trie;
2.构建关键点树;
3.贪心+模拟算出答案。
历程及注意事项:
20分WA代码:
#include <bits/stdc++.h> using namespace std; int n; string str; int b[100005]; int ch[510005][30]; vector<int>to[100005];//dfs前向星链表 map <int,int> a;//trie到dfs的映射 int fa[100005]; int size[100005]; int wz[100005]; int cnttt=0; int cntt=0; int cnt=1; long long ans=0; bool cmp(int x,int y){ return size[x]<size[y]; } void biuld(){//构建trie int u=1; int last=1; for(int i=str.size()-1;i>=0;i--){ int c=str[i]-'a'; if(!ch[u][c]){ ch[u][c]=++cnt; u=cnt; }else{ u=ch[u][c]; } if(b[u])last=u; } fa[++cntt]=a[last]; a[u]=cntt; b[u]=1; } void dfsprint(int now){ cout<<now; for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){ dfsprint(*it); } } void dfs1(int now){//求size if(to[now].size()==0){ size[now]=1; }else{ size[now]=1; for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){ dfs1(*it); size[now]+=size[*it]; } } } void dfs2(int now,int last){//模拟求出ans ++cnttt; ans+=cnttt-wz[last]; sort(to[now].begin(),to[now].end(),cmp); for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){ dfs2(*it,now); } } int main() { cin>>n; for(int i=0;i<=26;i++)ch[0][i]=1; a[1]=0; for(int i=1;i<=n;i++){ cin>>str; biuld(); } for(int i=1;i<=cntt;i++)to[fa[i]].push_back(i); dfs1(0); wz[0]=0; sort(to[0].begin(),to[0].end(),cmp); for(vector<int>::iterator it=to[0].begin();it!=to[0].end();it++){ dfs2(*it,0); } //dfsprint(0); //for(int i=0;i<=n;i++)cout<<size[i]<<endl; cout<<ans; return 0; }
这里发现了在函数dfs2中wz数组没有更新的错误,在下面的代码进行了更新。
30分WA代码:
#include <bits/stdc++.h> using namespace std; int n; string str; int b[100005]; int ch[510005][30]; vector<int>to[100005];//dfs前向星链表 map <int,int> a;//trie到dfs的映射 int fa[100005]; int size[100005]; int wz[100005]; int cnttt=0; int cntt=0; int cnt=1; long long ans=0; bool cmp(int x,int y){ return size[x]<size[y]; } void biuld(){ int u=1; int last=1; for(int i=str.size()-1;i>=0;i--){ int c=str[i]-'a'; if(!ch[u][c]){ ch[u][c]=++cnt; u=cnt; }else{ u=ch[u][c]; } if(b[u])last=u; } fa[++cntt]=a[last]; a[u]=cntt; b[u]=1; } void dfsprint(int now){ cout<<now; for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){ dfsprint(*it); } } void dfs1(int now){ if(to[now].size()==0){ size[now]=1; }else{ size[now]=1; for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){ dfs1(*it); size[now]+=size[*it]; } } } void dfs2(int now,int last){ //cout<<now<<endl; ++cnttt; wz[now]=cnttt;//这里纠正了wz数组的错误 ans+=cnttt-wz[last]; sort(to[now].begin(),to[now].end(),cmp); for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){ dfs2(*it,now); } } int main() { cin>>n; for(int i=0;i<=26;i++)ch[0][i]=1; a[1]=0; for(int i=1;i<=n;i++){ cin>>str; biuld(); } for(int i=1;i<=cntt;i++)to[fa[i]].push_back(i); dfs1(0); wz[0]=0; sort(to[0].begin(),to[0].end(),cmp); for(vector<int>::iterator it=to[0].begin();it!=to[0].end();it++){ dfs2(*it,0); } //dfsprint(0); cout<<ans; return 0; }
但这样的程序还是有问题,观察一下,我们发现:在biuld过程中,last本应指向最近的后缀祖先,但前提是祖先已经插入,由于数据的输入是无序的,所以将会导致致命的错误,我们需要将要插入的字符串按后缀排序。
50分WA代码:
#include <bits/stdc++.h> using namespace std; int n; string str[100005]; int b[100005]; int ch[510005][30]; vector<int>to[100005];//dfs前向星链表 map <int,int> a;//trie到dfs的映射 int fa[100005]; int size[100005]; int wz[100005]; int cnttt=0; int cntt=0; int cnt=1; long long ans=0; bool cmp(int x,int y){ return size[x]<size[y]; } bool cmpstr(string x,string y){ return x<y; } void biuld(string ccc){ int u=1; int last=1; for(int i=0;i<ccc.size();i++){ int c=ccc[i]-'a'; if(!ch[u][c]){ ch[u][c]=++cnt; u=cnt; }else{ u=ch[u][c]; } if(b[u])last=u; } fa[++cntt]=a[last]; a[u]=cntt; b[u]=1; } void dfsprint(int now){ cout<<now; for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){ dfsprint(*it); } } void dfs1(int now){ if(to[now].size()==0){ size[now]=1; }else{ size[now]=1; for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){ dfs1(*it); size[now]+=size[*it]; } } } void dfs2(int now,int last){ //cout<<now<<endl; ++cnttt; wz[now]=cnttt; ans+=cnttt-wz[last]; sort(to[now].begin(),to[now].end(),cmp); for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){ dfs2(*it,now); } } int main() { cin>>n; for(int i=0;i<=26;i++)ch[0][i]=1; a[1]=0; string fk; for(int i=1;i<=n;i++){ cin>>fk; int len=fk.size(); str[i]=fk; for(int j=0;j<len;j++)str[i][j]=fk[len-1-j];//将字符串翻转 } sort(str+1,str+n+1);//排序 //for(int i=1;i<=n;i++)cout<<str[i]<<endl; for(int i=1;i<=n;i++)biuld(str[i]); for(int i=1;i<=cntt;i++)to[fa[i]].push_back(i); dfs1(0); wz[0]=0; sort(to[0].begin(),to[0].end(),cmp); for(vector<int>::iterator it=to[0].begin();it!=to[0].end();it++){ dfs2(*it,0); } //dfsprint(0); //for(int i=1;i<=n;i++)cout<<size[i]<<endl; cout<<ans; return 0; }
这里解决了字符串错乱的问题,但仍然WA了。百思不得其解后我去loj测了一下,原来是RE,看来洛谷的机子不行啊。
100分AC代码:
#include <bits/stdc++.h> using namespace std; long long n;//三年OI一场空,不开long long见祖宗 string str[510005];//数组开大一点 long long b[510005]; long long ch[510005][30]; vector<long long>to[510005];//dfs前向星链表 map <long long,long long> a;//trie到dfs的映射 long long fa[510005]; long long size[510005]; long long wz[510005]; long long cnttt=0; long long cntt=0; long long cnt=1; long long ans=0; bool cmp(long long x,long long y){ return size[x]<size[y]; } void biuld(string ccc){ long long u=1; long long last=1; for(long long i=0;i<ccc.size();i++){ long long c=ccc[i]-'a'; if(!ch[u][c]){ ch[u][c]=++cnt; u=cnt; }else{ u=ch[u][c]; } if(b[u])last=u; } fa[++cntt]=a[last]; a[u]=cntt; b[u]=1; } void dfsprint(long long now){ cout<<now; for(vector<long long>::iterator it=to[now].begin();it!=to[now].end();it++){ dfsprint(*it); } } void dfs1(long long now){ if(to[now].size()==0){ size[now]=1; }else{ size[now]=1; for(vector<long long>::iterator it=to[now].begin();it!=to[now].end();it++){ dfs1(*it); size[now]+=size[*it]; } } } void dfs2(long long now,long long last){ //cout<<now<<endl; ++cnttt; wz[now]=cnttt; ans+=cnttt-wz[last]; sort(to[now].begin(),to[now].end(),cmp); for(vector<long long>::iterator it=to[now].begin();it!=to[now].end();it++){ dfs2(*it,now); } } int main() { cin>>n; for(long long i=0;i<=26;i++)ch[0][i]=1; a[1]=0; string fk; for(long long i=1;i<=n;i++){ cin>>fk; long long len=fk.size(); str[i]=fk; for(long long j=0;j<len;j++)str[i][j]=fk[len-1-j]; } sort(str+1,str+n+1); //for(long long i=1;i<=n;i++)cout<<str[i]<<endl; for(long long i=1;i<=n;i++)biuld(str[i]); for(long long i=1;i<=cntt;i++)to[fa[i]].push_back(i); dfs1(0); wz[0]=0; sort(to[0].begin(),to[0].end(),cmp); for(vector<long long>::iterator it=to[0].begin();it!=to[0].end();it++)dfs2(*it,0); //dfsprint(0); //for(long long i=1;i<=n;i++)cout<<size[i]<<endl; cout<<ans; return 0; }
大功告成!