$("head").append('')

[SCOI2016]背单词

题目链接:https://www.luogu.com.cn/problem/P3294

 

思路:

  首先看到题目,又是trie的题,所以自然想到用trie来存储单词,考虑到有两种存储顺序:正序插入和倒序插入(即前缀和后缀),根据题目,选择倒序更好,这样一来单词及其后缀就都在同一条链上了。

  对于题目的理解,我刚开始以为是:选择单词顺序,选择时ans加上当前节点到最近的有结尾标记的祖先或根节点的距离,为此我甚至还想出了按深度(dfs序)安排顺序(尽量避免情况1和情况2),但不久发现我想错了,不过想出来的贪心思路还是有作用的。

  那么结合贪心思路再捋一捋题意:选择单词顺序,对于当前单词,设它的编号为x,离它最近的后缀祖先的编号为y,那么ans加上y-x,否则加上x(以贪心思路,先选择后缀,尽量避免情况1和2);然后求ans的最小值。

  既然如此,那么每次计算的时候就要找到当前单词的祖先,然而一个一个地往上查找难免费力,所以打算压缩一下路径,由于计算当中只需要单词结尾的点,所以非结尾的点可尽情扔掉。

  再想一想怎样选择,才能让ans最小,由于我太蒻了,这里只提供一个思路:设两个子树,一个子节点多一个子节点少,计算两个不同顺序的优劣性,可以证得:遇到分枝时,优先选择子树规模小的单词节点,ans最小;

  最后,再模拟算出ans,问题解决!

 

总流程一览:

  1.构建后缀trie;

  2.构建关键点树;

  3.贪心+模拟算出答案。

 

历程及注意事项:

 

20分WA代码:

#include <bits/stdc++.h>
using namespace std;
int n;
string str;
int b[100005];
int ch[510005][30];
vector<int>to[100005];//dfs前向星链表
map <int,int> a;//trie到dfs的映射 
int fa[100005];
int size[100005];
int wz[100005];
int cnttt=0;
int cntt=0;
int cnt=1;
long long ans=0;
bool cmp(int x,int y){
    return size[x]<size[y];
}
void biuld(){//构建trie
    int u=1;
    int last=1;
    for(int i=str.size()-1;i>=0;i--){
        int c=str[i]-'a';
        if(!ch[u][c]){
            ch[u][c]=++cnt;
            u=cnt;
        }else{
            u=ch[u][c];
        }
        if(b[u])last=u;
    }
    fa[++cntt]=a[last];
    a[u]=cntt;
    b[u]=1;
}
void dfsprint(int now){
    cout<<now;
    for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){
        dfsprint(*it);
    }
}
void dfs1(int now){//求size
    if(to[now].size()==0){
        size[now]=1;
    }else{
        size[now]=1;
        for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){
            dfs1(*it);
            size[now]+=size[*it];
        }
    }
}
void dfs2(int now,int last){//模拟求出ans
    ++cnttt;
    ans+=cnttt-wz[last];
    sort(to[now].begin(),to[now].end(),cmp);
    for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){
        dfs2(*it,now);
    }
}
int main() {
    cin>>n;
    for(int i=0;i<=26;i++)ch[0][i]=1;
    a[1]=0;
    for(int i=1;i<=n;i++){
        cin>>str;
        biuld();
    }
    for(int i=1;i<=cntt;i++)to[fa[i]].push_back(i);
    dfs1(0);
    wz[0]=0;
    sort(to[0].begin(),to[0].end(),cmp);
    for(vector<int>::iterator it=to[0].begin();it!=to[0].end();it++){
        dfs2(*it,0);
    }
    //dfsprint(0);
    //for(int i=0;i<=n;i++)cout<<size[i]<<endl;
    cout<<ans;
    return 0;
}

这里发现了在函数dfs2中wz数组没有更新的错误,在下面的代码进行了更新。

 

30分WA代码:

