牛客练习赛81D 小Q与树
题意
给定一棵树,每个点 \(x\) 有点权 \(a_x\),求:
\[\sum_{u\neq v}\operatorname{dis}(u,v)\min\{a_u,a_v\}
\]
Solution
考虑 dsu on tree。考虑当前我们在遍历 \(l\) 的后代,遍历到了 \(u\),那么其贡献为:
\[\sum_{\operatorname{lca}(u,v)=l} (dep_u+dep_v-2dep_l)\min\{a_u,a_v\}
\]
对于所有 \(a_u<a_v\),其贡献为:
\[\begin{align*}
&\sum_{\operatorname{lca}(u,v)=l}(dep_u+dep_v-2dep_l)a_u\\
=&cnt(dep_u-2dep_l)a_u+a_u\sum_{\operatorname{lca}(u,v)=l}dep_v
\end{align*}
\]
其中 \(cnt=\sum_{\operatorname{lca}(u,v)=l} [a_u<a_v]\)。
对于所有 \(a_u\ge a_v\),其贡献为:
\[\begin{align*}
&\sum_{\operatorname{lca}(u,v)=l}(dep_u+dep_v-2dep_l)a_v\\
=&(dep_u-2dep_l)\sum_{\operatorname{lca}(u,v)=l}a_v+\sum_{\operatorname{lca}(u,v)=l}dep_va_v
\end{align*}
\]
树状数组分别维护 \(cnt\)、\(\sum a_v\)、\(\sum dep_v\) 和 \(\sum a_vdep_v\) 即可。
#include<bits/stdc++.h>
using namespace std;
#define int long long
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef vector<int> vi;
#define mp make_pair
#define pb push_back
#define fi first
#define se second
inline int read()
{
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
return x*f;
}
const int N=2e5+10,M=4e5+10,maxn=2e5,mod=998244353;
struct bit{
int c[N];
bit(){memset(c,0,sizeof(c));}
void modify(int x,int d){for(;x<=maxn;x+=x&-x)c[x]+=d;}
int query(int x){int ans=0;for(;x;x^=x&-x)ans+=c[x];return ans;}
}T,T1,T2,T3;
//T:dep[x], T1:cnt, T2:a[i], T3:a[i]dep[i]
int head[N],ver[M],nxt[M],tot=0;
void add(int x,int y)
{
ver[++tot]=y;
nxt[tot]=head[x];
head[x]=tot;
}
int sz[N],son[N],f[N],dep[N];
void dfs(int x,int fa)
{
sz[x]=1,dep[x]=dep[fa]+1;
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];if(y==fa)continue;
dfs(y,x),f[x]=(f[x]+1ll*sz[x]*sz[y]%mod)%mod,sz[x]+=sz[y];
if(!son[x]||sz[y]>sz[son[x]])son[x]=y;
}
}
int Ans[N],ans=0,a[N],t[N],dt=0;
void ff(int x)
{
int sum1=T2.query(a[x]-1)%mod,xx=(dep[x]-dt*2+mod)%mod,Sum1=T3.query(a[x]-1)%mod ;
int ans1=(1ll*sum1*xx%mod+Sum1)%mod;
int cnt2=T1.query(maxn)-T1.query(a[x]-1),sum2=T.query(maxn)-T.query(a[x]-1);
int ans2=(1ll*cnt2*t[a[x]]%mod*xx%mod+1ll*t[a[x]]*sum2%mod)%mod;
ans+=(ans1+ans2)%mod;
ans%=mod;
}
void calc(int x,int fa,int op)
{
if(op==0)T.modify(a[x],dep[x]),T1.modify(a[x],1),T2.modify(a[x],t[a[x]]),T3.modify(a[x],t[a[x]]*dep[x]%mod);
else if(op==1)ff(x);
else T.modify(a[x],-dep[x]),T1.modify(a[x],-1),T2.modify(a[x],-t[a[x]]),T3.modify(a[x],-t[a[x]]*dep[x]%mod);;
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];if(y==fa)continue;
calc(y,x,op);
}
}
void dsu(int x,int fa,int op)
{
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];if(y==fa||y==son[x])continue;
dsu(y,x,0);
}
if(son[x])dsu(son[x],x,1);
dt=dep[x];ff(x);
T.modify(a[x],dep[x]),T1.modify(a[x],1),T2.modify(a[x],t[a[x]]),T3.modify(a[x],t[a[x]]*dep[x]%mod);
for(int i=head[x];i;i=nxt[i])
{
int y=ver[i];if(y==fa||y==son[x])continue;
calc(y,x,1),calc(y,x,0);
}
Ans[x]=ans;
if(!op)calc(x,fa,-1);
ans=0;
}
signed main()
{
int n=read(),m=n;
for(int i=1;i<=n;i++)t[i]=a[i]=read();
for(int i=1;i<n;i++){int u=read(),v=read();add(u,v),add(v,u);}
sort(t+1,t+m+1),m=unique(t+1,t+m+1)-t-1;
for(int i=1;i<=n;i++)a[i]=lower_bound(t+1,t+m+1,a[i])-t;
// for(int i=1;i<=n;i++)printf("a[%d]=%d\n",i,a[i]);
dfs(1,0),dsu(1,0,1);
// for(int i=1;i<=n;i++)printf("dep[%d]=%d\n",i,dep[i]);
int sum=0;
for(int i=1;i<=n;i++)sum+=Ans[i],sum%=mod;
printf("%lld",sum*2%mod);
return 0;
}