LOJ #2269. 「SDOI2017」切树游戏
题目链接
题目大意
给定一棵 \(n\) 个点的树,点有权值 \(0\leq v_i<m\),一个连通块的权值为所有点的权值 \(v_i\) 的异或和,有 \(q\) 次操作,分为两种:
Change x y
将 \(v_x\) 修改为 \(y\) 。Query k
求权值为 \(k\) 的连通子树个数,对 \(10007\) 取模。
\(n,q\leq 3\times 10^4,m\leq128\),修改操作最多有 \(10^4\) 个。
思路
设 \(f(i,j)\) 表示连通块最顶端节点为 \(i\),权值为 \(j\) 的方案数。在树上 DP 时,开始令 \(f(x,v_x)=1\),将儿子 \(y\) 的子树合并到 \(x\):\(f(x,i\oplus j)\leftarrow f(x,i)\cdot (f(y,j)+[j=0])\),\([j=0]\) 是不选取 \(y\) 部分的情况,而权值为 \(k\) 的答案即 \(f(x,k)\) 之和。
注意到转移是异或卷积的形式,设 \(F_u(x)\) 为 \(f(u,*)\) 的集合幂级数,则
其中乘法为异或卷积。定义 \(H_u(x)\) 为 \(u\) 子树内的所有 \(F_v(x)\) 之和,这样答案即 \([x^k]H_1(x)\) 。
注意到转移很简洁,考虑用动态 DP 维护此过程。令 \(\text{hson}(x)\) 为 \(x\) 的重儿子,\(\text{lson}(x)\) 为 \(x\) 的轻儿子集合,处理出
将树重链剖分,先直接计算轻边的转移,\(F'_u(x)\) 和 \(H'_u(x)\) 是只考虑了轻儿子的信息,然后对于每条重链集中转移,到链顶时一次性算出 \(F,H\)。设当前链上的点从浅到深为 \(a_1,a_2,..,a_k\),那么有
这可以写成矩阵乘法的形式(\(u=a_i,v=a_{i+1}\))
于是可以用线段树维护重链上的矩阵乘法,这个运算可以进一步简化:
所以矩阵只需要记录四元组 \((a,b,c,d)\) 即可。对于初始情况,\(u=a_k\) 则 \(v\) 是空节点,即 \((F_v\;\;H_v\;\;1)\) 是 \((0\;\;0\;\;1)\),乘上转移矩阵可以得到 \((F_{a_1}\;\;H_{a_1}\;\;1)=(c\;\;d\;\;1)\) 。
于是节点的权值 \(v_x\) 的修改就是修改了自己对应的转移矩阵,从 \(x\) 到根节点 \(1\) 路径上的值会发生变化,一条重链上的改动就是一个线段树单点修改、全局求乘积操作,跳重链用线段树维护,走轻边时直接更新父亲的 \(F'_u\) 和 \(H_u'\) 即可。是 \(O(\log^2n)\) 次操作。
异或卷积就拿 \(FWT\) 维护,注意到 \(FWT\) 本质做的是线性变换,所以可以直接将集合幂级数做一遍正变换,然后上述的所有操作都是变换下按位独立操作,最后再逆变换回来即可,于是级数的四则运算都是 \(O(m)\) 的。
预处理复杂度是 \(O(nm)\) 的,修改一次复杂度 \(O(m\log^2 n)\),查询是 \(O(1)\) 的,综上时间复杂度为 \(O(nm+qm\log^2 n)\),\(q\) 只算修改的操作 \(\leq 10^4\) 。仔细一算 \(2.87e8\),似乎有点危,考虑到树剖的常数通常很小,官方数据可过,但是在洛谷上就被卡掉了。
要优化复杂度,需要一个叫「全局平衡二叉树」的科技,本质是这样的:树剖是对于每条重链建出一棵二叉树,将它们按照树上结构连起来长这样
节点更新就是沿着这棵树往上走,虚边就是原树的轻边,实边是重链线段树内部的边。可以发现这树不太平衡,树高居然是 \(O(\log^2n)\) 的,那就把它调整成树高 \(O(\log n)\) 的全局平衡二叉树:
也就是对原来的线段树做手脚,一条重链上,每个点有权值 \(w_x=siz_x-siz_{\text{hson(x)}}\),即 \(x\) 以及下面挂着的轻儿子子树大小和,然后对重链建二叉树(线段树)时采取带权分治,每次找到带权的中点,然后两边分治建树。这样这棵树整体上就是平衡的了,走虚边和实边时,节点所代表的那一段的轻子树大小基本都会翻倍(反例是 “线段树” 部分可能会出现分不平均的情况,不过分析下来是不影响复杂度的),于是树高便是 \(O(\log n)\) 的。
时间复杂度遂降为 \(O(nm+qm\log n)\) 。
Code
一个细节,在 \(v\) 更新父亲 \(u\) 的 \(F_u'\) 和 \(H_u'\) 时,要把之前的贡献去掉的,如果之前 \(F_v(x)+1\) 对应位是个 \(0\),会有除零的情况,所以这里数字要维护二元组 \((val,zero)\),将非零部分的值和乘上的零分开来记录。
将集合幂级数的运算循环展开,再加上一些卡常就可以冲到洛谷最优解了,这里贴的是卡常前的代码。
#include<iostream>
#include<cstring>
#include<vector>
#define mem(a,b) memset(a, b, sizeof(a))
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 30030
#define M 128
#define mod 10007
#define inv2 5004
using namespace std;
int n, m, q, v[N];
int head[N], nxt[2*N], to[2*N];
int siz[N], hson[N], lsiz[N];
int cnt;
int inv[mod];
void init(){ mem(head, -1), cnt = -1; }
void add_e(int a, int b, bool id){
nxt[++cnt] = head[a], head[a] = cnt, to[cnt] = b;
if(id) add_e(b, a, 0);
}
void dfs(int x, int fa){
siz[x] = 1;
int mx = 0;
for(int i = head[x], y; ~i; i = nxt[i]) if((y = to[i]) != fa){
dfs(y, x);
if(siz[y] > mx) mx = siz[y], hson[x] = y;
siz[x] += siz[y];
}
lsiz[x] = siz[x] - siz[hson[x]];
}
struct Dat{
int a[M];
void FWT(int op){
for(int i = 1; i < m; i <<= 1)
for(int j = 0; j < m; j += i<<1) rep(k,0,i-1){
int v1 = a[j+k], v2 = a[i+j+k];
a[j+k] = (v1 + v2) % mod, a[i+j+k] = (v1 + mod - v2) % mod;
if(op == -1) (a[j+k] *= inv2) %= mod, (a[i+j+k] *= inv2) %= mod;
}
}
Dat operator * (Dat b){
Dat c; rep(i,0,m-1) c.a[i] = a[i] * b.a[i] % mod;
return c;
}
Dat operator + (Dat b){
Dat c; rep(i,0,m-1) c.a[i] = (a[i] + b.a[i]) % mod;
return c;
}
Dat operator - (Dat b){
Dat c; rep(i,0,m-1) c.a[i] = (a[i] + mod - b.a[i]) % mod;
return c;
}
} pow[M], ans;
struct Globally_Balanced_Binary_Tree{
struct mat{
Dat a, b, c, d;
mat operator * (mat u){
return {a*u.a, a*u.b+b, c*u.a+u.c, c*u.b+d+u.d}; }
} val[2*N];
struct dat{
int v, zero;
dat operator * (int b){ return (b %= mod) ? dat{v*b%mod, zero} : dat{v, zero+1}; }
dat operator / (int b){ return (b %= mod) ? dat{v*inv[b]%mod, zero} : dat{v, zero-1}; }
} F[N][M];
Dat H[N];
int par[2*N], c[2*N][2], cnt;
bool vis[N];
int build(vector<int> vec){
if(vec.size() == 1){
int x = vec[0]; vis[x] = true;
rep(i,0,m-1) F[x][i].v = pow[0].a[i];
for(int i = head[x]; ~i; i = nxt[i]) if(to[i] != hson[x] && !vis[to[i]]){
int y = to[i], z = y;
vector<int> cev;
while(z) cev.push_back(z), z = hson[z];
par[y = build(cev)] = x;
rep(i,0,m-1) F[x][i] = F[x][i] * (val[y].c.a[i] + 1);
H[x] = H[x] + val[y].d;
}
rep(i,0,m-1) val[x].a.a[i] = F[x][i].zero ? 0 : F[x][i].v;
val[x].b = val[x].c = val[x].a = val[x].a * pow[v[x]];
val[x].d = H[x] + val[x].a;
return x;
}
int tot = 0;
for(int x : vec) tot += lsiz[x];
vector<int> lft = {vec[0]}, rgt;
for(int i = 1, cur = lsiz[vec[0]]; i < vec.size(); cur += lsiz[vec[i++]])
(cur*2 < tot && i+1 < vec.size() ? lft : rgt).push_back(vec[i]);
int x = ++cnt, a = build(lft), b = build(rgt);
par[a] = par[b] = x, c[x][0] = a, c[x][1] = b;
val[x] = val[b] * val[a];
return x;
}
void update(int x, int v){
ans.FWT(1);
auto check = [&](int x){
int y = par[x];
if(y >= 1 && y <= n){
rep(i,0,m-1) F[y][i] = F[y][i] / (val[x].c.a[i] + 1);
H[y] = H[y] - val[x].d;
} else if(y == 0) ans = ans - val[x].d;
};
check(x);
::v[x] = v;
while(x){
if(x <= n){
rep(i,0,m-1) val[x].a.a[i] = F[x][i].zero ? 0 : F[x][i].v;
val[x].a = val[x].b = val[x].c = val[x].a * pow[::v[x]];
val[x].d = H[x] + val[x].a;
}
int y = par[x]; check(y);
if(y > n){
val[y] = val[c[y][1]] * val[c[y][0]];
} else if(y){
rep(i,0,m-1) F[y][i] = F[y][i] * (val[x].c.a[i] + 1);
H[y] = H[y] + val[x].d;
} else ans = ans + val[x].d;
x = y;
}
ans.FWT(-1);
}
} T;
void prework(){
inv[0] = inv[1] = 1;
rep(i,2,mod-1) inv[i] = (mod-mod/i) * inv[mod%i] % mod;
rep(i,0,m-1) pow[i].a[i] = 1, pow[i].FWT(1);
}
int main(){
ios::sync_with_stdio(false);
cin>>n>>m;
rep(i,1,n) cin>>v[i];
init(); int u, v;
rep(i,1,n-1) cin>>u>>v, add_e(u, v, 1);
prework();
dfs(1, 0);
vector<int> vec; int cur = 1;
while(cur) vec.push_back(cur), cur = hson[cur];
T.cnt = n;
int root = T.build(vec);
ans = ans + T.val[root].d;
ans.FWT(-1);
string op; int x, y, k;
cin>>q;
while(q--){
cin>>op;
if(op == "Change") cin>>x>>y, T.update(x, y);
if(op == "Query") cin>>k, cout<< ans.a[k] <<endl;
}
return 0;
}