【PKUSC2019】树染色【线段树合并】【树形DP】
Description
给出一棵n个点的树,现在有m种颜色,要给每个节点染色,相邻节点不能同色。
另外有k条限制,形如x号点不能为颜色y
同一节点有可能有多条限制。
求方案数对998244353取模的结果。
n<=200000,m<=1e9,k<=400000
Solution
考场上一直在想怎么容斥做,怎么都弄不出来。
学傻了。
考虑暴力DP
设\(f[i][j]\)为当前处理了以i为根的子树,i的颜色为j的方案数。
记\(g[i]=\sum\limits_{k}f[i][k]\)
显然有转移$$f[i][j]=[!ban[i][j]]\prod_{p\in son[i]}(g[p]-f[p][j])$$
但是这样的状态数是\(n*m\)的,我们发现只需要记下子树中有的颜色,其他的颜色的答案都是一样的。
这样状态数缩减到\(n*k\),但还是很大,于是我们考虑采用线段树来维护。
转移的时候我们将子树一个个的合并到根
大概是\(f[i][j]=(g[p]-f[p][j])*f[i][j]\)
根据这个我们就可以线段树合并了。
如果只有父亲有,就直接乘
儿子父亲都有暴力合并
只有儿子有的话把括号拆开,就是乘上\(-f[i][j]\)加上\(g[p]*f[i][j]\)
需要维护区间乘区间加,类似一次函数维护即可。
时间复杂度大概是\(O((n+k)\log m)\),具体可以看代码。
Code
写了个对拍没问题,姑且当它是对的吧
#include <bits/stdc++.h>
#define fo(i,a,b) for(int i=a;i<=b;++i)
#define fod(i,a,b) for(inti i=a;i>=b;--i)
#define N 200005
#define M 13000005
#define LL long long
#define mo 998244353
using namespace std;
int n,m,l,fs[N],nt[2*N],dt[2*N],m1;
vector<int> qs[N];
LL ksm(LL k,LL n)
{
LL s=1;
for(;n;n>>=1,k=k*k%mo) if(n&1) s=s*k%mo;
return s;
}
int n1,t[M][2],sz[M],rt[N];
LL sp[M],g[N],f[N],lz[M][2];
void nwp(int &k)
{
if(!k) k=++n1,lz[k][0]=1,lz[k][1]=0;
}
void ins(int k,int l,int r,int x,int v)
{
if(l==r) {sp[k]=0,sz[k]=1;return;}
int mid=(l+r)>>1;
if(x<=mid) nwp(t[k][0]),ins(t[k][0],l,mid,x,v);
else nwp(t[k][1]),ins(t[k][1],mid+1,r,x,v);
sp[k]=(sp[t[k][0]]+sp[t[k][1]]);
if(sp[k]>=mo) sp[k]-=mo;
sz[k]=sz[t[k][0]]+sz[t[k][1]];
}
LL gp,fp,fk,vs;
void upd(int k,LL u,LL v)
{
sp[k]=(u*sp[k]+v*sz[k])%mo;
lz[k][0]=lz[k][0]*u%mo;
lz[k][1]=(lz[k][1]*u%mo+v)%mo;
}
void down(int k)
{
if(lz[k][0]!=1||lz[k][1]!=0)
{
if(t[k][0]) upd(t[k][0],lz[k][0],lz[k][1]);
if(t[k][1]) upd(t[k][1],lz[k][0],lz[k][1]);
lz[k][0]=1,lz[k][1]=0;
}
}
void mrg(int &k,int x,int l,int r)
{
if(!k)
{
if(!x) return;
k=x,upd(k,mo-fk,gp*fk%mo);
return;
}
if(!x) {upd(k,(gp-fp+mo)%mo,0);return;}
if(l==r) {sp[k]=(gp-sp[x]+mo)%mo*sp[k]%mo,sz[k]=sz[k]|sz[x];return;}
int mid=(l+r)>>1;
down(k),down(x);
mrg(t[k][0],t[x][0],l,mid);
mrg(t[k][1],t[x][1],mid+1,r);
sp[k]=(sp[t[k][0]]+sp[t[k][1]])%mo;
sz[k]=sz[t[k][0]]+sz[t[k][1]];
}
void dfs(int k,int fa)
{
f[k]=1;
nwp(rt[k]);
int r=qs[k].size();
fo(j,0,r-1) ins(rt[k],1,m,qs[k][j],0);
for(int i=fs[k];i;i=nt[i])
{
int p=dt[i];
if(p!=fa)
{
dfs(p,k);
gp=g[p],fk=f[k],fp=f[p];
mrg(rt[k],rt[p],1,m);
f[k]=(g[p]-f[p]+mo)%mo*f[k]%mo;
}
}
g[k]=(f[k]*(LL)(m-sz[rt[k]])%mo+sp[rt[k]])%mo;
}
void link(int x,int y)
{
nt[++m1]=fs[x];
dt[fs[x]=m1]=y;
}
int main()
{
cin>>n>>m>>l;
fo(i,1,n-1)
{
int x,y;
scanf("%d%d",&x,&y);
link(x,y),link(y,x);
}
fo(i,1,l)
{
int x,y;
scanf("%d%d",&x,&y);
qs[x].push_back(y);
}
dfs(1,0);
printf("%lld\n",g[1]);
}