CF1254D Tree Queries
CF1254D Tree Queries
好题一道。至少让我一种新的套路(bush)。
首先我们来考虑操作 1,我们发现当我们选好一个点 \(r\) 后,\(r\) 所在的 \(u\) 的一颗子树中的点总是能不经过 \(u\) 到达 \(r\),所以一个 \(r\) 能对一个点 \(v\) 产生贡献,当且仅当它不在 \(v\) 所在的 \(u\) 的子树中。
很容易得到,每次修改后,对于每一颗子树,其中所有点的期望都会增加 \(\frac{n-size_v}{n}\);而对于子树 \(u\) 以外的部分,都会增加 \(\frac{size_u}{n}\);对于点 \(u\),则会增加 \(d\)。那我们可以暂时忽略这个分母 \(n\),最后再乘上去,这样问题就转化成了如何维护这个增加的 \(n-size[v]\) 即可。
首先我们可以直接树剖暴力搞,拿线段树之类的维护区间加单点查。然鹅,这样的做法菊花图复杂度直接起飞(因为每次都要把 \(u\) 的所有子树搞掉 xwx )。
第一个优化就是我们可以考虑搞树上差分,这样就可以免除区间修改;但是这样还是无法做到快速修改。然后我就贺翻题解学了个新套路:只对重儿子进行修改,然后每次查询把链顶贡献加上即可。
具体怎么干捏?就是,假如我们现在进行操作 \(1\),节点为 \(u\),权值为 \(d\),那么节点 \(u\) 要加上 \(d\) 的贡献,而 \(1\) 又要多上 \(size_u \times d\) 的贡献,\(u\) 的重儿子要加上 \(n-size_u\) 的贡献,综合起来就是,在 1 处加上 \(size_u \times d\),在 \(u\) 处加上 \((n-size_u)\times d\),在重儿子 \(son\) 处减走 \((size_{son})\times d\)。这里可以自己导一导儿子与父亲之间的关系。
对于每次询问,因为如果是轻链,则不会被修改,也就是说,会少一个差分,我们只需要每次把这个差分加上即可。我们发现,链顶处就是轻重儿子分开的地方,所以我们要加上链顶的贡献。
代码:
#include<bits/stdc++.h>
using namespace std;
const int N = 150050, mod = 998244353;
inline int read(){
int x = 0, f = 1; char ch = getchar();
while(ch<'0' || ch>'9'){if(ch == '-') f = -1; ch = getchar();}
while(ch>='0'&&ch<='9'){x = x*10+ch-48; ch = getchar();}
return x * f;
}
inline int fpow(int a, int b){
int ret = 1;
while(b){
if(b & 1){
ret = (1ll*ret*a)%mod;
}
b>>=1;
a = (1ll*a*a)%mod;
}
return ret;
}
int n, Q, p;
int inv;
//————————————
struct node{
int nxt, to;
}edge[N<<1];
int head[N], tot;
void add(int u, int v){
edge[++tot].nxt = head[u];
edge[tot].to = v;
head[u] = tot;
}
int dfn[N], dep[N], fa[N], son[N], siz[N], totd;
void dfs1(int u, int fath){
fa[u] = fath;
siz[u] = 1;
dep[u] = dep[fath]+1;
for(int i = head[u]; i; i = edge[i].nxt){
int v = edge[i].to;
if(v == fath) continue;
dfs1(v, u);
siz[u]+=siz[v];
if(siz[son[u]]<siz[v]) son[u] = v;
}
}
int top[N];
void dfs2(int u, int Top){
top[u] = Top;
dfn[u] = ++totd;
if(!son[u]) return;
dfs2(son[u], Top);
for(int i = head[u]; i; i = edge[i].nxt){
int v = edge[i].to;
if(!dfn[v]) dfs2(v, v);
}
}
//————————————————
int val[N];
int tc[N];
inline int lowbit(int x){
return x&(-x);
}
void insert(int pos, int val){
for(int i = pos; i<=n; i+=lowbit(i)){
tc[i] = (1ll*tc[i]+val)%mod;
}
}
inline int query(int pos){
int ret = 0;
for(int i = pos; i; i-=lowbit(i)){
ret = (1ll*ret+tc[i])%mod;
}
return ret;
}
int main(){
n = read(), Q = read();
inv = fpow(n, mod-2);
for(int i = 1; i<n; ++i){
int u = read(), v = read();
add(u, v);
add(v, u);
}
dfs1(1, 0);
dfs2(1, 1);
while(Q--){
int op = read();
if(op == 1){
int v = read(), d = read();
(val[v]+=d)%=mod;
if(v ^ 1){
insert(dfn[v], 1ll*d*(n-siz[v])%mod);
}
insert(1, 1ll*d*siz[v]%mod);
if(son[v]){
insert(dfn[son[v]], 1ll*d*(mod-siz[son[v]])%mod);
}
} else{
int v = read();
int ans = 0;
while(top[v] ^ 1){
ans = (((ans+query(dfn[v]))%mod-query(dfn[top[v]]-1))%mod+1ll*val[fa[top[v]]]*(mod-siz[top[v]])%mod)%mod;//求差分前缀和+补差价
v = fa[top[v]];
}
(ans+=query(dfn[v]))%=mod;
ans = (ans+mod)%mod;
printf("%d\n", 1ll*ans*inv%mod);
}
}
system("pause");
return 0;
}