LOJ 2537 「PKUWC2018」Minimax

BZOJ 5461。

线段树合并优化$dp$。

假设所有离散之后的权值$\in [1, m]$,对于一个点$x$它的权值是$i$的概率是$f(x, i)$,那么

1、假如这个点只有一个儿子$y$,那么$f(x, i) = f(y, i)$。

2、假如这个点有两个儿子$y, z$,那么

$$f(x, i) = f(y, i)\sum_{j = 1}^{m}f(z, j)(p_x[i \geq j] + (1 - p_x)[i \leq j]) + f(z, i)\sum_{j = 1}^{m}f(y, j)(p_x[i \geq j] + (1 - p_x)[i \leq j])$$

看上去$f(y, i) * f(z, i)$这一项被算了两次,但是题目保证了所有权值互不相同,所以似乎并没有问题。

$$f(x, i) = f(y, i)(p_x\sum_{j = 1}^{i}f(z, j) + (1 - p_x)\sum_{j = i}^{m}f(z, j)) + f(z, i)(p_x\sum_{j = 1}^{i}f(y, j) + (1 - p_x)\sum_{j = i}^{m}f(y, j))$$

这样子就把$f(x, i)$表示成了$f(y, i)$乘上$f(z)$的贡献加上$f(z, i)$乘上$f(y)$的贡献,对于一个特定的$x$,$p_x$是一个定值,只要在线段树合并的时候算出向下的贡献累加到结点中即可。

注意到当合并的时候会出现一个结点是空结点的情况,这时候其实相当于整段区间乘上了一段定值,可以采用在线段树上打标记的形式来维护。

以后线段树合并都直接合并到原来的结点上去,不要开新节点,这样子好写而且空间小,注意合并之后原来的权值可能会发生变化,所以开一个变量把原来的值记录下来。

时间复杂度$O(nlogn)$。

Code:

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;

const int N = 3e5 + 5;
const int M = 1e7 + 5;
const ll P = 998244353LL;

int n, fa[N], son[N][3], m = 0;
ll a[N], buc[N];
bool isLeaf[N];

template <typename T>
inline void read(T &X) {
    X = 0; char ch = 0; T op = 1;
    for (; ch > '9'|| ch < '0'; ch = getchar())
        if (ch == '-') op = -1;
    for (; ch >= '0' && ch <= '9'; ch = getchar())
        X = (X << 3) + (X << 1) + ch - 48;
    X *= op;
}

template <typename T>
inline void inc(T &x, T y) {
    x += y;
    if (x >= P) x -= P;
}

template <typename T>
inline void sub(T &x, T y) {
    x -= y;
    if (x < 0) x += P;
}

inline ll fpow(ll x, ll y) {
    ll res = 1;
    for (x %= P; y > 0; y >>= 1) {
        if (y & 1) res = res * x % P;
        x = x * x % P;
    }
    return res;
}

void dfs(int x) {
    isLeaf[x] = 1;
    for (int i = 1; i <= son[x][0]; i++) {
        isLeaf[x] = 0;
        dfs(son[x][i]);
    }
}

namespace SegT {
    struct Node {
        int lc, rc;
        ll sum, tag;
    } s[M];
    
    #define lc(p) s[p].lc
    #define rc(p) s[p].rc
    #define sum(p) s[p].sum
    #define tag(p) s[p].tag
    #define mid ((l + r) >> 1)
    
    int root[N], nodeCnt = 0;
        
    inline void down(int p) {
        if (tag(p) == 1LL) return;
        if (lc(p)) {
            (sum(lc(p)) *= tag(p)) %= P;
            (tag(lc(p)) *= tag(p)) %= P;        
        }
        if (rc(p)) {
            (sum(rc(p)) *= tag(p)) %= P;
            (tag(rc(p)) *= tag(p)) %= P;            
        }
        tag(p) = 1;
    }
    
    void ins(int &p, int l, int r, int x) {
        p = ++nodeCnt, tag(p) = sum(p) = 1;
        if (l == r) return;
        
        if (x <= mid) ins(lc(p), l, mid, x);
        else ins(rc(p), mid + 1, r, x);
    }
    
    int merge(int u, int v, ll su, ll sv, ll k) {        
        if (!u) {
            (tag(v) *= su) %= P;
            (sum(v) *= su) %= P;
            return v;
        }
        if (!v) {
            (tag(u) *= sv) %= P;
            (sum(u) *= sv) %= P;
            return u;            
        }
                
        down(u), down(v);
        ll k1 = k, k2 = 1, rsu = sum(rc(u)), lsu = sum(lc(u)), rsv = sum(rc(v)), lsv = sum(lc(v));
        sub(k2, k);
        
        lc(u) = merge(lc(u), lc(v), (su + k2 * rsu % P) % P, (sv + k2 * rsv % P) % P, k);
        rc(u) = merge(rc(u), rc(v), (su + k1 * lsu % P) % P, (sv + k1 * lsv % P) % P, k);
        
        sum(u) = 0;
        inc(sum(u), sum(lc(u))), inc(sum(u), sum(rc(u)));
        return u;
    }
    
    ll query(int p, int l, int r, int x) {
        if (l == r) return sum(p);
        
        down(p);
        if (x <= mid) return query(lc(p), l, mid, x);
        else return query(rc(p), mid + 1, r, x); 
    }    
    
    void deb(int rt) {
        for (int i = 1; i <= m; i++)
            printf("%lld%c", query(rt, 1, m, i), i == m ? '\n' : ' ');
    }
    
} using namespace SegT;

void solve(int x) {    
    if (isLeaf[x]) return;
    
    for (int i = 1; i <= son[x][0]; i++) solve(son[x][i]);
    
    if (son[x][0] == 1) root[x] = root[son[x][1]];
    if (son[x][0] == 2) root[x] = SegT :: merge(root[son[x][1]], root[son[x][2]], 0, 0, a[x] % P);
}

int main() {
    #ifndef ONLINE_JUDGE
        freopen("5.in", "r", stdin);
    #endif

    read(n);
    for (int i = 1; i <= n; i++) {
        read(fa[i]);
        if (fa[i] != 0) son[fa[i]][++son[fa[i]][0]] = i;
    }
    dfs(1);
    
    ll inv10000 = fpow(10000LL, P - 2);
    for (int i = 1; i <= n; i++) {
        read(a[i]);
        if (isLeaf[i]) buc[++m] = a[i];
        else a[i] = a[i] * inv10000 % P;
    }
    
    sort(buc + 1, buc + 1 + m);
    for (int i = 1; i <= n; i++) {
        if (!isLeaf[i]) continue;
        a[i] = lower_bound(buc + 1, buc + 1 + m, a[i]) - buc;
        ins(root[i], 1, m, a[i]);
    }
        
    solve(1);
    
    ll ans = 0;
    for (int i = 1; i <= m; i++) {
        ll p = query(root[1], 1, m, i);    
        inc(ans, 1LL * i * p % P * p % P * (buc[i] % P) % P);
    }
    
    printf("%lld\n", ans);
    
/*    for (int i = 1; i <= m; i++) {
        ll p = query(root[1], 1, m, i);
        printf("%lld%c", p, i == m ? '\n' : ' ');
    }   */
    return 0;
}
View Code

 

posted @ 2019-01-15 10:51  CzxingcHen  阅读(190)  评论(0编辑  收藏  举报