luogu P5305 [GXOI/GZOI2019]旧词

luogu P5305 [GXOI/GZOI2019]旧词

题意

一开始给出一棵树和k
一共有若干询问
就是每次给出 x , y x,y x,y
∑ d e p t h ( l c a ( i , y ) ) k    ( i &lt; = x ) \sum depth(lca(i,y))^k \ \ (i &lt;= x) depth(lca(i,y))k  (i<=x)
首先这题就是luogu P4211 [LNOI2014]LCA这题的升级版
同理,先把问题离线,按照x排序,考虑如何计算贡献
发现直接用线段树维护一下a * b就好了,a表示当前需要的个数,b表示深度的k次方的和
然后做一下差分就行了
具体还是看代码自行理解一下吧

#include<bits/stdc++.h>
#define int long long
#define N 1000005
#define mod 998244353
using namespace std;
struct QQ{
	int x, y, id;
}q[N];
int cmp(QQ x, QQ y){
	return x.x < y.x;
}
struct edge{
	int v, nxt;
}e[N];
int p[N], eid;
void init(){
	memset(p, -1, sizeof p);
	eid = 0;
}
void insert(int u, int v){
	e[eid].v = v;
	e[eid].nxt = p[u];
	p[u] = eid ++;
}
int qpow(int x, int y){
	int ret = 1;
	for(;y; y >>= 1, x = x * x % mod) if(y & 1) ret = ret * x % mod;
	return ret;
}
int sum[N], sd[N], col[N], ys[N];
int dep[N], fa[N], w[N], size[N], id[N], top[N], tot, n, Q, K, ANS[N];
void build(int rt, int l, int r){
	if(l == r){
		sd[rt] = (qpow(dep[ys[l]], K) - qpow(dep[ys[l]] - 1, K) + mod) % mod;//sd表示深度的k次方差分后的和
		return;
	}
	int mid = (l + r) >> 1;
	build(rt << 1, l, mid);
	build(rt << 1 | 1, mid + 1, r);
	sd[rt] = sd[rt << 1] + sd[rt << 1 | 1], sd[rt] %= mod;
}
void pushdown(int rt){
	if(col[rt]){
		col[rt << 1] += col[rt];
		sum[rt << 1] += col[rt] * sd[rt << 1] % mod, sum[rt << 1] %= mod;
		
		col[rt << 1 | 1] += col[rt];
		sum[rt << 1 | 1] += col[rt] * sd[rt << 1 | 1] % mod, sum[rt << 1 | 1] %= mod;
		col[rt] = 0;
	}
}
void add(int rt, int l, int r, int L, int R){
	pushdown(rt);
	if(L <= l && r <= R){
		col[rt] = 1;sum[rt] += sd[rt], sum[rt] %= mod;	
		return; 
	}
	int mid = (l + r) >> 1;
	if(L <= mid) add(rt << 1, l, mid, L, R);
	if(mid < R) add(rt << 1 | 1, mid + 1, r, L, R);
	sum[rt] = sum[rt << 1] + sum[rt << 1 | 1], sum[rt] %= mod;
}
int query(int rt, int l, int r, int L, int R){
	pushdown(rt);
	if(L <= l && r <= R) return sum[rt];
	int mid = (l + r) >> 1, ret = 0;
	if(L <= mid) ret = query(rt << 1, l, mid, L, R);
	if(R > mid) ret += query(rt << 1 | 1, mid + 1, r, L, R), ret %= mod;
	return ret;
}
void dfs1(int u){
	size[u] = 1;
	for(int i = p[u]; i + 1; i = e[i].nxt){
		int v = e[i].v;
		dep[v] = dep[u] + 1;
		dfs1(v);
		size[u] += size[v];
		if(size[v] > size[w[u]]) w[u] = v;
	}
}
void dfs2(int u){
	id[u] = ++ tot; ys[tot] = u;
	if(w[u]) top[w[u]] = top[u], dfs2(w[u]);
	for(int i = p[u]; i + 1; i = e[i].nxt){
		int v = e[i].v;
		if(v == w[u]) continue;
		top[v] = v;
		dfs2(v);
	}
}
int Query(int u){
	int ret = 0;
	while(u){
		ret += query(1, 1, n, id[top[u]], id[u]), ret %= mod;
		u = fa[top[u]];		
	}
	return ret;
}
void Insert(int u){
	while(u){ 
		add(1, 1, n, id[top[u]], id[u]);
		u = fa[top[u]];
	}
}
signed main(){
	init();
	scanf("%lld%lld%lld", &n, &Q, &K);
	for(int i = 2; i <= n; i ++) scanf("%lld", &fa[i]), insert(fa[i], i);
	dep[1] = top[1] = 1;
	dfs1(1), dfs2(1);
	build(1, 1, n);
	for(int i = 1; i <= Q; i ++) scanf("%lld%lld", &q[i].x, &q[i].y), q[i].id = i;
	
	sort(q + 1, q + 1 + Q, cmp);
	int pos = 0;
	for(int i = 1; i <= Q; i ++){
		while(pos < q[i].x) Insert(++ pos);
		ANS[q[i].id] = Query(q[i].y);
	}
	for(int i = 1; i <= Q; i ++) printf("%lld\n", ANS[i]);
	return 0;
}

这题主要是码量大

posted @ 2019-09-03 20:42  lahlah  阅读(27)  评论(0编辑  收藏  举报