【BZOJ4911】[SDOI2017]切树游戏(动态dp,FWT)

【BZOJ4911】[SDOI2017]切树游戏(动态dp,FWT)

题面

BZOJ
洛谷
LOJ

题解

首先考虑如何暴力\(dp\),设\(f[i][S]\)表示当前以\(i\)节点为根节点,联通子树权值和为\(S\)的方案数,转移就是\(FWT\)的卷积,最后只需要把所有的\(f[i][k]\)全部加起来就可以得到最终的答案。
于是这样子的复杂度就是\(O(Qnmlogm)\)。但实际上转移的时候不需要\(FWT\)回来,直接拿点值表示的数组做就可以了,这样子可以少一个\(log\)
那么我们我们额外设一个变量\(S_u\)表示其子树内所有的\(f[u]\)的和。
令矩阵的每个元素都是一个长度为\(m\)的向量,向量的乘法就是每一位对应乘,加法就是每一位对应加,\(0,1\)分别表示全\(0\)、全\(1\)的向量。那么可以得到转移:

\[\begin{bmatrix}f'_u&0&f'_u\\f'_u&1&f'_u+S'_u\\0&0&1\end{bmatrix}\times \begin{bmatrix}f_v\\S_v\\1\end{bmatrix}=\begin{bmatrix}f_u\\S_u\\1\end{bmatrix} \]

其中\(f'_u,S'_u\)表示只考虑轻儿子的转移的结果,或者说只有重儿子没有转移的结果。
这样子单次矩乘的复杂度似乎是\(27m\),但是发现很多位置是\(0\),所以可以手动算一算结果:

\[\begin{bmatrix}a_1&0&b_1\\c_1&1&d_1\\0&0&1\end{bmatrix}\times\begin{bmatrix}a_2&0&b_2\\c_2&1&d_2\\0&0&1\end{bmatrix}=\begin{bmatrix}a_1a_2&0&a_1b_2+b_1\\c_1a_2+c_2&1&c_1b_2+d_2+d_1\\0&0&1\end{bmatrix} \]

这样子就只需要维护\(4\)个地方的值了,那么常数就大大的减少了。
然后我因为全部都是\(operator\),导致常数空间很大,所以就得开\(short\ int\)
本机、洛谷、LOJ都能过,\(BZOJ\ TLE\)

#include<iostream>
#include<cstdio>
using namespace std;
#define MOD 10007
#define inv2 5004
#define MAX 30030
inline int read()
{
	int x=0;bool t=false;char ch=getchar();
	while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
	if(ch=='-')t=true,ch=getchar();
	while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
	return t?-x:x;
}
int n,m,Q,V[MAX],inv[MOD];char ch[10];
struct Number{short a,z;short v(){return z?0:a;}};
Number operator+(Number a,Number b){return (Number){(a.v()+b.v())%MOD,0};}
Number operator-(Number a,Number b){return (Number){(a.v()-b.v()+MOD)%MOD,0};};
Number operator*(Number a,Number b)
{
	int x=b.v();
	if(x)a.a=a.a*x%MOD;
	else a.z+=1;
	return a;
}
Number operator*(Number a,int b){return a*(Number){b,0};}
Number operator/(Number a,Number b)
{
	int x=b.v();
	if(x)a.a=a.a*inv[x]%MOD;
	else a.z-=1;
	return a;
}
Number operator/(Number a,int b){return a/(Number){b,0};}
struct Array{Number s[128];}f[MAX],S[MAX],pre[129];
Array operator*(Array a,Array b){for(int i=0;i<m;++i)a.s[i]=a.s[i]*b.s[i];return a;}
Array operator+(Array a,Array b){for(int i=0;i<m;++i)a.s[i]=a.s[i]+b.s[i];return a;}
Array operator-(Array a,Array b){for(int i=0;i<m;++i)a.s[i]=a.s[i]-b.s[i];return a;}
Array operator/(Array a,Array b){for(int i=0;i<m;++i)a.s[i]=a.s[i]/b.s[i];return a;}
void FWT(Array &a,int opt)
{
	for(int i=1;i<m;i<<=1)
		for(int p=i<<1,j=0;j<m;j+=p)
			for(int k=0;k<i;++k)
			{
				Number x=a.s[j+k],y=a.s[i+j+k];
				a.s[j+k]=x+y,a.s[i+j+k]=x-y;
				if(opt==-1)a.s[j+k]=a.s[j+k]*inv2,a.s[i+j+k]=a.s[i+j+k]*inv2;
			}
}
struct Line{int v,next;}e[MAX<<1];
int h[MAX],cnt=1;
inline void Add(int u,int v){e[cnt]=(Line){v,h[u]};h[u]=cnt++;}
struct Matrix{Array a,b,c,d;}t[MAX<<2],tmp[MAX];
Matrix operator*(Matrix a,Matrix b){return (Matrix){a.a*b.a,a.a*b.b+a.b,a.c*b.a+b.c,a.d+b.d+a.c*b.b};}
int fa[MAX],dfn[MAX],tim,hson[MAX],size[MAX],top[MAX],bot[MAX],ln[MAX];
void dfs1(int u,int ff)
{
	fa[u]=ff;size[u]=1;
	for(int i=h[u];i;i=e[i].next)
	{
		int v=e[i].v;if(v==ff)continue;
		dfs1(v,u);size[u]+=size[v];
		if(size[v]>size[hson[u]])hson[u]=v;
	}
}
void dfs2(int u,int tp)
{
	top[u]=tp;dfn[u]=++tim,ln[tim]=u;
	if(hson[u])dfs2(hson[u],tp),bot[u]=bot[hson[u]];
	else bot[u]=u;
	for(int i=h[u];i;i=e[i].next)
		if(e[i].v!=fa[u]&&e[i].v!=hson[u])
			dfs2(e[i].v,e[i].v);
}
void dp(int u,int ff)
{
	f[u]=pre[V[u]];
	for(int i=h[u];i;i=e[i].next)
	{
		int v=e[i].v;if(v==ff)continue;dp(v,u);
		f[u]=f[u]*(f[v]+pre[0]);S[u]=S[u]+S[v];
	}
	S[u]=S[u]+f[u];
}
#define lson (now<<1)
#define rson (now<<1|1)
void Build(int now,int l,int r)
{
	if(l==r)
	{
		int u=ln[l];Array f0=pre[V[u]],s0=pre[m];
		for(int i=h[u];i;i=e[i].next)
			if(e[i].v!=hson[u]&&e[i].v!=fa[u])
				f0=f0*(f[e[i].v]+pre[0]),s0=s0+S[e[i].v];
		tmp[l]=t[now]=(Matrix){f0,f0,f0,s0+f0};
		return;
	}
	int mid=(l+r)>>1;
	Build(lson,l,mid);Build(rson,mid+1,r);
	t[now]=t[rson]*t[lson];
}
void Modify(int now,int l,int r,int p)
{
	if(l==r){t[now]=tmp[l];return;}
	int mid=(l+r)>>1;
	if(p<=mid)Modify(lson,l,mid,p);
	else Modify(rson,mid+1,r,p);
	t[now]=t[rson]*t[lson];
}
Matrix Query(int now,int l,int r,int L,int R)
{
	if(L==l&&r==R)return t[now];
	int mid=(l+r)>>1;
	if(R<=mid)return Query(lson,l,mid,L,R);
	if(L>mid)return Query(rson,mid+1,r,L,R);
	return Query(rson,mid+1,r,mid+1,R)*Query(lson,l,mid,L,mid);
}
Matrix GetTop(int x){return Query(1,1,n,dfn[top[x]],dfn[bot[x]]);}
void Modify(int u,int y)
{
	Array f0=tmp[dfn[u]].a,s0=tmp[dfn[u]].d;s0=s0-f0;
	f0=f0/pre[V[u]];f0=f0*pre[y];V[u]=y;
	tmp[dfn[u]]=(Matrix){f0,f0,f0,s0+f0};
	while(u)
	{
		Matrix a=GetTop(u);
		Modify(1,1,n,dfn[u]);
		Matrix b=GetTop(u);
		u=fa[top[u]];if(!u)break;int x=dfn[u];
		f0=tmp[x].a;s0=tmp[x].d;s0=s0-f0;
		f0=f0/(a.c+pre[0]);f0=f0*(b.c+pre[0]);s0=s0-a.d;s0=s0+b.d;
		tmp[x]=(Matrix){f0,f0,f0,s0+f0};
	}
}
int main()
{
	n=read(),m=read();
	for(int i=1;i<=n;++i)V[i]=read();
	for(int i=1,u,v;i<n;++i)u=read(),v=read(),Add(u,v),Add(v,u);
	inv[0]=inv[1]=1;for(int i=2;i<MOD;++i)inv[i]=inv[MOD%i]*(MOD-MOD/i)%MOD;
	for(int i=0;i<m;++i)pre[i].s[i]=(Number){1,0},FWT(pre[i],1);
	dfs1(1,0);dfs2(1,1);dp(1,0);Build(1,1,n);
	Q=read();
	while(Q--)
	{
		scanf("%s",ch);
		if(ch[0]=='Q')
		{
			int k=read();
			Array ans=GetTop(1).d;
			for(int i=0;i<m;++i)ans.s[i].a=ans.s[i].v();
			FWT(ans,-1);
			printf("%d\n",ans.s[k].v());
		}
		else
		{
			int x=read(),y=read();
			Modify(x,y);
		}
	}
	return 0;
}
posted @ 2019-03-21 11:48  小蒟蒻yyb  阅读(731)  评论(0编辑  收藏  举报