背单词(AC自动机+线段树+dp+dfs序)
G. 背单词
题目描述
给定一张包含N个单词的表,每个单词有个价值W。要求从中选出一个子序列使得其 中的每个单词是后一个单词的子串,最大化子序列中W的和。
输入格式
第一行一个整数TEST,表示数据组数。 接下来TEST组数据,每组数据第一行为一个整数N。 接下来N行,每行为一个字符串和一个整数W。
输出格式
TEST行,每行一个整数,表示W的和的最大值。
数据规模 设字符串的总长度为Len 30.的数据满足,TEST≤5,N≤500,Len≤10^4 100.的数据满足,TEST≤10,N≤20000,Len≤3*10^5
析:又是一道好题,这道题将AC自动机与线段树,dp,以及 dfs 序结合了起来;
首先我们要明确这样一个事情,S是T的字串,相当于 T 的一个前缀可以通过 Fail 树遍历到 S 的末尾结点,也就是说, S 的末尾节点是 T 的某个前缀在 Fail 树上的祖先;
那么这道题思路就清晰了,首先可以写出 dp 方程 :f[i]=max(x)+w[i] ,表示 在前 i 个单词中,当前枚举到第 i 个单词且选择它的最大值, max(x) 表示当前单词前缀的最大值;
那么此时我们的问题就在于 1.如何求得前缀? 2.如何求得区间(单点)最大值?
对于第一个问题,我们可以使用 fa 数组进行回溯查询:
if(!use[now].son[p]) { use[now].son[p]=++num; fa[use[now].son[p]]=now; }
这是在构建字典树的过程中记录每一个字符的父亲节点,我们在计算过程中就可以:
while(p) { f[i]=max(f[i],query(rt,1,num,l[p])); p=fa[p]; }
那么考虑第二个问题,区间查询,我们同时还要考虑到,每次在我们枚举一次选择的单词后,我们都要判断选或不选哪个是最优解,然后对一段区间进行更新,所以说,我们不仅需要区间查询,还需要区间修改的操作,
看这数据范围,显然我们可以想到线段树。那么我们在查询,修改的过程中如何确定区间呢? 这里利用 dfs 序就是一种很妙的思路,我们求出 in[],与out[] 就可以知道当前单词的控制区间。
那么问题来了,我们要构建一颗什么样的 dfs 树呢?
显然,若考虑当前单词 i ,fail[i] ,fa[i],fail[fa[i]] ,那么很明显,fail[fa[i]] 应该是控制区间最大的那一个,所以,我们就要从每个节点的 fail 指针向当前单词 连一条边,进行 dfs ;
这里我们再解释一下为什么是单点查询,注意题目要求:
> 从中选出一个子序列使得其中的每个单词是后一个单词的子串
1.假设现在有个序列 ABCD ,那么假设我的单词分别为 AB , 和 CD,那么如果我两个同时拿的话就无法满足题义
2.若每次我们都之考虑某一个前缀的最大值,那么递推过来的一定是满足条件的最大值!!
那么,在这颗线段树中,我们要维护的就是每个单词选或不选的最大值,所以在区间更新的时候我们都要取 max,到这里应该解释的差不多了;
代码:
#include<bits/stdc++.h> #define re register int using namespace std; const int N=3e6+10; int T,n,cnt,tot,rt,timi,num=1; char s[N]; int w[N],ed[N],l[N],r[N],fa[N]; int head[N],to[N<<1],next[N<<1]; long long f[N]; long long maxx; bool vis[N]; queue<int> q; struct CUN { int flag,fail; int son[30]; void clean() { flag=0; fail=0; memset(son,0,sizeof(son)); } }use[N]; struct C2 { int lc,rc,sum,lazy; void clean() { lc=0; rc=0; sum=0; lazy=0; } }t[N]; void in() { for(re i=0;i<=max(cnt,tot);i++) use[i].clean(); for(re i=0;i<=max(cnt,tot);i++) t[i].clean(); cnt=0; num=1; maxx=0; tot=0; timi=0; memset(l,0,sizeof(l)); memset(r,0,sizeof(r)); memset(vis,0,sizeof(vis)); memset(f,0,sizeof(f)); memset(w,0,sizeof(w)); memset(ed,0,sizeof(ed)); memset(head,0,sizeof(head)); memset(to,0,sizeof(to)); memset(next,0,sizeof(next)); memset(fa,0,sizeof(fa)); while(!q.empty()) q.pop(); } void insert(char ss[],int pos) { int now=1; int l=strlen(ss); for(re i=0;i<l;i++) { int p=ss[i]-'a'; if(!use[now].son[p]) { use[now].son[p]=++num; fa[use[now].son[p]]=now; } now=use[now].son[p]; } ed[pos]=now; } void Add(int x,int y) { to[++tot]=y; next[tot]=head[x]; head[x]=tot; } void dfs(int x) { if(x) l[x]=++timi; for(re i=head[x];i;i=next[i]) dfs(to[i]); r[x]=timi; } void build(int &p,int L,int R) { p=++cnt; if(L==R) return; int mid=(L+R)>>1; build(t[p].lc,L,mid); build(t[p].rc,mid+1,R); } void get_fail() { for(re i=0;i<26;i++) use[0].son[i]=1; use[1].fail=0; q.push(1); while(!q.empty()) { int u=q.front(); q.pop(); int Fail=use[u].fail; for(re i=0;i<26;i++) { int v=use[u].son[i]; if(!v) { use[u].son[i]=use[Fail].son[i]; continue; } use[v].fail=use[Fail].son[i]; q.push(v); } } for(re i=1;i<=num;i++) Add(use[i].fail,i); dfs(0); build(rt,1,num); } void pd(int p) { if(t[p].lazy==0) return; t[t[p].lc].lazy=max(t[t[p].lc].lazy,t[p].lazy); t[t[p].rc].lazy=max(t[t[p].rc].lazy,t[p].lazy); t[t[p].lc].sum=max(t[t[p].lc].sum,t[p].lazy); t[t[p].rc].sum=max(t[t[p].rc].sum,t[p].lazy); t[p].lazy=0; } void pp(int rt) { t[rt].sum=max(t[t[rt].lc].sum,t[t[rt].rc].sum); } long long query(int rt,int L,int R,int p) { if(L==R) return t[rt].sum; int mid=(L+R)>>1; pd(rt); if(p<=mid) return query(t[rt].lc,L,mid,p); return query(t[rt].rc,mid+1,R,p); } void updata(int p,int L,int R,int l,int r,int z) { if(l<=L&&R<=r) { t[p].sum=max(t[p].sum,z); t[p].lazy=max(t[p].lazy,z); return; } pd(p); int mid=(L+R)>>1; if(mid>=l) updata(t[p].lc,L,mid,l,r,z); if(mid<r) updata(t[p].rc,mid+1,R,l,r,z); pp(p); } void dp() { //f[i]=max{x}+w[i]; for(re i=1;i<=n;i++) { int p=ed[i]; while(p) { f[i]=max(f[i],query(rt,1,num,l[p])); p=fa[p]; } f[i]=max(0*1ll,f[i]+w[i]); updata(rt,1,num,l[ed[i]],r[ed[i]],f[i]); } for(re i=1;i<=n;i++) maxx=max(maxx,f[i]); printf("%lld\n",maxx); } signed main() { scanf("%d",&T); while(T--) { scanf("%d",&n); in(); for(re i=1;i<=n;i++) { scanf("%s",s); scanf("%d",&w[i]); insert(s,i); } get_fail(); dp(); } }