[luogu3781]切树游戏

考虑暴力的dp,即用$f_{i,j}$表示以$i$为根的子树内,强制$i$必须选且异或为$j$的方案数,转移用FWT即可,求出该dp数组的时间复杂度为$o(nm\log_{2}m)$

由于是全局的方案数,再记录一个$sum_{i,j}=f_{i,j}+\sum_{son}sum_{son,j}$,那么即求$sum_{1,x}$

令$f'_{i}=FWT(f_{i})$,则有$f'_{i,j}=a_{i,j}\prod_{son}(f'_{son,j}+1)$(其中$a_{i,j}$指点$i$初始的dp数组(即$f_{i,v_{i}}=1$)FWT后的结果,加1是最后对$f_{son,0}$加1,FWT后即对所有位置加1)

根据FWT的分配律,可得$sum'_{i,j}=f'_{i,j}+\sum_{son}sum'_{son,j}$,最后求出$sum_{1}=IFWT(sum'_{1})$即可

这样做单次询问复杂度降为$o(nm)$,但还是无法通过

注意到这样的每一个$j$除了在最后$IFWT$以外,都是独立的,因此考虑求某一个$sum'_{1,j}$,以下就省略数组的第二维(都是$j$)

对其树链剖分,记其重儿子为$hs_{k}$,先统计轻儿子的信息,即:

