【GMOJ5363】生命之树

题目

题目链接:https://gmoj.net/senior/#main/show/5363

思路

这个异或很烦,二进制拆位搞掉。
然后两个节点能造成贡献当且仅当他们这一位下为 \(1\)。维护两棵 Trie,存二进制下这一位为 \(0/1\) 的所有字符串。
对于一个节点 \(x\),先计算它所有子树答案,然后发现需要将子树中所有节点的字符串扔到 Trie 中。暴力扔显然是不可以的,dsu on tree 搞一下即可。
计算答案的时候就在 Tire 的路劲上求一下即可。
时间复杂度 \(O(n\log n\log |S|)\)

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N=500010,LG=18;
int n,t,tot,head[N],a[N],pos[N],size[N],son[N];
ll ans[N],ans2[N];
char s[N],ss[N];

struct edge
{
	int next,to;
}e[N];

void add(int from,int to)
{
	e[++tot].to=to;
	e[tot].next=head[from];
	head[from]=tot;
}

struct Trie
{
	int tot,c[N][27],size[N];
	
	void ins(int k)
	{
		int p=1;
		size[p]++;
		for (int i=pos[k];i<pos[k+1];i++)
		{
			if (!c[p][s[i]-'a'+1]) c[p][s[i]-'a'+1]=++tot;
			p=c[p][s[i]-'a'+1];
			size[p]++;
		}
	}
	
	ll query(int k)
	{
		int p=1;
		ll sum=0;
		for (int i=pos[k];i<pos[k+1];i++)
		{
			p=c[p][s[i]-'a'+1];
			sum+=size[p];
		}
		return sum;
	}
	
	void clr(int x)
	{
		for (int i=1;i<=26;i++)
			if (c[x][i]) clr(c[x][i]),c[x][i]=0;
		size[x]=0;
	}
}trie[2];

void dfs3(int x,int fa,int rt,int val)
{
	int id=((a[x]&val)!=0);
	ans[rt]+=trie[id^1].query(x)*val;
	trie[id].ins(x);
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa) dfs3(v,x,rt,val);
	}
}

void dfs1(int x,int fa)
{
	size[x]=pos[x+1]-pos[x];
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa)
		{
			dfs1(v,x);
			size[x]+=size[v];
			if (size[v]>size[son[x]]) son[x]=v;
		}
	}
}

void dfs2(int x,int fa,int val,bool flag)
{
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa && v!=son[x]) dfs2(v,x,val,0);
	}
	if (son[x])
	{
		dfs2(son[x],x,val,1);
		ans[x]+=ans[son[x]];
	}
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa && v!=son[x]) dfs3(v,x,x,val);
	}
	int id=((a[x]&val)!=0);
	ans[x]+=trie[id^1].query(x)*val;
	trie[id].ins(x);
	if (!flag)
	{
		trie[0].clr(1); trie[0].tot=1;
		trie[1].clr(1); trie[1].tot=1;
	}
}

int main()
{
	freopen("tree.in","r",stdin);
	freopen("tree.out","w",stdout);
	memset(head,-1,sizeof(head));
	scanf("%d",&n);
	for (int i=1;i<=n;i++)
		scanf("%d",&a[i]);
	pos[1]=1;
	for (int i=1;i<=n;i++)
	{
		scanf("%s",ss+1);
		int len=strlen(ss+1);
		for (int j=pos[i];j<pos[i]+len;j++)
			s[j]=ss[j-pos[i]+1];
		pos[i+1]=pos[i]+len;
	}
	for (int i=1,x,y;i<n;i++)
	{
		scanf("%d%d",&x,&y);
		add(x,y); add(y,x);
	}
	dfs1(1,0);
	trie[0].tot=trie[1].tot=1;
	for (int i=0;i<=LG;i++)
	{
		dfs2(1,0,(1<<i),0);
		for (int j=1;j<=n;j++)
			ans2[j]+=ans[j],ans[j]=0;
	}
	for (int i=1;i<=n;i++)
		printf("%lld\n",ans2[i]);
	return 0;
}
posted @ 2020-11-02 15:38  stoorz  阅读(69)  评论(0编辑  收藏  举报