背单词(AC自动机+线段树+dp+dfs序)

G. 背单词

内存限制:256 MiB 时间限制:1000 ms 标准输入输出
题目类型:传统 评测方式:文本比较
 

题目描述

给定一张包含N个单词的表,每个单词有个价值W。要求从中选出一个子序列使得其 中的每个单词是后一个单词的子串,最大化子序列中W的和。

输入格式

第一行一个整数TEST,表示数据组数。 接下来TEST组数据,每组数据第一行为一个整数N。 接下来N行,每行为一个字符串和一个整数W。

输出格式

TEST行,每行一个整数,表示W的和的最大值。

数据规模 设字符串的总长度为Len 30.的数据满足,TEST≤5,N≤500,Len≤10^4 100.的数据满足,TEST≤10,N≤20000,Len≤3*10^5


析:又是一道好题,这道题将AC自动机与线段树,dp,以及 dfs 序结合了起来;

  首先我们要明确这样一个事情,S是T的字串,相当于 T  的一个前缀可以通过 Fail 树遍历到 S 的末尾结点,也就是说, S 的末尾节点是 T 的某个前缀在 Fail 树上的祖先;

  那么这道题思路就清晰了,首先可以写出 dp 方程 :f[i]=max(x)+w[i] ,表示 在前 i 个单词中,当前枚举到第 i 个单词且选择它的最大值, max(x) 表示当前单词前缀的最大值;

  那么此时我们的问题就在于 1.如何求得前缀? 2.如何求得区间(单点)最大值?

  对于第一个问题,我们可以使用 fa 数组进行回溯查询:

 if(!use[now].son[p])
        {
            use[now].son[p]=++num;
            fa[use[now].son[p]]=now;
        }

   这是在构建字典树的过程中记录每一个字符的父亲节点,我们在计算过程中就可以:

 while(p)
        {
            f[i]=max(f[i],query(rt,1,num,l[p]));
            p=fa[p];
        }

   那么考虑第二个问题,区间查询,我们同时还要考虑到,每次在我们枚举一次选择的单词后,我们都要判断选或不选哪个是最优解,然后对一段区间进行更新,所以说,我们不仅需要区间查询,还需要区间修改的操作,

   看这数据范围,显然我们可以想到线段树。那么我们在查询,修改的过程中如何确定区间呢? 这里利用 dfs 序就是一种很妙的思路,我们求出 in[],与out[] 就可以知道当前单词的控制区间。

  那么问题来了,我们要构建一颗什么样的 dfs 树呢?

  显然,若考虑当前单词 i ,fail[i] ,fa[i],fail[fa[i]] ,那么很明显,fail[fa[i]] 应该是控制区间最大的那一个,所以,我们就要从每个节点的 fail 指针向当前单词 连一条边,进行 dfs ;

  这里我们再解释一下为什么是单点查询,注意题目要求:

>  从中选出一个子序列使得其中的每个单词是后一个单词的子串

  1.假设现在有个序列 ABCD ,那么假设我的单词分别为 AB , 和 CD,那么如果我两个同时拿的话就无法满足题义

  2.若每次我们都之考虑某一个前缀的最大值,那么递推过来的一定是满足条件的最大值!!

  那么,在这颗线段树中,我们要维护的就是每个单词选或不选的最大值,所以在区间更新的时候我们都要取 max,到这里应该解释的差不多了;

 代码:

