线段树合并
以前一直以为像 treap 的 merge 一样的, 从来没写过
突然发现处理树上问题的复杂度不会证
概述
就是很朴素的递归合并
// x 和 y 是当前合并的两个节点号
int merge(int x, int y, int l, int r)
{
if (!x || !y) return x + y;
if (l == r)
{
<DO SOMETHING>
return x;
}
ch[x][0] = merge(ch[x][0], ch[y][0], l, mid);
ch[x][1] = merge(ch[x][1], ch[y][1], mid + 1, r);
update(x);
return x;
}
分析这段代码, 容易得到, 一次合并两棵线段树的复杂度即为这个函数经过的节点数, 也就是两棵线段树重合的节点数
那么有这样一个问题:
树上有点权, 你需要给每个点开一个线段树(动态开点), 来维护其子数内点权的信息
于是需要不断把儿子的线段树合并到自己, 那么这样的复杂度如何呢?
每个点都会加一条链, 所以这样只加入会有 \(n \log n\) 个点
然后每次合并两个重合的点, 那么其中一个之后都不回去访问他了, 相当于删去了这个点, 复杂度即为访问一次的复杂度 \(O(1)\)
显然最后不会删完所有加入的点, 所以复杂度是 \(O(n \log n)\)
CF600E Lomsat gelral
题意:
一棵树有颜色, 1 为根, 求出每个子数内数量最多的(可能并列最多)颜色的和(比如颜色 2, 3 最多, 和为 2 + 3)
Sol:
权值线段树统计区间中最多出现的次数, 以及所有出现这个次数的和
namespace SEG
{
const int MAXLG = 18, MAXT = MAXN * MAXLG;
int tot, root[MAXN];
LL sum[MAXT];
int cnt[MAXT], mx[MAXT], ch[MAXT][2];
#define mid ((l + r) >> 1)
#define ls (ch[now][0])
#define rs (ch[now][1])
void update(int now)
{
if (mx[ls] == mx[rs]) sum[now] = sum[ls] + sum[rs], mx[now] = mx[ls];
else if (mx[ls] < mx[rs]) sum[now] = sum[rs], mx[now] = mx[rs];
else sum[now] = sum[ls], mx[now] = mx[ls];
}
void insert(int & now, int l, int r, int c)
{
if (!now) now = ++ tot;
if (l == r)
{
sum[now] = c;
cnt[now] += 1;
mx[now] = cnt[now];
return ;
}
if (c <= mid) insert(ls, l, mid, c);
else insert(rs, mid + 1, r, c);
update(now);
}
int merge(int x, int y, int l, int r)
{
if (!x || !y) return x + y;
if (l == r)
{
sum[x] = sum[y];
cnt[x] += cnt[y];
mx[x] = cnt[x];
return x;
}
ch[x][0] = merge(ch[x][0], ch[y][0], l, mid);
ch[x][1] = merge(ch[x][1], ch[y][1], mid + 1, r);
update(x);
return x;
}
void merge(int u, int v)
{
root[u] = merge(root[u], root[v], 1, n);
}
LL getmax(int u)
{
return sum[root[u]];
}
#undef mid
#undef ls
#undef rs
}
LL ans[MAXN];
void DFS(int u, int fa)
{
SEG::insert(SEG::root[u], 1, n, c[u]);
for (int i = head[u]; i; i = adj[i].nex)
{
int v = adj[i].to;
if (v == fa) continue;
DFS(v, u);
SEG::merge(u, v);
}
ans[u] = SEG::getmax(u);
}
int main()
{
n = in();
for (int i = 1; i <= n; ++ i) c[i] = in();
for (int i = 1; i < n; ++ i) link(in(), in());
DFS(1, 0);
for (int i = 1; i <= n; ++ i) printf("%lld ", ans[i]);
return 0;
}
BZOJ5461[PKUWC2018]Minimax
这道题很适合用来加深理解, 我就是因为这题才来做线段是合并
题意:
二叉树, 1为根, 叶子有点权且各不相同, 非叶子节点有概率 \(p_i\), 选择他儿子中较大哪个, 否则较小哪个
将根可能的取值从小到大标号为 \(V_0, \dots, V_m\), 概率分别为 \(D_i\) 求
Sol:
一开始当做了推式子的题, 还漏看了了许多条件
理一下条件, 容易发现对于一个节点 \(u\):
- 一个儿子, 那么 \(100\%\) 取他的儿子
- 否则, 对于可能出现在左儿子的一个值 \(x\) (右儿子没有), \(D[u][x] = D[l][x] \cdot p_u \cdot D[r][<x] + D[l][x] \cdot (1 - p_u) \cdot D[r][>x]\)
如果枚举正确的话, 复杂度相当于子数合并, \(O(n^2)\), 40 分
康康线段树合并的优化:
权值线段树记录一段下标的概率和
合并要求只访问重合的点, 于是发现, 从上往下, 给左右儿子分别记录上述的前缀和和后缀和,
然后到达不重合的那一点, 那么他的子数内如上的前缀和和后缀和都不会再改变了, 因为权值各不相同, 此时另一个儿子能做贡献的都搞定了, 所以此时打上标记, 正确更新这个节点的所有下标,
然后一路上推回去就完成了
复杂度就是合并的复杂度, \(O(n \log n)\)
const LL MOD = 998244353;
LL p[MAXN];
void add(LL & x, LL y)
{
x += y;
x %= MOD;
}
void mul(LL & x, LL y)
{
x *= y;
x %= MOD;
}
namespace SEG
{
const int MAXLG = 20, MAXT = MAXN * MAXLG;
int tot, root[MAXN];
LL sum[MAXT], tag[MAXT];
int ch[MAXT][2];
#define mid ((l + r) >> 1)
#define ls (ch[now][0])
#define rs (ch[now][1])
void update(int now)
{
sum[now] = (sum[ls] + sum[rs]) % MOD;
}
void addtag(int now, LL v)
{
if (!now) return ;
mul(sum[now], v);
mul(tag[now], v);
}
void pushdown(int now)
{
if (!now || tag[now] == 1) return ;
addtag(ls, tag[now]); addtag(rs, tag[now]);
tag[now] = 1;
}
void insert(int & now, int l, int r, int x, int v)
{
now = ++ tot;
tag[now] = 1;
if (l == r)
{
sum[now] = v;
return ;
}
if (x <= mid) insert(ls, l, mid, x, v);
else insert(rs, mid + 1, r, x, v);
update(now);
}
int merge(int x, int y, LL p, LL xl, LL xg, LL yl, LL yg)
{
pushdown(x); pushdown(y);
if (!x || !y)
{
if (x) addtag(x, (p * yl % MOD + (1 + MOD - p) * yg % MOD) % MOD);
else if (y) addtag(y, (p * xl % MOD + (1 + MOD - p) * xg % MOD) % MOD);
return x + y;
}
LL sx1 = sum[ch[x][1]], sx0 = sum[ch[x][0]], sy1 = sum[ch[y][1]], sy0 = sum[ch[y][0]];
ch[x][0] = merge(ch[x][0], ch[y][0], p, xl, (xg + sx1) % MOD, yl, (yg + sy1) % MOD);
ch[x][1] = merge(ch[x][1], ch[y][1], p, (xl + sx0) % MOD, xg, (yl + sy0) % MOD, yg);
update(x);
return x;
}
LL query(int now, int l, int r, int x)
{
if (l == r) return sum[now];
pushdown(now);
if (x <= mid) return query(ls, l, mid, x);
else return query(rs, mid + 1, r, x);
}
#undef mid
#undef ls
#undef rs
}
bool hvson[MAXN];
int tot;
LL buc[MAXN], val[MAXN];
void init()
{
sort(buc + 1, buc + tot + 1);
for (int i = 1; i <= n; ++ i)
if (val[i]) val[i] = lower_bound(buc + 1, buc + tot + 1, val[i]) - buc;
}
int fa[MAXN];
int son[MAXN][2];
void DFS(int u)
{
for (int i = head[u]; i; i = adj[i].nex)
{
int v = adj[i].to;
DFS(v);
if (!son[u][0]) son[u][0] = v;
else son[u][1] = v;
}
if (!son[u][0]) SEG::insert(SEG::root[u], 1, tot, val[u], 1);
else if (!son[u][1]) SEG::root[u] = SEG::root[son[u][0]];
else SEG::root[u] = SEG::merge(SEG::root[son[u][0]], SEG::root[son[u][1]], p[u], 0, 0, 0, 0);
}
int main()
{
n = in();
for (int i = 1; i <= n; ++ i)
{
hvson[fa[i] = in()] = true;
addedge(fa[i], i);
}
LL i10000 = 796898467;
for (int i = 1; i <= n; ++ i)
{
int x = in();
if (!hvson[i]) val[i] = buc[++ tot] = x;
else p[i] = 1ll * x * i10000 % MOD;
}
init();
DFS(1);
LL ans = 0;
for (int i = 1; i <= tot; ++ i)
{
LL d = SEG::query(SEG::root[1], 1, tot, i);
add(ans, 1ll * i * buc[i] % MOD * d % MOD * d % MOD);
}
printf("%lld\n", ans);
return 0;
}