【BZOJ2905】背单词 fail树+DFS序+线段树

【BZOJ2905】背单词

Description

给定一张包含N个单词的表,每个单词有个价值W。要求从中选出一个子序列使得其中的每个单词是后一个单词的子串,最大化子序列中W的和。 

Input

第一行一个整数TEST,表示数据组数。 
接下来TEST组数据,每组数据第一行为一个整数N。 
接下来N行,每行为一个字符串和一个整数W。 

Output

TEST行,每行一个整数,表示W的和的最大值。 
数据规模 
设字符串的总长度为Len 
30.的数据满足,TEST≤5,N≤500,Len≤10^4 
100.的数据满足,TEST≤10,N≤20000,Len≤3*10^5

题解:先建出AC自动机,然后A串是B串的子串当且仅当B中某个节点沿着fail树往根走,能走到A的结束节点。那么我们先将权值<=0的串都扔掉,然后从前往后枚举每个字符串,对于每个串,我们查询它的每个节点到根路径上的所有节点的DP值的最大值,然后用最大值+当前串的价值得到当前点的DP值,最后将当前串的DP值存到当前串的结束节点位置上。

查询最大值的时候可以将链查询,点修改变成点查询,子树修改,然后用线段树维护即可。

#include <cstdio>
#include <cstring>
#include <iostream>
#include <queue>
#include <vector>
#define lson x<<1
#define rson x<<1|1
using namespace std;
const int maxn=20010;
const int maxm=300010;
typedef long long ll;
struct node
{
	int ch[26],fail;
}p[maxm];
queue<int> q;
int T,n,tot,cnt;
ll ans,sum;
int v[maxn],to[maxm],next[maxm],head[maxm],p1[maxm],p2[maxm];
char str[maxm];
vector<int> pos[maxn];
ll s[maxm<<2],tag[maxm<<2];
inline void build()
{
	register int i,u;
	q.push(1);
	while(!q.empty())
	{
		u=q.front(),q.pop();
		for(i=0;i<26;i++)
		{
			if(!p[u].ch[i])
			{
				if(u==1)	p[u].ch[i]=1;
				else	p[u].ch[i]=p[p[u].fail].ch[i];
				continue;
			}
			q.push(p[u].ch[i]);
			if(u==1)	p[p[u].ch[i]].fail=1;
			else	p[p[u].ch[i]].fail=p[p[u].fail].ch[i];
		}
	}
}
inline void add(int a,int b)
{
	to[cnt]=b,next[cnt]=head[a],head[a]=cnt++;
}
void dfs(int x)
{
	p1[x]=++p2[0];
	for(int i=head[x];i!=-1;i=next[i])	dfs(to[i]);
	p2[x]=p2[0];
}
inline void pushdown(int x)
{
	if(tag[x])
	{
		s[lson]=max(s[lson],tag[x]),s[rson]=max(s[rson],tag[x]);
		tag[lson]=max(tag[lson],tag[x]),tag[rson]=max(tag[rson],tag[x]);
		tag[x]=0;
	}
}
void updata(int l,int r,int x,int a,int b,ll c)
{
	if(a<=l&&r<=b)
	{
		s[x]=max(s[x],c),tag[x]=max(tag[x],c);
		return ;
	}
	pushdown(x);
	int mid=(l+r)>>1;
	if(a<=mid)	updata(l,mid,lson,a,b,c);
	if(b>mid)	updata(mid+1,r,rson,a,b,c);
	s[x]=max(s[lson],s[rson]);
}
ll query(int l,int r,int x,int a)
{
	if(l==r)	return s[x];
	pushdown(x);
	int mid=(l+r)>>1;
	if(a<=mid)	return query(l,mid,lson,a);
	return query(mid+1,r,rson,a);
}
inline void work()
{
	register int i,j,a,b,u;
	memset(s,0,sizeof(s[0])*4*(tot+1)),memset(tag,0,sizeof(tag[0])*4*(tot+1)),memset(p,0,sizeof(p[0])*(tot+1));
	scanf("%d",&n);
	tot=1,cnt=p2[0]=0,ans=0;
	for(i=1;i<=n;i++)
	{
		scanf("%s%d",str,&v[i]),a=strlen(str);
		if(v[i]<=0)	continue;
		pos[i].clear();
		for(u=1,j=0;j<a;j++)
		{
			b=str[j]-'a';
			if(!p[u].ch[b])	p[u].ch[b]=++tot;
			u=p[u].ch[b],pos[i].push_back(u);
		}
	}
	build();
	memset(head,-1,sizeof(head[0])*(tot+1));
	for(i=2;i<=tot;i++)	add(p[i].fail,i);
	dfs(1);
	for(i=1;i<=n;i++)	if(v[i]>0)
	{
		for(a=pos[i].size(),sum=0,j=0;j<a;j++)	sum=max(sum,query(1,tot,1,p1[pos[i][j]]));
		sum+=v[i],ans=max(ans,sum);
		updata(1,tot,1,p1[pos[i][a-1]],p2[pos[i][a-1]],sum);
	}
	printf("%lld\n",ans);
}
int main()
{
	//freopen("data.in","r",stdin);
	//freopen("data.out","w",stdout);
	scanf("%d",&T);
	while(T--)	work();
	return 0;
}//1 5 a 1 ab 1 ac 4 abc 2 aa 1
posted @ 2017-11-10 08:55  CQzhangyu  阅读(586)  评论(0编辑  收藏  举报