SDOI2017 切树游戏

一棵树,每个点有点权,多次操作

1.单点修改一个点的点权

2.询问有多少棵子树点权异或和为 $k$

$n \leq 30000,k \leq 128,q \leq 30000$

sol:

动态 dp

为防止自己忘,再写一遍

一个点的 dp 值 = $\sum dp_{轻儿子} + dp_{重儿子}$

这样就可以一条重链一起转移

用线段树维护重链上的转移,每次修改时跳 log n 条重链

这题还要强行整个 fwt,甚至还要写一个模 mod 剩余类,维护乘了多少个 0

#include <bits/stdc++.h>
#define LL long long
#define rep(i, s, t) for (register int i = (s), i##end = (t); i <= i##end; ++i)
#define dwn(i, s, t) for (register int i = (s), i##end = (t); i >= i##end; --i)
using namespace std;
inline int read() {
    int x = 0, f = 1;
    char ch;
    for (ch = getchar(); !isdigit(ch); ch = getchar())
        if (ch == '-')
            f = -f;
    for (; isdigit(ch); ch = getchar()) x = 10 * x + ch - '0';
    return x * f;
}
const int mod = 10007, maxn = 65010, maxk = 256;
int n, m, inv[maxn];
int val[maxn], e[maxk][maxk], ans[maxk], tmp[maxk];
struct Mint {
    int val, tms;
    Mint(int cur = 0) {
        if (!cur) {
            val = tms = 1;
        } else {
            tms = 0, val = cur;
        }
    }
    friend Mint operator*(Mint a, int b) {
        if (b == 0)
            a.tms++;
        else
            (a.val *= b) %= mod;
        return a;
    }
    friend Mint operator/(Mint a, int b) {
        if (b == 0)
            a.tms--;
        else
            (a.val *= inv[b]) %= mod;
        return a;
    }
    int real() { return tms ? 0 : val; }
} f[maxn][maxk];
inline int inc(int x, int y) {
    x += y;
    if (x >= mod)
        x -= mod;
    return x;
}
inline int dec(int x, int y) {
    x -= y;
    if (x < 0)
        x += mod;
    return x;
}
void fwt(int *a, int n, int f) {
    for (register int i = 1; i < n; i <<= 1)
        for (register int j = 0; j < n; j += (i << 1)) rep(k, 0, i - 1) {
                int x = a[j + k], y = a[j + k + i];
                if (f == 1)
                    a[j + k] = inc(x, y), a[j + k + i] = dec(x, y);
                else
                    a[j + k] = inc(x, y) * inv[2] % mod, a[j + k + i] = dec(x, y) * inv[2] % mod;
            }
}
vector<int> G[maxn], ch[maxn];
int fa[maxn], dep[maxn], bl[maxn], mxs[maxn], size[maxn], stk[maxn], top;
inline void dfs1(int x) {
    size[x] = 1;
    for (int i=0;i<G[x].size();i++) {
        int to = G[x][i];
        if (to == fa[x])
            continue;
        dep[to] = dep[x] + 1;
        fa[to] = x;
        dfs1(to);
        if (size[to] > size[mxs[x]])
            mxs[x] = to;
        size[x] += size[to];
    }
}
inline void dfs2(int x, int col) {
    bl[x] = col;
    ch[col].push_back(x);
    if (bl[x] == x)
        stk[++top] = x;
    if (!mxs[x])
        return;
    dfs2(mxs[x], col);
    for (int i=0;i<G[x].size();i++) {
        int to = G[x][i];
        if (to != fa[x] && to != mxs[x])
            dfs2(to, to);
    }
}
int cmp(int i, int j) { return dep[i] > dep[j]; }
int root[maxn], ls[maxn << 6], rs[maxn << 6], anc[maxn << 6], pos[maxn << 6], ToT;
int h[maxn][maxk], lval[maxn][maxk], rval[maxn][maxk], sum[maxn][maxk];
void pushup(int x) {
    rep(i, 0, m - 1) {
        h[x][i] = (h[ls[x]][i] + h[rs[x]][i] + rval[ls[x]][i] * lval[rs[x]][i]) % mod;
        lval[x][i] = (lval[ls[x]][i] + sum[ls[x]][i] * lval[rs[x]][i]) % mod;
        rval[x][i] = (rval[rs[x]][i] + sum[rs[x]][i] * rval[ls[x]][i]) % mod;
        sum[x][i] = sum[ls[x]][i] * sum[rs[x]][i] % mod;
    }
}
inline void build(int &x, int l, int r, int ps) {
    x = ++ToT;
    if (l == r) {
        rep(i, 0, m - 1) h[x][i] = lval[x][i] = rval[x][i] = sum[x][i] = f[ch[ps][l - 1]][i].real();
        pos[ch[ps][l - 1]] = x;
        return;
    }
    int mid = (l + r) >> 1;
    build(ls[x], l, mid, ps);
    anc[ls[x]] = x;
    build(rs[x], mid + 1, r, ps);
    anc[rs[x]] = x;
    // anc[ls[x]] = anc[rs[x]] = x;
    pushup(x);
}
inline void modify(int u) {
    int x = pos[u], tp = bl[u];
    if (fa[tp])
        rep(i, 0, m - 1) f[fa[tp]][i] = f[fa[tp]][i] / ((lval[root[tp]][i] + e[0][i]) % mod);
    rep(i, 0, m - 1) ans[i] = (ans[i] - h[root[tp]][i] + mod) % mod;
    rep(i, 0, m - 1) sum[x][i] = lval[x][i] = rval[x][i] = h[x][i] = f[u][i].real();
    for (x = anc[x]; x; x = anc[x]) pushup(x);

    if (fa[tp])
        rep(i, 0, m - 1) f[fa[tp]][i] = f[fa[tp]][i] * ((lval[root[tp]][i] + e[0][i]) % mod);
    rep(i, 0, m - 1) ans[i] = (ans[i] + h[root[tp]][i]) % mod;
    // rep(i, 0, m-1) cout << ans[i] << " ";
    // cout << endl;
}
int main() {
    n = read(), m = read();
    int w = 1;
    for (; w <= m; w <<= 1)
        ;
    m = w;
    rep(i, 1, n) val[i] = read();
    inv[1] = 1;
    rep(i, 2, mod - 1) inv[i] = (mod - (mod / i)) * inv[mod % i] % mod;
    rep(i, 2, n) {
        int u = read(), v = read();
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs1(1);
    dfs2(1, 1);
    rep(i, 0, m - 1) {
        e[i][i] = 1;
        fwt(e[i], m, 1);
    }
    rep(i, 1, n) rep(j, 0, m - 1) f[i][j] = Mint(e[val[i]][j]);
    sort(stk + 1, stk + top + 1, cmp);
    rep(i, 1, top) {
        int now = stk[i];
        build(root[now], 1, ch[now].size(), now);
        rep(j, 0, m - 1) ans[j] = (ans[j] + h[root[now]][j]) % mod;
        if (fa[now])
            rep(j, 0, m - 1) f[fa[now]][j] = f[fa[now]][j] * ((lval[root[now]][j] + e[0][j]) % mod);
    }
    // rep(i, 0, m-1) cout << ans[i] << " ";
    //    cout << endl;
    int q = read();
    char opt[10];
    while (q--) {
        scanf("%s", opt + 1);
        if (opt[1] == 'C') {
            int x = read(), y = read();
            rep(i, 0, m - 1) f[x][i] = f[x][i] / e[val[x]][i];
            val[x] = y;
            rep(i, 0, m - 1) f[x][i] = f[x][i] * e[val[x]][i];
            for (; x; x = fa[bl[x]]) modify(x);
        } else {
            int x = read();
            // rep(i, 0, m-1) cout << ans[i] << " ";
            rep(i, 0, m - 1) tmp[i] = ans[i];
            fwt(tmp, m, -1);
            // rep(i, 0, m-1) cout << tmp[i] << " ";
            // cout << endl;
            printf("%d\n", tmp[x]);
        }
    }
}
View Code

 

posted @ 2019-03-11 09:31  探险家Mr.H  阅读(253)  评论(0编辑  收藏  举报