[luogu3676] 小清新数据结构题 [树链剖分+线段树]

题面

传送门

思路

本来以为这道题可以LCT维护子树信息直接做的,后来发现这样会因为splay形态改变影响子树权值平方和,是splay本身的局限性导致的

所以只能另辟蹊径

首先,我们考虑询问点都在1的情况

考虑一次修改带来的影响:

假设当前节点的值变动量为$delta$,修改节点为$u$

那么对于所有位于路径$(1,u)$上的节点而言,它们的子树和以及子树平方和都会有改变

设$sum(u)$表示子树点权和,$sumsqr(u)$表示点权和的平方

那么$\forall v \in (1,u)$,$sum(v)+=delta$,$sumsqr(v)+=delta\ast 2\ast sum(v)+delta\ast delta$

又可见$ans=\sum_{u} sumsqr(u)$

那么$ans$在这次事件中的变化量可以表示如下:

设$len$为路径$(1,u)$的长度

那么$ans+=len\ast delta\ast delta+delta\ast 2\ast \sum_{v \in (1,u)} sum(v)$

所以,我们可以使用一棵线段树维护$sum(u)$的值,树链剖分一下,只需要支持区间修改和区间查询

接下来考虑询问点在$x$的情况

可以发现,如果我们考虑$ans(1)$到$ans(x)$中各个位置的贡献,容易发现,依然只有$v\in (1,x)$的节点贡献改变了

设$a_i$表示节点$i$在以1为根(也就是我们的树剖维护的东西)时的子树点权和,$b_i$表示以$x$为根的时候的点权和,$sum$为总点权和

可以得到:

$ans(x)=ans(1)-\sum_{v\in (1,x)} a_v^2 +\sum_{v\in (1,x)} b_v^2$

同时,$a_i$和$b_i$有一个性质:

$a_1=b_x=sum=a_v+b_{fa(v)}$,其中$v \neq 1,x$

那么化简上面式子

$ans(x)=ans(1)-\sum_{v\in (1,x),v\neq 1}a_v^2 + sum_{v\in(1,x),v\neq x}b_v^2$

为了方便,我们用$a_i$表示路径上的第$i$个点,$b_i$同理

$ans(x)=ans(1)-\sum_{i=2}{len}a_i2 + sum_{i=2}{len}(sum-a_i)2$

$ans(x)=ans(1)+(len-1)\ast (len-1)\ast sum-2\ast (len-1)ast \sum_{i=2}^{len}a_i$

这样就可以算了,依然是树链剖分+线段树解决

注意最后面这个式子,如果为了方便直接查询到根,可以把$len-1$提出来,并把后面的$i=2...len$变成$1...len$,在前面把一个$len-1$变成$len+1$,即:

$ans(x)=ans(1)+(len-1)\ast ((len+1)\ast sum-2\ast \sum_{i=1}^{len}a_i)$

Code

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cassert>
#define ll long long
using namespace std;
inline int read(){
    int re=0,flag=1;char ch=getchar();
    while(!isdigit(ch)){
        if(ch=='-') flag=-1;
        ch=getchar();
    }
    while(isdigit(ch)) re=(re<<1)+(re<<3)+ch-'0',ch=getchar();
    return re*flag;
}
int n,first[200010],dep[200010],siz[200010],son[200010],top[200010],pos[200010],back[200010],fa[200010],clk,cnte;
struct edge{
    int to,next;
}a[400010];
ll w[200010],s[200010],ans1;
inline void adde(const int u,const int v){
    a[++cnte]=(edge){v,first[u]};first[u]=cnte;
    a[++cnte]=(edge){u,first[v]};first[v]=cnte;
}
void dfs1(const int u,const int f){
    int i,v,maxn=0;
    dep[u]=dep[f]+1;fa[u]=f;
    siz[u]=1;son[u]=0;s[u]=w[u];
    for(i=first[u];~i;i=a[i].next){
        v=a[i].to;if(v==f) continue;
        dfs1(v,u);
        siz[u]+=siz[v];s[u]+=s[v];
        if(maxn<siz[v]) son[u]=v,maxn=siz[v];
    }
}
void dfs2(const int u,const int t){
    int i,v;
    pos[u]=++clk;back[clk]=u;top[u]=t;
    if(son[u]) dfs2(son[u],t);
    for(i=first[u];~i;i=a[i].next){
        v=a[i].to;if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);
    }
}
ll len[800010],sum[800010],lazy[800010];
void update(int num){
    sum[num]=sum[num<<1|1]+sum[num<<1];
}
void push(int l,int r,int num){
    if(l==r||!lazy[num]) return ;
    int mid=(l+r)>>1;
    sum[num<<1]+=(ll)(mid-l+1)*lazy[num];
    sum[num<<1|1]+=(ll)(r-mid)*lazy[num];
    lazy[num<<1]+=lazy[num];
    lazy[num<<1|1]+=lazy[num];
    lazy[num]=0;
}
int ql,qr;ll val;
void build(const int l,const int r,const int num){
    if(l==r){sum[num]=s[back[l]];return;}
    const int mid=(l+r)>>1;
    build(l,mid,num<<1);build(mid+1,r,num<<1|1);
    update(num);
}
void change(const int l,const int r,const int num){
    if(l>=ql&&r<=qr){
        sum[num]+=(r-l+1)*val;
        lazy[num]+=val;
        return;
    }
    push(l,r,num);
    const int mid=(l+r)>>1;
    if(mid>=ql) change(l,mid,num<<1);
    if(mid<qr) change(mid+1,r,num<<1|1);
    update(num);
}
ll query(const int l,const int r,const int num){
    if(l>=ql&&r<=qr) return sum[num];
    push(l,r,num);
    const int mid=(l+r)>>1;ll re=0;
    if(mid>=ql) re+=query(l,mid,num<<1);
    if(mid<qr) re+=query(mid+1,r,num<<1|1);
    return re;
}
void add(int u,const int v){
    int f;val=v;
    while(u){
        f=top[u];
        ql=pos[f];qr=pos[u];
        change(1,n,1);
        u=fa[f];
    }
}
ll ask(int u){
    int f;ll re=0;
    while(u){
        f=top[u];
        ql=pos[f];qr=pos[u];
        re+=query(1,n,1);
        u=fa[f];
    }
    return re;
}
int main(){
    n=read();int Q=read(),i,t1,t2,t3;ll s1;
    memset(first,-1,sizeof(first));
    for(i=1;i<n;i++){
        t1=read();t2=read();
        adde(t1,t2);
    }
    for(i=1;i<=n;i++) w[i]=read();
    dfs1(1,0);dfs2(1,0);build(1,n,1);
    for(i=1;i<=n;i++) ans1+=s[i]*s[i];
    while(Q--){
        t3=read();t1=read();
        if(t3==1){
            t2=read();
            t2=t2-w[t1];w[t1]+=t2;
            ans1+=t2*t2*dep[t1]+t2*2*ask(t1);
            add(t1,t2);
        }
        else{
            ql=qr=1;s1=query(1,n,1);
            printf("%lld\n",ans1+s1*((dep[t1]+1)*s1-2*ask(t1)));
        }
    }
}
posted @ 2018-10-12 18:42  dedicatus545  阅读(180)  评论(0编辑  收藏  举报