[GXOI/GZOI2019]旧词

相关链接(雾:[LNOI2014]LCA

实际上这题就是加了一个幂。。原题爆破比赛

原题:当\(k=1\)

暴力:求LCA再求深度。。然后观察一下求LCA的方法。。最暴力的方法就是把根节点到节点\(i\)的路径上的点都打上标记,然后由节点\(y\)往上,直到一个有标记的点为止。。

然而这题并不需要求LCA的序号,只需要深度,于是对于每个询问中\([1,x]\)的每个节点\(i\),把上面打标记换成权值\(+1\),把扫到一个有标记的点为止变为统计根节点到节点\(y\)的路径上节点的权值和。。

然后我们发现每个询问中中的节点可以一起做,于是这就变成了树链剖分的模版。。

\(k\not=1\)

先观察\(k=1\)时,点的权值其实是\((dep_i)^1-(dep_i-1)^1=1\)

于是,当\(k\not=1\)时,点权就是\((dep_i)^k-(dep_i-1)^k=1\)。当然,处理这个的时候要预处理一下。。其他的和\(k=1\)时一样。


#include <bits/stdc++.h>
#define lson (n << 1)
#define rson (n << 1 | 1)
#define next ___________________________________________________________________________________________________
using namespace std;

const int MAXN = 50004,MOD = 998244353;
int n, q, K, f[MAXN], x, y, value[MAXN], size[MAXN], h, head[MAXN], dep[MAXN], v[MAXN], cnt, t[MAXN << 2], k[MAXN << 2], delta[MAXN << 2], id[MAXN], nid[MAXN], top[MAXN], son[MAXN], ans[MAXN];

struct edge {
    int to, next;
} g[MAXN << 1];

struct data {
    int x, y, id;
} s[MAXN];

inline void addedge(int x, int y) {g[++h].next = head[x],head[x] = h,g[h].to = y;}

inline int qpow(int a, int b) {
    int s = 1;
    while (b) {
        if (b & 1) s = 1LL * s * a % MOD;
        a = 1LL * a * a % MOD,b >>= 1;
    }
    return s;
}

inline void pushdown(int n) {
    int x = delta[n];
    delta[n] = 0,
    t[lson] = (t[lson] + 1LL * x * k[lson]) % MOD,t[rson] = (t[rson] + 1LL * x * k[rson]) % MOD,
    delta[lson] = (delta[lson] + x) % MOD,delta[rson] = (delta[rson] + x) % MOD;
}

inline void dfs1(int x) {
    int j, maxs = -1;
    size[x] = 1;
    for (int i = head[x]; i; i = g[i].next) {
        j = g[i].to;
        if (dep[j]) continue;
        dep[j] = dep[x] + 1,
        dfs1(j),
        size[x] += size[j];
        if (size[j] >= maxs) son[x] = j, maxs = size[j];
    }
}

inline void dfs2(int x, int ttop) {
    top[x] = ttop,id[x] = ++cnt,nid[cnt] = x;
    if (!son[x]) return;
    dfs2(son[x], ttop);
    int j;
    for (int i = head[x]; i; i = g[i].next) {
        j = g[i].to;
        if (j == son[x] || j == f[x]) continue;
        dfs2(j, j);
    }
}

inline bool cmp(data a, data b) { return a.x < b.x; }

inline int add(int x, int y) {
    x += y;
    if (x >= MOD) x -= MOD;
    return x;
}

inline void build(int n, int l, int r) {
    if (l == r) return (void)(k[n] = v[nid[l]],t[n] = delta[n] = 0);
    int mid = (l + r) >> 1;
    build(lson, l, mid),build(rson, mid + 1, r),
    t[n] = add(t[lson], t[rson]),k[n] = add(k[lson], k[rson]);
}

inline int query(int n, int l, int r, int ll, int rr) {
    if (ll <= l && r <= rr) return t[n];
    int mid = (l + r) >> 1, ans = 0;
    if (delta[n]) pushdown(n);
    if (ll <= mid) ans = query(lson, l, mid, ll, rr);
    if (rr > mid) ans = add(ans, query(rson, mid + 1, r, ll, rr));
    return ans;
}

inline void update(int n, int l, int r, int ll, int rr) {
    if (ll <= l && r <= rr) return (void)(t[n] = add(t[n], k[n]),delta[n] = add(delta[n], 1));
    int mid = (l + r) >> 1;
    if (delta[n]) pushdown(n);
    if (ll <= mid) update(lson, l, mid, ll, rr);
    if (rr > mid) update(rson, mid + 1, r, ll, rr);
    t[n] = add(t[lson], t[rson]);
}

inline void change(int x, int y) {
    while (top[x] != top[y]) {
        if (dep[top[x]] > dep[top[y]]) swap(x, y);
        update(1, 1, n, id[top[y]], id[y]),
        y = f[top[y]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    update(1, 1, n, id[x], id[y]);
}

int solve(int x, int y) {
    int ans = 0;
    while (top[x] != top[y]) {
        if (dep[top[x]] > dep[top[y]]) swap(x, y);
        ans = add(ans, query(1, 1, n, id[top[y]], id[y])),
        y = f[top[y]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    ans = add(ans, query(1, 1, n, id[x], id[y]));
    return ans;
}

int main() {
    scanf("%d%d%d", &n, &q, &K);
    for (int i = 2; i <= n; ++i)
        scanf("%d", &f[i]),
        addedge(f[i], i);
    dep[1] = 1,
    dfs1(1),dfs2(1, 1);
    for (int i = 1; i <= n; ++i) value[i] = qpow(dep[i], K);
    for (int i = 1; i <= n; ++i) v[i] = (value[i] - value[f[i]] + MOD) % MOD;
    for (int i = 1; i <= q; ++i)
        scanf("%d%d", &s[i].x, &s[i].y),
        s[i].id = i;
    sort(s + 1, s + 1 + q, cmp),
    build(1, 1, n);
    int l = 1;
    for (int i = 1; i <= q; ++i) {
        while (l <= s[i].x && l <= n)
			change(1, l),
			l++;
        ans[s[i].id] = solve(1, s[i].y);
    }
    for (int i = 1; i <= q; ++i) printf("%d\n", ans[i]);
    return 0;
}
posted @ 2019-05-26 13:28  蒟蒻SLS  阅读(175)  评论(0编辑  收藏  举报