luogu P5305 [GXOI/GZOI2019]旧词
题意
一开始给出一棵树和k
一共有若干询问
就是每次给出
x
,
y
x,y
x,y求
∑
d
e
p
t
h
(
l
c
a
(
i
,
y
)
)
k
(
i
<
=
x
)
\sum depth(lca(i,y))^k \ \ (i <= x)
∑depth(lca(i,y))k (i<=x)
首先这题就是luogu P4211 [LNOI2014]LCA这题的升级版
同理,先把问题离线,按照x排序,考虑如何计算贡献
发现直接用线段树维护一下a * b就好了,a表示当前需要的个数,b表示深度的k次方的和
然后做一下差分就行了
具体还是看代码自行理解一下吧
#include<bits/stdc++.h>
#define int long long
#define N 1000005
#define mod 998244353
using namespace std;
struct QQ{
int x, y, id;
}q[N];
int cmp(QQ x, QQ y){
return x.x < y.x;
}
struct edge{
int v, nxt;
}e[N];
int p[N], eid;
void init(){
memset(p, -1, sizeof p);
eid = 0;
}
void insert(int u, int v){
e[eid].v = v;
e[eid].nxt = p[u];
p[u] = eid ++;
}
int qpow(int x, int y){
int ret = 1;
for(;y; y >>= 1, x = x * x % mod) if(y & 1) ret = ret * x % mod;
return ret;
}
int sum[N], sd[N], col[N], ys[N];
int dep[N], fa[N], w[N], size[N], id[N], top[N], tot, n, Q, K, ANS[N];
void build(int rt, int l, int r){
if(l == r){
sd[rt] = (qpow(dep[ys[l]], K) - qpow(dep[ys[l]] - 1, K) + mod) % mod;//sd表示深度的k次方差分后的和
return;
}
int mid = (l + r) >> 1;
build(rt << 1, l, mid);
build(rt << 1 | 1, mid + 1, r);
sd[rt] = sd[rt << 1] + sd[rt << 1 | 1], sd[rt] %= mod;
}
void pushdown(int rt){
if(col[rt]){
col[rt << 1] += col[rt];
sum[rt << 1] += col[rt] * sd[rt << 1] % mod, sum[rt << 1] %= mod;
col[rt << 1 | 1] += col[rt];
sum[rt << 1 | 1] += col[rt] * sd[rt << 1 | 1] % mod, sum[rt << 1 | 1] %= mod;
col[rt] = 0;
}
}
void add(int rt, int l, int r, int L, int R){
pushdown(rt);
if(L <= l && r <= R){
col[rt] = 1;sum[rt] += sd[rt], sum[rt] %= mod;
return;
}
int mid = (l + r) >> 1;
if(L <= mid) add(rt << 1, l, mid, L, R);
if(mid < R) add(rt << 1 | 1, mid + 1, r, L, R);
sum[rt] = sum[rt << 1] + sum[rt << 1 | 1], sum[rt] %= mod;
}
int query(int rt, int l, int r, int L, int R){
pushdown(rt);
if(L <= l && r <= R) return sum[rt];
int mid = (l + r) >> 1, ret = 0;
if(L <= mid) ret = query(rt << 1, l, mid, L, R);
if(R > mid) ret += query(rt << 1 | 1, mid + 1, r, L, R), ret %= mod;
return ret;
}
void dfs1(int u){
size[u] = 1;
for(int i = p[u]; i + 1; i = e[i].nxt){
int v = e[i].v;
dep[v] = dep[u] + 1;
dfs1(v);
size[u] += size[v];
if(size[v] > size[w[u]]) w[u] = v;
}
}
void dfs2(int u){
id[u] = ++ tot; ys[tot] = u;
if(w[u]) top[w[u]] = top[u], dfs2(w[u]);
for(int i = p[u]; i + 1; i = e[i].nxt){
int v = e[i].v;
if(v == w[u]) continue;
top[v] = v;
dfs2(v);
}
}
int Query(int u){
int ret = 0;
while(u){
ret += query(1, 1, n, id[top[u]], id[u]), ret %= mod;
u = fa[top[u]];
}
return ret;
}
void Insert(int u){
while(u){
add(1, 1, n, id[top[u]], id[u]);
u = fa[top[u]];
}
}
signed main(){
init();
scanf("%lld%lld%lld", &n, &Q, &K);
for(int i = 2; i <= n; i ++) scanf("%lld", &fa[i]), insert(fa[i], i);
dep[1] = top[1] = 1;
dfs1(1), dfs2(1);
build(1, 1, n);
for(int i = 1; i <= Q; i ++) scanf("%lld%lld", &q[i].x, &q[i].y), q[i].id = i;
sort(q + 1, q + 1 + Q, cmp);
int pos = 0;
for(int i = 1; i <= Q; i ++){
while(pos < q[i].x) Insert(++ pos);
ANS[q[i].id] = Query(q[i].y);
}
for(int i = 1; i <= Q; i ++) printf("%lld\n", ANS[i]);
return 0;
}
这题主要是码量大