#include<bits/stdc++.h>
#define re register int
using namespace std;
const int N=3e6+10;
int T,n,cnt,tot,rt,timi,num=1;
char s[N];
int w[N],ed[N],l[N],r[N],fa[N];
int head[N],to[N<<1],next[N<<1];
long long f[N];
long long maxx;
bool vis[N];
queue<int> q;
struct CUN
{
    int flag,fail;
    int son[30];
    void clean()
    {
        flag=0;
        fail=0;
        memset(son,0,sizeof(son));
    }
}use[N];
struct C2
{
    int lc,rc,sum,lazy;
    void clean()
    {
    	lc=0;
    	rc=0;
    	sum=0;
    	lazy=0;
    }
}t[N];
void in()
{
    for(re i=0;i<=max(cnt,tot);i++)
        use[i].clean();
    for(re i=0;i<=max(cnt,tot);i++)
    	t[i].clean();
    cnt=0;
    num=1;
    maxx=0;
    tot=0;
    timi=0;
    memset(l,0,sizeof(l));
    memset(r,0,sizeof(r));
    memset(vis,0,sizeof(vis));
    memset(f,0,sizeof(f));
    memset(w,0,sizeof(w));
    memset(ed,0,sizeof(ed));
    memset(head,0,sizeof(head));
    memset(to,0,sizeof(to));
    memset(next,0,sizeof(next));
    memset(fa,0,sizeof(fa));
    while(!q.empty())
        q.pop();
}
void insert(char ss[],int pos)
{
    int now=1;
    int l=strlen(ss);
    for(re i=0;i<l;i++)
    {
        int p=ss[i]-'a';
        if(!use[now].son[p])
        {
            use[now].son[p]=++num;
        	fa[use[now].son[p]]=now;
        }
        now=use[now].son[p];
    }
    ed[pos]=now;
}
void Add(int x,int y)
{
    to[++tot]=y;
    next[tot]=head[x];
    head[x]=tot;
}
void dfs(int x)
{
    if(x)
        l[x]=++timi;
    for(re i=head[x];i;i=next[i])
        dfs(to[i]);
    r[x]=timi;    
}
void build(int &p,int L,int R)
{
    p=++cnt;
    if(L==R)
        return;
    int mid=(L+R)>>1;
    build(t[p].lc,L,mid);
    build(t[p].rc,mid+1,R);
}
void get_fail()
{
    for(re i=0;i<26;i++)
        use[0].son[i]=1;
    use[1].fail=0;
    q.push(1);
    while(!q.empty())
    {
        int u=q.front();
        q.pop();
        int Fail=use[u].fail;
        for(re i=0;i<26;i++)
        {
            int v=use[u].son[i];
            if(!v)
            {
                use[u].son[i]=use[Fail].son[i];
                continue;
            }
            use[v].fail=use[Fail].son[i];
            q.push(v);
        }
    }
    for(re i=1;i<=num;i++)
        Add(use[i].fail,i);
    dfs(0);
    build(rt,1,num);
}
void pd(int p)
{
    if(t[p].lazy==0)
        return;
    t[t[p].lc].lazy=max(t[t[p].lc].lazy,t[p].lazy);
    t[t[p].rc].lazy=max(t[t[p].rc].lazy,t[p].lazy);
    t[t[p].lc].sum=max(t[t[p].lc].sum,t[p].lazy);
    t[t[p].rc].sum=max(t[t[p].rc].sum,t[p].lazy);
    t[p].lazy=0;
}
void pp(int rt)
{
    t[rt].sum=max(t[t[rt].lc].sum,t[t[rt].rc].sum);
}
long long query(int rt,int L,int R,int p)
{
    if(L==R)
        return t[rt].sum;
    int mid=(L+R)>>1;
    pd(rt);
    if(p<=mid)
        return query(t[rt].lc,L,mid,p);
    return query(t[rt].rc,mid+1,R,p);
}
void updata(int p,int L,int R,int l,int r,int z)
{
    if(l<=L&&R<=r)
    {
        t[p].sum=max(t[p].sum,z);
        t[p].lazy=max(t[p].lazy,z);
        return;
    }
    pd(p);
    int mid=(L+R)>>1;
    if(mid>=l)
        updata(t[p].lc,L,mid,l,r,z);
    if(mid<r)
        updata(t[p].rc,mid+1,R,l,r,z);
    pp(p);
}
void dp()
{
    //f[i]=max{x}+w[i];
    for(re i=1;i<=n;i++)
    {
        int p=ed[i];
        while(p)
        {
            f[i]=max(f[i],query(rt,1,num,l[p]));
            p=fa[p];
        }
        f[i]=max(0*1ll,f[i]+w[i]);
        updata(rt,1,num,l[ed[i]],r[ed[i]],f[i]);
    }
    for(re i=1;i<=n;i++)
        maxx=max(maxx,f[i]);
    printf("%lld\n",maxx);
}
signed main()
{
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d",&n);
        in();
        for(re i=1;i<=n;i++)
        {
        	scanf("%s",s);
        	scanf("%d",&w[i]);
            insert(s,i);
        }
        get_fail();
        dp();
    }
}

 

posted @ 2021-06-14 07:07  WindZR  阅读(129)  评论(0编辑  收藏  举报