bzoj 4911: [Sdoi2017]切树游戏
题目描述
Solution
考虑暴力DP:设 \(f[x][i]\) 表示 \(x\) 子树内, \(x\) 作为深度最小的点的连通块的数量
\(f[x][i]=f[x][j]*f[u][k]\,j \bigoplus k=i\)
这个过程可以用 \(FWT\) 优化
由于有修改,用链分治动态维护这个DP
按树链剖分的方法,把树分成若干条重链
每一条重链看作一个序列 \(P_L,...P_R\),按照深度从 \(L\) 到 \(R\) 递减的顺序排列,线段树维护
分别记录以下东西:
\(sum[x][i]\) 表示线段树中 \(x\) 所代表的区间的异或和为 \(i\) 的连通块的答案和
\(li[x][i]\) 表示 线段树中 \(x\) 所代表的区间中包含左端点的异或和为 \(i\) 的连通块的答案和
\(ri[x][i]\) 表示 线段树中 \(x\) 所代表的区间中包含右端点的异或和为 \(i\) 的连通块的答案和
\(siz[x][i]\) 表示 线段树中 \(x\) 所代表的区间 \([L,R]\) 这个完整的异或和为 \(i\) 的连通块的答案(也就是每一个位置权值的乘积)
同一条链的转移十分简单,考虑链与链之间的转移:
我们把这一条链直接当作 链顶的父亲 的权值就行了
更新的时候在链上暴力跳就行了
复杂度是 \(log^2\) 的
考虑这个转移是需要 \(FWT\) 优化的,复杂度又多了个 \(log\)
有一种方法优化:
我们 \(FWT\) 时,是先 \(FWT(a,1)\),再做点值多项式乘法,再转回来的过程
我们可以一开始就转好点值多项式,然后运算过程全程用点值多项式的值来代入,中间的运算过程就可以变成普通的点值乘法了
在询问的时候再 \(FWT\) 回来就行了
这样复杂度就是 \(O(n*m*log^2)\) 的了
另外注意:
\(0\) 没有逆元,由于会除以 \(0\),所以要定义一种新运算维护 \(0\) 的个数,重载一下乘除号就行了
#include<bits/stdc++.h>
#define pb push_back
using namespace std;
const int N=30005,M=130,mod=10007;
int n,m,Q,a[N],sz[N],son[N],dep[N],head[N],nxt[N*2],to[N*2],num=0;
int top[N],fa[N],inv[N],E[M][M],lis[N],tt=0,ans[M],re[M];
vector<int>v[N];
inline void link(int x,int y){nxt[++num]=head[x];to[num]=y;head[x]=num;}
inline void dfs(int x){
sz[x]=1;
for(int i=head[x];i;i=nxt[i]){
int u=to[i];if(sz[u])continue;
dep[u]=dep[x]+1;fa[u]=x;dfs(u);
sz[x]+=sz[u];if(sz[u]>sz[son[x]])son[x]=u;
}
}
inline void dfs2(int x,int tp){
top[x]=tp;
if(son[x])dfs2(son[x],tp);
for(int i=head[x];i;i=nxt[i])
if(to[i]!=fa[x] && to[i]!=son[x])dfs2(to[i],to[i]);
v[tp].pb(x);
}
inline void fwt(int *A,int o){
for(int i=1;i<m;i<<=1)
for(int j=0;j<m;j+=i<<1)
for(int k=0;k<i;k++){
int x=A[j+k],y=A[j+k+i];
if(!o)A[j+k]=(x+y)%mod,A[j+k+i]=(x-y+mod)%mod;
else A[j+k]=(x+y)*inv[2]%mod,A[j+k+i]=(x-y+mod)*inv[2]%mod;
}
}
struct data{
int a,b;
inline void biu(int x){x%=mod;if(x)a=x,b=0;else a=1,b=1;}
inline int val(){return b?0:a;}
inline void operator *=(const int x){
if(!x)b++;
else a=a*x%mod;
}
inline void operator /=(const int x){
if(!x)b--;
else a=a*inv[x]%mod;
}
}f[N][M];
void priwork(){
inv[1]=1;
for(int i=2;i<mod;i++)inv[i]=(mod-(mod/i)*inv[mod%i]%mod)%mod;
int len;for(len=1;len<m;len<<=1);m=len;
for(int i=0;i<m;i++)E[i][i]=1,fwt(E[i],0); //预处理出单位矩阵 E
//因为我们是先把 f[i][a[i]]=1 赋为 1 再转点值表达式的,我们预处理出E[i]表示把 i 赋成1时的单位多项式
for(int i=1;i<=n;i++)
for(int j=0;j<m;j++)f[i][j].biu(E[a[i]][j]);
}
inline bool comp(int i,int j){return dep[i]>dep[j];}
int ls[N*4],rs[N*4],rt[N],li[N*4][M],ri[N*4][M];
int ft[N*4],sum[N*4][M],siz[N*4][M],id[N];
inline void upd(int o){
for(int i=0;i<m;i++){
sum[o][i]=(sum[ls[o]][i]+sum[rs[o]][i]+ri[ls[o]][i]*li[rs[o]][i])%mod;
li[o][i]=(li[ls[o]][i]+li[rs[o]][i]*siz[ls[o]][i])%mod;
ri[o][i]=(ri[rs[o]][i]+ri[ls[o]][i]*siz[rs[o]][i])%mod;
siz[o][i]=siz[ls[o]][i]*siz[rs[o]][i]%mod;
}
}
inline void build(int &x,int l,int r,int t){
x=++tt;
if(l==r){
id[v[t][l]]=x;
for(int i=0;i<m;i++)
li[x][i]=ri[x][i]=sum[x][i]=siz[x][i]=f[v[t][l]][i].val();
return ;
}
int mid=(l+r)>>1;
build(ls[x],l,mid,t);build(rs[x],mid+1,r,t);
if(ls[x])ft[ls[x]]=x;if(rs[x])ft[rs[x]]=x;
upd(x);
}
inline void solve(int x){
int t=top[x];
if(fa[t])for(int i=0;i<m;i++)f[fa[t]][i]/=(ri[rt[t]][i]+E[0][i])%mod;
for(int i=0;i<m;i++)ans[i]=(ans[i]-sum[rt[t]][i]+mod)%mod;
int p=id[x];
for(int i=0;i<m;i++)
li[p][i]=ri[p][i]=sum[p][i]=siz[p][i]=f[x][i].val();
for(p=ft[p];p;p=ft[p])upd(p);
if(fa[t])for(int i=0;i<m;i++)f[fa[t]][i]*=(ri[rt[t]][i]+E[0][i])%mod;
for(int i=0;i<m;i++)ans[i]=(ans[i]+sum[rt[t]][i])%mod;
}
int main(){
freopen("pp.in","r",stdin);
freopen("pp.out","w",stdout);
int x,y;char S[8];
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)scanf("%d",&a[i]);
for(int i=1;i<n;i++)scanf("%d%d",&x,&y),link(x,y),link(y,x);
dep[1]=1;dfs(1);dfs2(1,1);
priwork();
int cnt=0;
for(int i=1;i<=n;i++)if(top[i]==i)lis[++cnt]=i;
sort(lis+1,lis+cnt+1,comp);
for(int i=1;i<=cnt;i++){
x=lis[i];
build(rt[x],0,v[x].size()-1,x);
if(fa[x])
for(int j=0;j<m;j++)f[fa[x]][j]*=(ri[rt[x]][j]+E[0][j])%mod;
for(int j=0;j<m;j++)ans[j]=(ans[j]+sum[rt[x]][j])%mod;
}
cin>>Q;
while(Q--){
scanf("%s%d",S,&x);
if(S[0]=='Q'){
for(int i=0;i<m;i++)re[i]=ans[i];
fwt(re,1);
printf("%d\n",re[x]);
}
else{
scanf("%d",&y);
for(int i=0;i<m;i++)f[x][i]/=E[a[x]][i];
a[x]=y;
for(int i=0;i<m;i++)f[x][i]*=E[a[x]][i];
for(;x;x=fa[top[x]])solve(x);
}
}
return 0;
}