牛客练习赛81D 小Q与树

题意

Link

给定一棵树,每个点 \(x\) 有点权 \(a_x\),求:

\[\sum_{u\neq v}\operatorname{dis}(u,v)\min\{a_u,a_v\} \]

Solution

考虑 dsu on tree。考虑当前我们在遍历 \(l\) 的后代,遍历到了 \(u\),那么其贡献为:

\[\sum_{\operatorname{lca}(u,v)=l} (dep_u+dep_v-2dep_l)\min\{a_u,a_v\} \]

对于所有 \(a_u<a_v\),其贡献为:

\[\begin{align*} &\sum_{\operatorname{lca}(u,v)=l}(dep_u+dep_v-2dep_l)a_u\\ =&cnt(dep_u-2dep_l)a_u+a_u\sum_{\operatorname{lca}(u,v)=l}dep_v \end{align*} \]

其中 \(cnt=\sum_{\operatorname{lca}(u,v)=l} [a_u<a_v]\)

对于所有 \(a_u\ge a_v\),其贡献为:

\[\begin{align*} &\sum_{\operatorname{lca}(u,v)=l}(dep_u+dep_v-2dep_l)a_v\\ =&(dep_u-2dep_l)\sum_{\operatorname{lca}(u,v)=l}a_v+\sum_{\operatorname{lca}(u,v)=l}dep_va_v \end{align*} \]

树状数组分别维护 \(cnt\)\(\sum a_v\)\(\sum dep_v\)\(\sum a_vdep_v\) 即可。

#include<bits/stdc++.h>
using namespace std;
#define int long long
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef vector<int> vi;
#define mp make_pair
#define pb push_back
#define fi first
#define se second
inline int read()
{
	int x=0,f=1;char c=getchar();
	while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
	while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
	return x*f;
}
const int N=2e5+10,M=4e5+10,maxn=2e5,mod=998244353;
struct bit{
	int c[N];
	bit(){memset(c,0,sizeof(c));}
	void modify(int x,int d){for(;x<=maxn;x+=x&-x)c[x]+=d;}
	int query(int x){int ans=0;for(;x;x^=x&-x)ans+=c[x];return ans;}
}T,T1,T2,T3;
//T:dep[x],   T1:cnt,   T2:a[i],   T3:a[i]dep[i] 
int head[N],ver[M],nxt[M],tot=0;
void add(int x,int y)
{
	ver[++tot]=y;
	nxt[tot]=head[x];
	head[x]=tot;
}
int sz[N],son[N],f[N],dep[N];
void dfs(int x,int fa)
{
	sz[x]=1,dep[x]=dep[fa]+1;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];if(y==fa)continue;
		dfs(y,x),f[x]=(f[x]+1ll*sz[x]*sz[y]%mod)%mod,sz[x]+=sz[y];
		if(!son[x]||sz[y]>sz[son[x]])son[x]=y;
	}
}
int Ans[N],ans=0,a[N],t[N],dt=0;
void ff(int x)
{
	int sum1=T2.query(a[x]-1)%mod,xx=(dep[x]-dt*2+mod)%mod,Sum1=T3.query(a[x]-1)%mod	;
	int ans1=(1ll*sum1*xx%mod+Sum1)%mod;
	
	int cnt2=T1.query(maxn)-T1.query(a[x]-1),sum2=T.query(maxn)-T.query(a[x]-1);
	int ans2=(1ll*cnt2*t[a[x]]%mod*xx%mod+1ll*t[a[x]]*sum2%mod)%mod;
	
	ans+=(ans1+ans2)%mod;
	ans%=mod;
}
void calc(int x,int fa,int op)
{
	if(op==0)T.modify(a[x],dep[x]),T1.modify(a[x],1),T2.modify(a[x],t[a[x]]),T3.modify(a[x],t[a[x]]*dep[x]%mod);
	else if(op==1)ff(x);
	else T.modify(a[x],-dep[x]),T1.modify(a[x],-1),T2.modify(a[x],-t[a[x]]),T3.modify(a[x],-t[a[x]]*dep[x]%mod);;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];if(y==fa)continue;
		calc(y,x,op);
	}
}
void dsu(int x,int fa,int op)
{
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];if(y==fa||y==son[x])continue;
		dsu(y,x,0);
	}
	if(son[x])dsu(son[x],x,1);
	dt=dep[x];ff(x);
	T.modify(a[x],dep[x]),T1.modify(a[x],1),T2.modify(a[x],t[a[x]]),T3.modify(a[x],t[a[x]]*dep[x]%mod);
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];if(y==fa||y==son[x])continue;
		calc(y,x,1),calc(y,x,0);
	}
	Ans[x]=ans;
	if(!op)calc(x,fa,-1);
	ans=0;
}
signed main()
{
	int n=read(),m=n;
	for(int i=1;i<=n;i++)t[i]=a[i]=read();
	for(int i=1;i<n;i++){int u=read(),v=read();add(u,v),add(v,u);}
	sort(t+1,t+m+1),m=unique(t+1,t+m+1)-t-1;
	for(int i=1;i<=n;i++)a[i]=lower_bound(t+1,t+m+1,a[i])-t;
//	for(int i=1;i<=n;i++)printf("a[%d]=%d\n",i,a[i]);
	dfs(1,0),dsu(1,0,1);
//	for(int i=1;i<=n;i++)printf("dep[%d]=%d\n",i,dep[i]);
	int sum=0;
	for(int i=1;i<=n;i++)sum+=Ans[i],sum%=mod;
	printf("%lld",sum*2%mod);
        return 0;
}
posted @ 2021-10-29 10:01  zzt1208  阅读(75)  评论(0编辑  收藏  举报