【PKUSC2019】树染色【线段树合并】【树形DP】

Description

给出一棵n个点的树,现在有m种颜色,要给每个节点染色,相邻节点不能同色。
另外有k条限制,形如x号点不能为颜色y
同一节点有可能有多条限制。
求方案数对998244353取模的结果。

n<=200000,m<=1e9,k<=400000

Solution

考场上一直在想怎么容斥做,怎么都弄不出来。
学傻了。

考虑暴力DP
\(f[i][j]\)为当前处理了以i为根的子树,i的颜色为j的方案数。
\(g[i]=\sum\limits_{k}f[i][k]\)

显然有转移$$f[i][j]=[!ban[i][j]]\prod_{p\in son[i]}(g[p]-f[p][j])$$

但是这样的状态数是\(n*m\)的,我们发现只需要记下子树中有的颜色,其他的颜色的答案都是一样的。
这样状态数缩减到\(n*k\),但还是很大,于是我们考虑采用线段树来维护。

转移的时候我们将子树一个个的合并到根
大概是\(f[i][j]=(g[p]-f[p][j])*f[i][j]\)
根据这个我们就可以线段树合并了。
如果只有父亲有,就直接乘
儿子父亲都有暴力合并
只有儿子有的话把括号拆开,就是乘上\(-f[i][j]\)加上\(g[p]*f[i][j]\)
需要维护区间乘区间加,类似一次函数维护即可。
时间复杂度大概是\(O((n+k)\log m)\),具体可以看代码。

Code

写了个对拍没问题,姑且当它是对的吧

#include <bits/stdc++.h>
#define fo(i,a,b) for(int i=a;i<=b;++i)
#define fod(i,a,b) for(inti i=a;i>=b;--i)
#define N 200005
#define M 13000005
#define LL long long
#define mo 998244353
using namespace std;
int n,m,l,fs[N],nt[2*N],dt[2*N],m1;
vector<int> qs[N];
LL ksm(LL k,LL n)
{
	LL s=1;
	for(;n;n>>=1,k=k*k%mo) if(n&1) s=s*k%mo;
	return s;
}
int n1,t[M][2],sz[M],rt[N];
LL sp[M],g[N],f[N],lz[M][2];
void nwp(int &k)
{
	if(!k) k=++n1,lz[k][0]=1,lz[k][1]=0;
}
void ins(int k,int l,int r,int x,int v)
{
	if(l==r) {sp[k]=0,sz[k]=1;return;}
	int mid=(l+r)>>1;
	if(x<=mid) nwp(t[k][0]),ins(t[k][0],l,mid,x,v);
	else nwp(t[k][1]),ins(t[k][1],mid+1,r,x,v);
	sp[k]=(sp[t[k][0]]+sp[t[k][1]]);
	if(sp[k]>=mo) sp[k]-=mo;
	sz[k]=sz[t[k][0]]+sz[t[k][1]];
}
LL gp,fp,fk,vs;
void upd(int k,LL u,LL v)
{
	sp[k]=(u*sp[k]+v*sz[k])%mo;
	lz[k][0]=lz[k][0]*u%mo;
	lz[k][1]=(lz[k][1]*u%mo+v)%mo;
}
void down(int k)
{
	if(lz[k][0]!=1||lz[k][1]!=0)
	{
		if(t[k][0]) upd(t[k][0],lz[k][0],lz[k][1]);
		if(t[k][1]) upd(t[k][1],lz[k][0],lz[k][1]);
		lz[k][0]=1,lz[k][1]=0;	
	}	
}
void mrg(int &k,int x,int l,int r)
{
	if(!k)
	{
		if(!x) return;
		k=x,upd(k,mo-fk,gp*fk%mo);
		return;
	}
	if(!x) {upd(k,(gp-fp+mo)%mo,0);return;}
	if(l==r) {sp[k]=(gp-sp[x]+mo)%mo*sp[k]%mo,sz[k]=sz[k]|sz[x];return;}
	int mid=(l+r)>>1;
	down(k),down(x);
	mrg(t[k][0],t[x][0],l,mid);
	mrg(t[k][1],t[x][1],mid+1,r);
	sp[k]=(sp[t[k][0]]+sp[t[k][1]])%mo;
	sz[k]=sz[t[k][0]]+sz[t[k][1]];
}
void dfs(int k,int fa)
{
	f[k]=1;
	nwp(rt[k]);
	int r=qs[k].size();
	fo(j,0,r-1) ins(rt[k],1,m,qs[k][j],0);
	for(int i=fs[k];i;i=nt[i])
	{
		int p=dt[i];
		if(p!=fa) 
		{
			dfs(p,k);
			gp=g[p],fk=f[k],fp=f[p];
			mrg(rt[k],rt[p],1,m);
			f[k]=(g[p]-f[p]+mo)%mo*f[k]%mo;
		}
	}
	g[k]=(f[k]*(LL)(m-sz[rt[k]])%mo+sp[rt[k]])%mo;
}
void link(int x,int y)
{
	nt[++m1]=fs[x];
	dt[fs[x]=m1]=y;
}
int main()
{
	cin>>n>>m>>l;
	fo(i,1,n-1)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		link(x,y),link(y,x);		
	}
	fo(i,1,l)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		qs[x].push_back(y);
	}
	dfs(1,0);
	printf("%lld\n",g[1]);
}
posted @ 2019-05-29 11:56  BAJim_H  阅读(714)  评论(0编辑  收藏  举报