[WC2024] 线段树
[WC2024] 线段树
考虑如何写指数暴力,即给定 \(S\) 判定是否合法,发现我们可以这么做:对线段树上每个结点,如果在 \(S\) 内则从 \(L\) 向 \(R\) 连一条边,表示若一直 \(s_L\) 则可以知道 \(s_R\),其中 \(s\) 为前缀和。那么可以确定询问区间 \([L, R)\) 当且仅当建出来的图中 \(L\) 和 \(R\) 连通。
在线段树上我们有性质,\([L, R)\) 区间中若左右端点不连通,则必存在 \(j\) 使得与 \(L\) 连通的点都在 \(j\) 左侧,与 \(R\) 连通的点都在 \(j\) 右侧。我们称 \(j\) 为这个区间的分割点。
考虑一个线段树结点表示的区间 \([L,R)\),其中有若干不交极大子区间 \([l,r)\) 使得其中的结点不与 \(L\) 和 \(R\) 连通,进而这些区间不和区间外的结点连通。那么一个必要条件是对于这些子区间中任意一个 \(I\),都不存在询问区间一个端点在 \(I\) 内,另一个端点在 \(I\) 外。
我们发现这些子区间具有很好的性质。对于 区间 \([L, R)\) 的儿子 \([L, m)\) 和 \([m, R)\),区间 \([L, R)\) 的子区间集合就是左右儿子集合的并再并上左儿子的分割点到右儿子的分割点这个新增的区间。
考虑从下到上根据分割点 dp,过程中考虑当前结点的子区间集合,只在合法时新增子区间,这样新区间由于左右部分尚未处理询问的点都与中间点 \(m\) 连通,所以询问都被处理完了。具体地,设 \(f_{u, j}\) 表示结点 \(u\) 中 \(j\) 为分割点的方案数,\(g_u\) 表示左右端点连通的方案数,有三类转移:
- \(2g_lg_r \rightarrow g_u\)。
- \(g_l f_{r, i} \rightarrow g_u\), \(g_l f_{r, i} \rightarrow f_{u, i}\),交换 \(l\),\(r\) 同理。
- \(f_{l, i} f_{r, j} \rightarrow g_u\), \(f_{l, i} f_{r, j} \rightarrow f_{u, i}\),表示新增一个子区间。
现在考虑如何快速判定一个子区间合法,即判断是否所有询问两个端点或者同时在区间内,或者同时在区间外,这件事情可以直接对于每个询问把询问的区间染色,然后只需判断子区间的两个端点的颜色集合是否相同,可以用哈希做,记哈希值为 \(h\)。
注意到在上述 dp 中我们记录分割点的实际作用就是看处理第三类贡献时 \(i\) 和 \(j\) 的 \(h\) 是否相同,所以我们不妨直接把第二维换成 \(h\),这样第三类贡献变成了形如这样:
同时前两类贡献形式不变,这样就可以直接做到 \(O(n^2)\),并且可以直接上线段树合并优化,做到 \(O(n \log n)\)。
const int N = 4e5 + 10;;
int n, m, t;
int a[N];
ull hh[N];
int h[N];
mt19937_64 rnd(94012);
int L[N], R[N], lc[N], rc[N], tot = 1, leaf = 0;
void build(int u, int l, int r) {
L[u] = l, R[u] = r;
if(a[u]) {
if(a[u] - l > 1) lc[u] = ++tot;
else lc[u] = n + leaf, ++leaf;
build(lc[u], l, a[u]);
if(r - a[u] > 1) rc[u] = ++tot;
else rc[u] = n + leaf, ++leaf;
build(rc[u], a[u], r);
}
}
struct SegmentTree {
#define mid (l + r) / 2
int tot;
struct node {
int sum, lc, rc, tag;
}tr[N * 32];
void update(int k) {
add(tr[k].sum = tr[tr[k].lc].sum, tr[tr[k].rc].sum);
}
void addtag(int k, int v) {
tr[k].sum = 1ll * tr[k].sum * v % P;
tr[k].tag = 1ll * tr[k].tag * v % P;
}
void pushdown(int k) {
if(tr[k].tag != 1) {
addtag(tr[k].lc, tr[k].tag);
addtag(tr[k].rc, tr[k].tag);
tr[k].tag = 1;
}
}
void set(int &k, int l, int r, int pos) {
if(l > pos || r < pos) return ;
if(!k) k = ++tot, tr[k].tag = 1;
if(l == r) {
tr[k].sum = 1;
return ;
}
pushdown(k);
if(pos <= mid) set(tr[k].lc, l, mid, pos);
else set(tr[k].rc, mid + 1, r, pos);
update(k);
}
int query(int x) {
return tr[x].sum;
}
int query(int k, int l, int r, int pos) {
if(l > pos || r < pos || !k) return 0;
if(l == r) return tr[k].sum;
pushdown(k);
if(pos <= mid) return query(tr[k].lc, l, mid, pos);
else return query(tr[k].rc, mid + 1, r, pos);
}
int merge(int x, int y, int l, int r, int vl, int vr) {
if(!x) {
addtag(y, vl);
return y;
}
if(!y) {
addtag(x, vr);
return x;
}
if(l == r) {
tr[x].sum = (1ll * tr[x].sum * vr + 1ll * tr[y].sum * vl + 1ll * tr[x].sum * tr[y].sum) % P;
return x;
}
pushdown(x), pushdown(y);
tr[x].lc = merge(tr[x].lc, tr[y].lc, l, mid, vl, vr);
tr[x].rc = merge(tr[x].rc, tr[y].rc, mid + 1, r, vl, vr);
update(x);
return x;
}
#undef mid
}seg;
int rt[N], g[N];
void dfs(int u) {
if(!a[u]) {
g[u] = 1;
seg.set(rt[u], 1, t, h[L[u]]);
return ;
}
dfs(lc[u]), dfs(rc[u]);
g[u] = 2ll * g[lc[u]] * g[rc[u]] % P;
rt[u] = seg.merge(rt[lc[u]], rt[rc[u]], 1, t, g[lc[u]], g[rc[u]]);
add(g[u], seg.query(rt[u]));
// g[u] = 2ll * g[lc[u]] * g[rc[u]] % P;
// for(int i = L[u]; i < a[u]; ++i) {
// int w = 1ll * f[lc[u]][i] * g[rc[u]] % P;
// add(g[u], w), add(f[u][i], w);
// }
// for(int i = a[u]; i < R[u]; ++i) {
// int w = 1ll * g[lc[u]] * f[rc[u]][i] % P;
// add(g[u], w), add(f[u][i], w);
// }
// for(int i = L[u]; i < a[u]; ++i)
// for(int j = a[u]; j < R[u]; ++j)
// if(h[i] == h[j]) {
// int w = 1ll * f[lc[u]][i] * f[rc[u]][j] % P;
// add(g[u], w), add(f[u][i], w);
// }
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
cin >> n >> m;
for(int i = 1; i < n; ++i)
cin >> a[i];
for(int i = 1; i <= m; ++i) {
int l, r;
cin >> l >> r;
ull v = rnd();
hh[l] ^= v, hh[r] ^= v;
}
vector<ull> vec;
for(int i = 1; i <= n; ++i)
hh[i] ^= hh[i - 1];
for(int i = 0; i <= n; ++i)
vec.eb(hh[i]);
sort(vec.begin(), vec.end());
vec.erase(unique(vec.begin(), vec.end()), vec.end());
t = sz(vec);
for(int i = 0; i <= n; ++i)
h[i] = lower_bound(vec.begin(), vec.end(), hh[i]) - vec.begin() + 1;
build(1, 0, n);
dfs(1);
int ans = g[1];
if(vec.front() == 0) add(ans, seg.query(rt[1], 1, t, 1));
cout << ans << endl;
}
注释部分为 \(O(n^3)\) 代码