#include <bits/stdc++.h>
using namespace std;
int n;
string str;
int b[100005];
int ch[510005][30];
vector<int>to[100005];//dfs前向星链表
map <int,int> a;//trie到dfs的映射 
int fa[100005];
int size[100005];
int wz[100005];
int cnttt=0;
int cntt=0;
int cnt=1;
long long ans=0;
bool cmp(int x,int y){
    return size[x]<size[y];
}
void biuld(){
    int u=1;
    int last=1;
    for(int i=str.size()-1;i>=0;i--){
        int c=str[i]-'a';
        if(!ch[u][c]){
            ch[u][c]=++cnt;
            u=cnt;
        }else{
            u=ch[u][c];
        }
        if(b[u])last=u;
    }
    fa[++cntt]=a[last];
    a[u]=cntt;
    b[u]=1;
}
void dfsprint(int now){
    cout<<now;
    for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){
        dfsprint(*it);
    }
}
void dfs1(int now){
    if(to[now].size()==0){
        size[now]=1;
    }else{
        size[now]=1;
        for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){
            dfs1(*it);
            size[now]+=size[*it];
        }
    }
}
void dfs2(int now,int last){
    //cout<<now<<endl;
    ++cnttt;
    wz[now]=cnttt;//这里纠正了wz数组的错误
    ans+=cnttt-wz[last];
    sort(to[now].begin(),to[now].end(),cmp);
    for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){
        dfs2(*it,now);
    }
}
int main() {
    cin>>n;
    for(int i=0;i<=26;i++)ch[0][i]=1;
    a[1]=0;
    for(int i=1;i<=n;i++){
        cin>>str;
        biuld();
    }
    for(int i=1;i<=cntt;i++)to[fa[i]].push_back(i);
    dfs1(0);
    wz[0]=0;
    sort(to[0].begin(),to[0].end(),cmp);
    for(vector<int>::iterator it=to[0].begin();it!=to[0].end();it++){
        dfs2(*it,0);
    }
    //dfsprint(0);
    cout<<ans;
    return 0;
}

但这样的程序还是有问题,观察一下,我们发现:在biuld过程中,last本应指向最近的后缀祖先,但前提是祖先已经插入,由于数据的输入是无序的,所以将会导致致命的错误,我们需要将要插入的字符串按后缀排序。

 

50分WA代码:

#include <bits/stdc++.h>
using namespace std;
int n;
string str[100005];
int b[100005];
int ch[510005][30];
vector<int>to[100005];//dfs前向星链表
map <int,int> a;//trie到dfs的映射 
int fa[100005];
int size[100005];
int wz[100005];
int cnttt=0;
int cntt=0;
int cnt=1;
long long ans=0;
bool cmp(int x,int y){
    return size[x]<size[y];
}
bool cmpstr(string x,string y){
    return x<y;
}
void biuld(string ccc){
    int u=1;
    int last=1;
    for(int i=0;i<ccc.size();i++){
        int c=ccc[i]-'a';
        if(!ch[u][c]){
            ch[u][c]=++cnt;
            u=cnt;
        }else{
            u=ch[u][c];
        }
        if(b[u])last=u;
    }
    fa[++cntt]=a[last];
    a[u]=cntt;
    b[u]=1;
}
void dfsprint(int now){
    cout<<now;
    for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){
        dfsprint(*it);
    }
}
void dfs1(int now){
    if(to[now].size()==0){
        size[now]=1;
    }else{
        size[now]=1;
        for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){
            dfs1(*it);
            size[now]+=size[*it];
        }
    }
}
void dfs2(int now,int last){
    //cout<<now<<endl;
    ++cnttt;
    wz[now]=cnttt;
    ans+=cnttt-wz[last];
    sort(to[now].begin(),to[now].end(),cmp);
    for(vector<int>::iterator it=to[now].begin();it!=to[now].end();it++){
        dfs2(*it,now);
    }
}
int main() {
    cin>>n;
    for(int i=0;i<=26;i++)ch[0][i]=1;
    a[1]=0;
    string fk;
    for(int i=1;i<=n;i++){
        cin>>fk;
        int len=fk.size();
        str[i]=fk;
        for(int j=0;j<len;j++)str[i][j]=fk[len-1-j];//将字符串翻转
    }
    sort(str+1,str+n+1);//排序
    //for(int i=1;i<=n;i++)cout<<str[i]<<endl;
    for(int i=1;i<=n;i++)biuld(str[i]);
    for(int i=1;i<=cntt;i++)to[fa[i]].push_back(i);
    dfs1(0);
    wz[0]=0;
    sort(to[0].begin(),to[0].end(),cmp);
    for(vector<int>::iterator it=to[0].begin();it!=to[0].end();it++){
        dfs2(*it,0);
    }
    //dfsprint(0);
    //for(int i=1;i<=n;i++)cout<<size[i]<<endl;
    cout<<ans;
    return 0;
}

