[P3676]小清新数据结构题
Description:
给你一棵树,每次询问以一个点为根时所有子树点权和的平方和
带修改
Hint:
\(n\le 2*10^5\)
Solution:
这题只要推出式子就很简单了
如果不换根这个平方和树剖直接做就行了
考虑换根的影响了哪些点的贡献
显然只影响了\(1\)到\(u\)的路径上的点
把1到\(u\)这条路径上的点依次标记为\(1,2,3......k\)
我们设\(a_i\)为以1为根时\(i\)的点权和,\(b_i\)为以\(u\)为根的点权和
\(Ans=ans_1-\sum a_i^2 + \sum b_i^2\)
注意到\(a_{i+1}+b_i=sum\)
\(Ans=ans_1-\sum a_i^2 -a_1^2+b_k^2 + \sum (sum-a_{i+1})^2\)
消掉\(\sum a_i^2\)
\(Ans=ans_1-k*sum^2-2*sum*\sum a_i\)
预处理出\(ans1\),每次算一条链就行
(注意最后并没有算\(a_1\))
#include <map>
#include <set>
#include <stack>
#include <cmath>
#include <queue>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#define ls p<<1
#define rs p<<1|1
using namespace std;
typedef long long ll;
const ll mxn=1e6+5;
ll n,m,cnt,hd[mxn];
inline ll read() {
char c=getchar(); ll x=0,f=1;
while(c>'9'||c<'0') {if(c=='-') f=-1;c=getchar();}
while(c<='9'&&c>='0') {x=(x<<3)+(x<<1)+(c&15);c=getchar();}
return x*f;
}
inline void chkmax(ll &x,ll y) {if(x<y) x=y;}
inline void chkmin(ll &x,ll y) {if(x>y) x=y;}
struct ed {
ll to,nxt;
}t[mxn<<1];
ll df;
ll a[mxn],f[mxn],sz[mxn],rk[mxn],dfn[mxn],top[mxn],son[mxn];
ll tr[mxn<<2],pw[mxn<<2],tag[mxn<<2],len[mxn<<2],sum[mxn];
inline void add(ll u,ll v) {
t[++cnt]=(ed) {v,hd[u]}; hd[u]=cnt;
}
void dfs1(ll u,ll fa) {
f[u]=fa; sz[u]=1; sum[u]=a[u];
for(ll i=hd[u];i;i=t[i].nxt) {
ll v=t[i].to;
if(v==fa) continue ;
dfs1(v,u); sz[u]+=sz[v]; sum[u]+=sum[v];
if(sz[son[u]]<sz[v]) son[u]=v;
}
}
void dfs2(ll u,ll tp) {
dfn[u]=++df; rk[df]=u; top[u]=tp;
if(son[u]) dfs2(son[u],tp);
for(ll i=hd[u];i;i=t[i].nxt) {
ll v=t[i].to;
if(v==f[u]||v==son[u]) continue ;
dfs2(v,v);
}
}
void push_up(ll p) {
tr[p]=tr[ls]+tr[rs];
pw[p]=pw[ls]+pw[rs];
}
void push_down(ll p) {
if(tag[p]) {
tag[ls]+=tag[p]; tag[rs]+=tag[p];
pw[ls]+=2*tr[ls]*tag[p]+tag[p]*tag[p]*len[ls];
pw[rs]+=2*tr[rs]*tag[p]+tag[p]*tag[p]*len[rs];
tr[ls]+=len[ls]*tag[p];
tr[rs]+=len[rs]*tag[p];
tag[p]=0;
}
}
void build(ll l,ll r,ll p) {
if(l==r) {
len[p]=1;
tr[p]=sum[rk[l]];
pw[p]=sum[rk[l]]*sum[rk[l]];
return ;
}
ll mid=(l+r)>>1;
build(l,mid,ls); build(mid+1,r,rs);
push_up(p); len[p]=r-l+1;
}
void update(ll l,ll r,ll ql,ll qr,ll val,ll p) {
if(ql<=l&&r<=qr) {
tag[p]+=val;
pw[p]+=val*val*len[p]+2*val*tr[p];
tr[p]+=val*len[p];
return ;
}
ll mid=(l+r)>>1; push_down(p);
if(ql<=mid) update(l,mid,ql,qr,val,ls);
if(qr>mid) update(mid+1,r,ql,qr,val,rs);
push_up(p);
}
ll query(ll l,ll r,ll ql,ll qr,ll p) {
if(ql<=l&&r<=qr) return tr[p];
ll mid=(l+r)>>1; push_down(p); ll res=0;
if(ql<=mid) res+=query(l,mid,ql,qr,ls);
if(qr>mid) res+=query(mid+1,r,ql,qr,rs);
return res;
}
ll tp;
void modify(ll x,ll y) {
y-=a[x]; a[x]+=y; tp+=y;
while(x) {
update(1,n,dfn[top[x]],dfn[x],y,1);
x=f[top[x]];
}
}
ll ask(ll x) {
ll ans=pw[1],res1=0,res2=0;
while(top[x]!=1) {
res1+=dfn[x]-dfn[top[x]]+1;
res2+=query(1,n,dfn[top[x]],dfn[x],1);
x=f[top[x]];
}
res1+=dfn[x]-1;
if(x!=1) res2+=query(1,n,dfn[1]+1,dfn[x],1);
return ans+tp*(res1*tp-res2*2);
}
int main()
{
n=read(); m=read(); ll u,v,opt,x,y;
for(ll i=1;i<n;++i) {
u=read(); v=read();
add(u,v); add(v,u);
}
for(ll i=1;i<=n;++i) a[i]=read();
dfs1(1,0); dfs2(1,1); build(1,n,1); tp=sum[1];
for(ll i=1;i<=m;++i) {
opt=read();
if(opt==1) {
x=read(); y=read();
modify(x,y);
}
else x=read(),printf("%lld\n",ask(x));
}
return 0;
}