令$g_{k}=a_{k}\prod_{son\ne hs_{k}}(f'_{son}+1)$那么就有$f'_{k}=g_{k}(f'_{hs_{k}}+1)$

令$h_{k}=\sum_{son\ne hs_{k}}sum'_{son}$,则$sum'_{k}=h_{k}+sum'_{hs_{k}}+f_{k}$

考虑一条重链的维护,构建矩阵$A_{k}=[1\ f'_{k}\ sum'_{k}]$,那么即$A_{k}=A_{hs_{k}}\begin{bmatrix}1& g_{k}&h_{k}+g_{k}\\0&g_{k}&g_{k}\\0&0&1\end{bmatrix}$

根据矩阵乘法的结合律,用线段树维护区间转移矩阵的乘积,再通过将该点直至重链尾部的转移矩阵全部乘起来(初始状态为$[1\ 0\ 0]$),即可求出每一个$k$的$f'_{k}$以及$sum'_{k}$(询问即$k=1$)

对于修改,会改变$k$的转移矩阵,即改变了$A_{top}$(重链顶端),将其求出后再根据轻链的转移修改到$g_{fa_{top}}$和$h_{fa_{top}}$,重复此过程即可,复杂度即为$o(3^{3}q\log^{2}n)$

(特别的,对于$g_{k}$需要存储其轻儿子中0的个数,来支持除法)

事实上,矩阵只需要维护右上角的4个位置(其余位置相乘后不变),复杂度降为$o(2^{2}q\log^{2}n)$,

(另外,矩阵乘法不具备交换律,因此线段树上要右边乘左边)

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 #define N 30005
  4 #define M (1<<7)
  5 #define mod 10007
  6 #define L (k<<1)
  7 #define R (L+1)
  8 #define mid (l+r>>1)
  9 struct ji{
 10     int nex,to;
 11 }edge[N<<1];
 12 int E,n,m,x,y,head[N],v[N],fa[N],sz[N],son[N],id[N],top[N],las[N];
 13 char s[11];
 14 int ksm(int n,int m){
 15     int s=n,ans=1;
 16     while (m){
 17         if (m&1)ans=ans*s%mod;
 18         s=s*s%mod;
 19         m>>=1;
 20     }
 21     return ans;
 22 }
 23 void add(int x,int y){
 24     edge[E].nex=head[x];
 25     edge[E].to=y;
 26     head[x]=E++;
 27 }
 28 void dfs1(int k,int f){
 29     fa[k]=f;
 30     sz[k]=1;
 31     for(int i=head[k];i!=-1;i=edge[i].nex)
 32         if (edge[i].to!=f){
 33             dfs1(edge[i].to,k);
 34             sz[k]+=sz[edge[i].to];
 35             if ((!son[k])||(sz[son[k]]<sz[edge[i].to]))son[k]=edge[i].to;
 36         }
 37 }
 38 void dfs2(int k,int fa,int t){
 39     id[k]=++x;
 40     top[k]=t;
 41     if (!son[k])las[k]=k;
 42     else{
 43         dfs2(son[k],k,t);
 44         las[k]=las[son[k]];
 45     }
 46     for(int i=head[k];i!=-1;i=edge[i].nex){
 47         int x=edge[i].to;
 48         if ((x!=fa)&&(x!=son[k]))dfs2(x,k,x);
 49     }
 50 }
 51 struct num{
 52     int t,v;
 53     num operator * (const num &a){
 54         return num{t+a.t,v*a.v%mod};
 55     }
 56     num inv(){
 57         return num{-t,ksm(v,mod-2)};
 58     }
 59     int value(){
 60         if (t)return 0;
 61         return v;
 62     }
 63 };
 64 num turn(int k){
 65     k%=mod;
 66     if (!k)return num{1,1};
 67     return num{0,k};
 68 }
 69 struct mat{
 70     int a,b,c,d;
 71     mat operator * (const mat &k)const{
 72         mat ans;
 73         ans.a=(k.a+a*k.c)%mod;
 74         ans.b=(b+k.b+a*k.d)%mod;
 75         ans.c=c*k.c%mod;
 76         ans.d=(c*k.d+d)%mod;
 77         return ans;
 78     }
 79 }; 
 80 struct Seg{
 81     int h[N];
 82     num g[N];
 83     mat f[N<<2];
 84     void init(){
 85         f[0].c=1;
 86         for(int i=1;i<=n;i++)g[i]=turn(1);
 87     }
 88     void update(int k,int l,int r,int x){
 89         if (l==r){
 90             f[k].a=f[k].c=f[k].d=g[x].value();
 91             f[k].b=(g[x].value()+h[x])%mod;
 92             return;
 93         }
 94         if (x<=mid)update(L,l,mid,x);
 95         else update(R,mid+1,r,x);
 96         f[k]=f[R]*f[L];
 97     }
 98     mat query(int k,int l,int r,int x,int y){
 99         if ((l>y)||(x>r))return f[0];
100         if ((x<=l)&&(r<=y))return f[k];
101         return query(R,mid+1,r,x,y)*query(L,l,mid,x,y);
102     }
103     mat get(int k){
104         return query(1,1,n,id[k],id[las[k]]);
105     }
106     void update(int k,num x,int y){
107         while (k){
108             mat ans=get(top[k]);
109             g[id[k]]=g[id[k]]*x;
110             h[id[k]]+=y;
111             x=turn(ans.a+1).inv(),y=mod-ans.b;
112             update(1,1,n,id[k]);
113             ans=get(top[k]);
114             x=x*turn(ans.a+1),y=(y+ans.b)%mod;
115             k=fa[top[k]];
116         }
117     }
118 }T[M];
119 struct FWT{
120     int a[M];
121     void fwt(int p){
122         for(int i=0;i<7;i++)
123             for(int j=0;j<M;j++)
124                 if (j&(1<<i)){
125                     int x=a[j^(1<<i)],y=a[j];
126                     a[j^(1<<i)]=(x+y)%mod;
127                     a[j]=(x+mod-y)%mod;
128                 }
129         if (p){
130             int s=ksm(M,mod-2);
131             for(int i=0;i<M;i++)a[i]=1LL*a[i]*s%mod;
132         }
133     }
134 }ans;
135 void update(int k,int p){
136     for(int i=0;i<M;i++)ans.a[i]=(i==v[k]);
137     ans.fwt(0);
138     for(int i=0;i<M;i++)
139         if (!p)T[i].update(k,turn(ans.a[i]),0);
140         else T[i].update(k,turn(ans.a[i]).inv(),0);
141 }
142 int main(){
143     scanf("%d%*d",&n);
144     for(int i=1;i<=n;i++)scanf("%d",&v[i]);
145     memset(head,-1,sizeof(head));
146     for(int i=1;i<n;i++){
147         scanf("%d%d",&x,&y);
148         add(x,y);
149         add(y,x);
150     }
151     dfs1(1,0);
152     x=0;
153     dfs2(1,0,1);
154     for(int i=0;i<M;i++)T[i].init();
155     for(int i=1;i<=n;i++)update(i,0);
156     scanf("%d",&m);
157     for(int i=1;i<=m;i++){
158         scanf("%s%d",s,&x);
159         if (s[0]=='Q'){
160             for(int j=0;j<M;j++)ans.a[j]=T[j].get(1).b;
161             ans.fwt(1);
162             printf("%d\n",ans.a[x]);
163         }
164         else{
165             update(x,1);
166             scanf("%d",&v[x]);
167             update(x,0);
168         }
169     }
170 }
View Code

 

posted @ 2021-01-27 09:12  PYWBKTDA  阅读(126)  评论(0编辑  收藏  举报