这里解决了字符串错乱的问题,但仍然WA了。百思不得其解后我去loj测了一下,原来是RE,看来洛谷的机子不行啊。

 

100分AC代码:

#include <bits/stdc++.h>
using namespace std;
long long n;//三年OI一场空,不开long long见祖宗
string str[510005];//数组开大一点
long long b[510005];
long long ch[510005][30];
vector<long long>to[510005];//dfs前向星链表
map <long long,long long> a;//trie到dfs的映射 
long long fa[510005];
long long size[510005];
long long wz[510005];
long long cnttt=0;
long long cntt=0;
long long cnt=1;
long long ans=0;
bool cmp(long long x,long long y){
    return size[x]<size[y];
}
void biuld(string ccc){
    long long u=1;
    long long last=1;
    for(long long i=0;i<ccc.size();i++){
        long long c=ccc[i]-'a';
        if(!ch[u][c]){
            ch[u][c]=++cnt;
            u=cnt;
        }else{
            u=ch[u][c];
        }
        if(b[u])last=u;
    }
    fa[++cntt]=a[last];
    a[u]=cntt;
    b[u]=1;
}
void dfsprint(long long now){
    cout<<now;
    for(vector<long long>::iterator it=to[now].begin();it!=to[now].end();it++){
        dfsprint(*it);
    }
}
void dfs1(long long now){
    if(to[now].size()==0){
        size[now]=1;
    }else{
        size[now]=1;
        for(vector<long long>::iterator it=to[now].begin();it!=to[now].end();it++){
            dfs1(*it);
            size[now]+=size[*it];
        }
    }
}
void dfs2(long long now,long long last){
    //cout<<now<<endl;
    ++cnttt;
    wz[now]=cnttt;
    ans+=cnttt-wz[last];
    sort(to[now].begin(),to[now].end(),cmp);
    for(vector<long long>::iterator it=to[now].begin();it!=to[now].end();it++){
        dfs2(*it,now);
    }
}
int main() {
    cin>>n;
    for(long long i=0;i<=26;i++)ch[0][i]=1;
    a[1]=0;
    string fk;
    for(long long i=1;i<=n;i++){
        cin>>fk;
        long long len=fk.size();
        str[i]=fk;
        for(long long j=0;j<len;j++)str[i][j]=fk[len-1-j];
    }
    sort(str+1,str+n+1);
    //for(long long i=1;i<=n;i++)cout<<str[i]<<endl;
    for(long long i=1;i<=n;i++)biuld(str[i]);
    for(long long i=1;i<=cntt;i++)to[fa[i]].push_back(i);
    dfs1(0);
    wz[0]=0;
    sort(to[0].begin(),to[0].end(),cmp);
    for(vector<long long>::iterator it=to[0].begin();it!=to[0].end();it++)dfs2(*it,0);
    //dfsprint(0);
    //for(long long i=1;i<=n;i++)cout<<size[i]<<endl;
    cout<<ans;
    return 0;
}

 

大功告成!

posted @ 2020-06-08 19:34  returnG  阅读(118)  评论(0编辑  收藏  举报