luogu P5298 [PKUWC2018]Minimax
luogu P5298 [PKUWC2018]Minimax
题目大意
不可描述
比较清楚就不讲了
题解
首先注意到题目中的这样一句话
保证这类点中每个结点的权值互不相同
显然线段树合并QWQ
考虑如何合并
先考虑只有一边有节点的情况
然后直接把贡献乘在那个子树上(tag)
假设把 以 x 和 y为根的两颗线段树合并
那发现如果 x, y都不为0就直接先往下做然后再 把概率加起来就行了
还是看代码吧
code:
#include<bits/stdc++.h>
#define int long long
#define mod 998244353
#define N 2000005
using namespace std;
struct A{
int tag, d;
}a[N << 4];
struct AA{
int l, r, val, d;
}son[N << 4];
int P, n, ans, m, tot, b[N], ch[N << 4][2], root[N];
void TAG(int rt, int o){
o %= mod;
a[rt].tag = a[rt].tag * o % mod;
a[rt].d = a[rt].d * o % mod;
}
void pushdown(int rt){
if(a[rt].tag == 1) return;
if(ch[rt][0]) TAG(ch[rt][0], a[rt].tag);
if(ch[rt][1]) TAG(ch[rt][1], a[rt].tag);
a[rt].tag = 1;
}
void insert(int &rt, int l, int r, int x){
if(! rt) rt = ++ tot;
a[rt].d = a[rt].tag = 1;
if(l == r) return;
int mid = (l + r) >> 1;
if(x <= mid) insert(ch[rt][0], l, mid, x);
else insert(ch[rt][1], mid + 1, r, x);
}
int merge(int x, int y, int xl, int xr, int yl, int yr){
xl %= mod, xr %= mod, yl %= mod, yr %= mod;
if(!x && !y) return 0;
if(x && y){
pushdown(x), pushdown(y);
int ll = a[ch[x][0]].d, rr = a[ch[y][0]].d;
ch[x][0] = merge(ch[x][0], ch[y][0], xl, xr + a[ch[x][1]].d, yl, yr + a[ch[y][1]].d);
ch[x][1] = merge(ch[x][1], ch[y][1], xl + ll, xr, yl + rr, yr);
a[x].d = (a[ch[x][0]].d + a[ch[x][1]].d) % mod;
}else{
if(!x) swap(x, y), swap(xl, yl), swap(xr, yr);
TAG(x, yl * P + (mod + 1 - P) * yr);
}
return x;
}
void dfs(int u){
if(son[u].val){
insert(root[u], 1, m, son[u].val);
}else if(son[u].r){
dfs(son[u].l), dfs(son[u].r);
P = son[u].d;
root[u] = merge(root[son[u].l], root[son[u].r], 0ll, 0ll, 0ll, 0ll);
}else{
dfs(son[u].l);
root[u] = root[son[u].l];
}
}
void getans(int rt, int l, int r){
pushdown(rt);
if(l == r){
ans += l * b[l] % mod * a[rt].d % mod * a[rt].d % mod, ans %= mod;
return;
}
int mid = (l + r) >> 1;
if(ch[rt][0]) getans(ch[rt][0], l, mid);
if(ch[rt][1]) getans(ch[rt][1], mid + 1, r);
}
signed main(){
scanf("%lld", &n);
for(int i = 1; i <= n; i ++){
int fa;
scanf("%lld", &fa);
if(!son[fa].l) son[fa].l = i;
else son[fa].r = i;
}
for(int i = 1; i <= n; i ++){
int x;
scanf("%lld", &x);
if(son[i].l) son[i].d = x * 796898467ll % mod;
else {
son[i].val = x;
b[++ m] = x;
}
}
sort(b + 1, b + 1 + m);
for(int i = 1; i <= n; i ++)
if(son[i].val) son[i].val = lower_bound(b + 1, b + 1 + m, son[i].val) - b;
dfs(1);
getans(root[1], 1, m);
printf("%lld", ans);
return 0;
}
好题!
终于把这个史前巨坑填了