Loading

【题解】P5305 [GXOI/GZOI2019]旧词

题面很清楚,不概括题意了。

思路

树剖。

\(k = 1\) 的情况是 P4211 [LNOI2014]LCA

具体来说,只需要 \(\forall 1 \leq i \leq x\),将 \(1\)\(i\) 的路径上每一个结点权值都加一,然后询问从 \(1\)\(y\) 的路径权值和即可。

具体证明可以画图分讨。

考虑 \(k > 1\) 的情况。

按照 \(k = 1\) 的情况,只需要考虑将 \(dep^k\) 分摊到去往根的路径即可。

根据我不会证,只需要将深度为 \(x\) 的结点增加 \(x^k - (x - 1)^k\)

这里的增量是固定的,所以只需要用线段树维护它的系数即可。

原本的路径加就是路径上的结点系数都加一。

后面按照套路直接按右端点升序排列,然后处理询问即可。

建议酱紫。

代码

#include <cstdio>
#include <algorithm>
using namespace std;

const int maxn = 5e4 + 5;
const int maxe = 5e4 + 5;
const int maxq = 5e4 + 5;
const int mod = 998244353;

struct t_node
{
    int l, r, sum, lazy, val;
} tree[maxn << 2];

struct e_node
{
    int to, nxt;
} edge[maxe];

struct ques
{
    int p, nd, id;

    bool operator < (const ques& rhs) const { return (p < rhs.p); }
} qry[maxq];

int n, q, pw, cnt;
int ans[maxq];
int fa[maxn], son[maxn], top[maxn];
int head[maxn], sz[maxn], pos[maxn], rk[maxn], dep[maxn];

int qpow(int base, int power)
{
    int res = 1;
    while (power)
    {
        if (power & 1) res = 1ll * res * base % mod;
        base = 1ll * base * base % mod;
        power >>= 1;
    }
    return res;
}

void add_edge(int u, int v)
{
    cnt++;
    edge[cnt].to = v;
    edge[cnt].nxt = head[u];
    head[u] = cnt;
}

void dfs1(int u, int f)
{
    dep[u] = dep[f] + 1;
    sz[u] = 1;
    for (int i = head[u]; i; i = edge[i].nxt)
    {
        int v = edge[i].to;
        dfs1(v, u);
        sz[u] += sz[v];
        if (sz[v] > sz[son[u]]) son[u] = v;
    }
}

void dfs2(int u, int t)
{
    top[u] = t;
    pos[u] = ++cnt;
    rk[cnt] = u;
    if (son[u]) dfs2(son[u], t);
    for (int i = head[u]; i; i = edge[i].nxt)
    {
        int v = edge[i].to;
        if ((v != fa[u]) && (v != son[u])) dfs2(v, v);
    }
}

void push_up(int k) { tree[k].sum = (tree[k << 1].sum + tree[k << 1 | 1].sum) % mod; }

void push_down(int k)
{
    if (!tree[k].lazy) return;
    tree[k << 1].sum = (tree[k << 1].sum + 1ll * tree[k].lazy * tree[k << 1].val % mod) % mod;
    tree[k << 1 | 1].sum = (tree[k << 1 | 1].sum + 1ll * tree[k].lazy * tree[k << 1 | 1].val % mod) % mod;
    tree[k << 1].lazy += tree[k].lazy, tree[k << 1 | 1].lazy += tree[k].lazy;
    tree[k].lazy = 0;
}

void build(int k, int l, int r)
{
    tree[k].l = l, tree[k].r = r;
    if (l == r)
    {
        tree[k].val = (qpow(dep[rk[l]], pw) - qpow(dep[rk[l]] - 1, pw) + mod) % mod;
        return;
    }
    int mid = (l + r) >> 1;
    build(k << 1, l, mid);
    build(k << 1 | 1, mid + 1, r);
    tree[k].val = (tree[k << 1].val + tree[k << 1 | 1].val) % mod;
}

void update(int k, int l, int r, int w)
{
    if ((tree[k].l >= l) && (tree[k].r <= r))
    {
        tree[k].sum = (tree[k].sum + 1ll * w * tree[k].val) % mod;
        tree[k].lazy += w;
        return;
    }
    push_down(k);
    int mid = (tree[k].l + tree[k].r) >> 1;
    if (l <= mid) update(k << 1, l, r, w);
    if (r > mid) update(k << 1 | 1, l, r, w);
    push_up(k);
}

int query(int k, int l, int r)
{
    if ((tree[k].l >= l) && (tree[k].r <= r)) return tree[k].sum;
    push_down(k);
    int mid = (tree[k].l + tree[k].r) >> 1, sum = 0;
    if (l <= mid) sum = (sum + query(k << 1, l, r)) % mod;
    if (r > mid) sum = (sum + query(k << 1 | 1, l, r)) % mod;
    return sum;
}

void modify(int u, int v, int w)
{
    while (top[u] != top[v])
    {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        update(1, pos[top[u]], pos[u], w);
        u = fa[top[u]];
    }
    if (dep[u] > dep[v]) swap(u, v);
    update(1, pos[u], pos[v], w);
}

int queryp(int u, int v)
{
    int res = 0;
    while (top[u] != top[v])
    {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        res = (res + query(1, pos[top[u]], pos[u])) % mod;
        u = fa[top[u]];
    }
    if (dep[u] > dep[v]) swap(u, v);
    res = (res + query(1, pos[u], pos[v])) % mod;
    return res;
}

int main()
{
    scanf("%d%d%d", &n, &q, &pw);
    for (int i = 2; i <= n; i++)
    {
        scanf("%d", &fa[i]);
        add_edge(fa[i], i);
    }
    cnt = 0;
    dfs1(1, 0);
    dfs2(1, 1);
    build(1, 1, n);
    for (int i = 1; i <= q; i++)
    {
        qry[i].id = i;
        scanf("%d%d", &qry[i].p, &qry[i].nd);
    }
    sort(qry + 1, qry + q + 1);
    int cur = 1;
    for (int i = 1; i <= n; i++)
    {
        modify(1, i, 1);
        while ((cur <= q) && (qry[cur].p == i))
        {
            ans[qry[cur].id] = queryp(1, qry[cur].nd);
            cur++;
        }
    }
    for (int i = 1; i <= q; i++) printf("%d\n", (ans[i] % mod + mod) % mod);
    return 0;
}
posted @ 2023-01-04 21:47  kymru  阅读(24)  评论(0编辑  收藏  举报