【纪中集训2019.3.11】树上四次求和

题目

描述

给定一棵\(n\)个点的树和一个\(n\)元排列\(a_{i}\) ,\(q\)个询问,每次询问一个\(k\),求:

\[\begin{align} \sum_{l=1}^{k}\sum_{r=l}^{k} \sum_{i=l}^{r}\sum_{j=i}^{r} dis(a_{i},a_{j})\end{align} \\ 其中 k\le n ,dis(u,v)为u和v的树上最短距离 \]

\(998244353\)取模的值;

范围

$n,q \le 1e5 \ , \ u,v,k \le n $

题解:

  • 考虑每次的增量:

  • 对于在\(n\)之前的区间,增量和上一次的增量相同;

  • 对于右端点为\(n\)的区间,新的增量 = \(\sum_{i=1}^{n} dis(a_{n},a_{i}) i = \sum_{i=1}^{n} (i dep(a_{n}) + idep(a_{i}) -2 idep(lca(a_{i},a_{n}))\)

  • 只需要考虑 \(\sum_{i=1}^{n} dep(a_{i},a_{n}) \times i\)

  • 这个直接修改一个点的到根的树链即可,树剖或者\(LCT\)维护;

    #include<bits/stdc++.h>
    #define ll long long 
    #define mod 998244353
    using namespace std;
    const int N=100010;
    int n,m,a[N],o=1,hd[N],ch[N][2],fa[N],sum[N],rev[N],ly[N],sz[N],w[N],dep[N],ans[N];
    struct Edge{int v,nt;}E[N<<1];
    void adde(int u,int v){
    	E[o]=(Edge){v,hd[u]};hd[u]=o++;
    	E[o]=(Edge){u,hd[v]};hd[v]=o++;
    }
    char gc(){
    	static char*p1,*p2,s[1000000];
    	if(p1==p2)p2=(p1=s)+fread(s,1,1000000,stdin);
    	return(p1==p2)?EOF:*p1++; 
    } 
    int rd(){
    	int x=0;char c=gc();
    	while(c<'0'||c>'9')c=gc();
    	while(c>='0'&&c<='9')x=(x<<1)+(x<<3)+c-'0',c=gc();
    	return x;
    }
    char ps[1000000],*pp=ps;
    void push(char x){
    	if(pp==ps+1000000)fwrite(ps,1,1000000,stdout),pp=ps;
    	*pp++=x;
    }
    void write(int x){
    	static int sta[20],top;
    	if(!x){push('0');push('\n');return;}
    	while(x)sta[++top]=x%10,x/=10;
    	while(top)push(sta[top--]^'0');
    	push('\n');
    }
    void flush(){fwrite(ps,1,pp-ps,stdout);pp=ps;}
    void pushup(int k){
    	sum[k]=((ll)sum[ch[k][0]]+sum[ch[k][1]]+w[k])%mod;
    	sz[k]=sz[ch[k][0]]+sz[ch[k][1]]+1; 
    }
    void pushdown(int k){
    	int &l=ch[k][0],&r=ch[k][1];
    	if(rev[k]){
    		rev[l]^=1,rev[r]^=1;
    		swap(l,r);
    		rev[k]^=1;
    	}
    	if(ly[k]){
    		int x=ly[k];
    		sum[l]=(sum[l]+1ll*sz[l]*x%mod)%mod;
    		sum[r]=(sum[r]+1ll*sz[r]*x%mod)%mod;
    		ly[l]+=x;if(ly[l]>=mod)ly[l]-=mod;
    		ly[r]+=x;if(ly[r]>=mod)ly[r]-=mod;
    		w[l]+=x;if(w[l]>=mod)w[l]-=mod;
    		w[r]+=x;if(w[r]>=mod)w[r]-=mod;
    		ly[k]=0;
    	}
    }
    bool isrt(int x){return ch[fa[x]][0]!=x&&ch[fa[x]][1]!=x;}
    void push(int x){
    	if(!isrt(x))push(fa[x]);
    	pushdown(x);
    }
    void rotate(int x){
    	int y=fa[x],z=fa[y];
    	if(!isrt(y))ch[z][ch[z][1]==y]=x;
    	int l=ch[y][1]==x,r=l^1;
    	fa[x]=z,fa[y]=x,fa[ch[x][r]]=y;
    	ch[y][l]=ch[x][r],ch[x][r]=y;
    	pushup(y),pushup(x);
    }
    void splay(int x){
    	push(x);
    	for(int y,z;!isrt(x);rotate(x)){
    		y=fa[x],z=fa[y];
    		if(!isrt(y))rotate((ch[y][0]==x)^(ch[z][0]==y) ? x : y); 
    	}
    }
    void access(int x){
    	for(int y=0;x;y=x,x=fa[x]){
    		splay(x);
    		ch[x][1]=y;
    		pushup(x);
    	}
    }
    void mkrt(int x){access(x);splay(x);rev[x]^=1;}
    void split(int x,int y){mkrt(x);access(y);splay(y);}
    void link(int x,int y){mkrt(x),fa[x]=y;}
    void dfs(int u,int F){
    	dep[u]=dep[F]+1;
    	for(int i=hd[u];i;i=E[i].nt){
    		int v=E[i].v;
    		if(v==F)continue;
    		dfs(v,u);
    		link(v,u);
    	}
    }
    int main(){
    	freopen("sumsumsum.in","r",stdin);
    	freopen("sumsumsum.out","w",stdout);
    	n=rd();m=rd();
    	for(int i=1;i<=n;++i)sz[i]=1;
    	for(int i=1;i<n;++i){
    		int u=rd(),v=rd();
    		adde(u,v);
    	}
    	dfs(1,0);
    	for(int i=1,x,y=0,z=0;i<=n;++i){
    		x=rd();
    		y+=1ll*i*dep[x]%mod;if(y>=mod)y-=mod;
    		access(x),splay(x);
    		sum[x]=(sum[x]+1ll*sz[x]*i%mod)%mod;
    		ly[x]+=i;if(ly[x]>=mod)ly[x]-=mod;
    		w[x]+=i;if(w[x]>=mod)w[x]-=mod;
    		z=(z + 1ll*i*(i+1)/2%mod*dep[x]%mod + y - 2*sum[x])%mod;
    		if(z<0)z+=mod;
    		ans[i]=(ans[i-1]+z)%mod;
    	}
    	for(int i=1,x;i<=m;++i)/*printf("%d\n",ans[rd()]);*/write(ans[rd()]);
    	flush();
    	return 0;
    }
    
posted @ 2019-03-15 07:31  大米饼  阅读(229)  评论(0编辑  收藏  举报