CF915F Imbalance Value of a Tree (并查集)
题目大意:给你一棵树,每个点有点权a_{i},求$\sum _{i=1}^{n} \sum _{j=i}^{n} f(i,j)$,$f(i,j)$表示i,j,路径上的点的最大权值-最小权值
正解的思路好神啊
正解:
首先,原式可以拆成$\sum _{i=1}^{n} \sum _{j=i}^{n} max(i,j) \; - \; \sum _{i=1}^{n} \sum _{j=i}^{n} min(i,j)$
max的求法和min类似,这里只讨论min的求法
把点按照从大到小排序,依次加入树里
感性理解成以当前点的点权作为最小值,那么这个点会向它周围已经被加入树里的联通块"扩散"去更新答案
答案就是这个点周围(剩余联通块的点数-当前联通块的点数)*点权,然后 剩余点数-当前联通块点数
也可以省去减法步骤,把最终答案除以二
并查集维护联通块即可,算上排序,总时间约为$O(nlogn)$
1 #include <queue> 2 #include <vector> 3 #include <cstdio> 4 #include <cstring> 5 #include <algorithm> 6 #define N 1001000 7 #define ll long long 8 using namespace std; 9 10 int gint() 11 { 12 int ret=0,f=1;char c=getchar(); 13 while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();} 14 while(c>='0'&&c<='9'){ret=ret*10+c-'0';c=getchar();} 15 return ret*f; 16 } 17 int n,cte,tim; 18 int w[N],head[N],use[N]; 19 int fa[N],sz[N]; 20 void init(){for(int i=1;i<=n;i++)fa[i]=i,sz[i]=1;} 21 int find_fa(int x){ 22 int y=x,pre;while(fa[y]!=y)y=fa[y]; 23 while(fa[x]!=y){ 24 pre=fa[x],fa[x]=y,x=pre; 25 }return y; 26 } 27 struct node{int id,val;}a[N]; 28 struct Edge{int to,nxt;}edge[N*2]; 29 void ae(int u,int v){ 30 cte++;edge[cte].nxt=head[u]; 31 edge[cte].to=v,head[u]=cte; 32 } 33 int cmp(node s1,node s2){return s1.val<s2.val;} 34 35 ll solve1() 36 { 37 sort(a+1,a+n+1,cmp); 38 int x,y,fx,fy; 39 init(); ll ans=0; 40 for(int i=1;i<=n;i++) 41 { 42 ll sum=0;x=a[i].id; 43 for(int j=head[x];j;j=edge[j].nxt){ 44 int v=edge[j].to; 45 if(use[v]){ 46 fy=find_fa(v); 47 sum+=sz[fy]; 48 } 49 }sum++; 50 for(int j=head[x];j;j=edge[j].nxt){ 51 int v=edge[j].to; 52 if(use[v]){ 53 fy=find_fa(v); 54 ans+=1ll*(sum-sz[fy])*sz[fy]*a[i].val; 55 sum-=sz[fy]; 56 fa[fy]=x,sz[x]+=sz[fy]; 57 } 58 } 59 use[x]=1; 60 }return ans; 61 } 62 ll solve2() 63 { 64 memset(use,0,sizeof(use)); 65 int x,y,fx,fy; 66 init(); ll ans=0; 67 for(int i=n;i>=1;i--) 68 { 69 ll sum=0;x=a[i].id; 70 for(int j=head[x];j;j=edge[j].nxt){ 71 int v=edge[j].to; 72 if(use[v]){ 73 fy=find_fa(v); 74 sum+=sz[fy]; 75 } 76 }sum++; 77 for(int j=head[x];j;j=edge[j].nxt){ 78 int v=edge[j].to; 79 if(use[v]){ 80 fy=find_fa(v); 81 ans+=1ll*(sum-sz[fy])*sz[fy]*a[i].val; 82 sum-=sz[fy]; 83 fa[fy]=x,sz[x]+=sz[fy]; 84 } 85 } 86 use[x]=1; 87 }return ans; 88 } 89 90 91 int main() 92 { 93 scanf("%d",&n); 94 int x,y; 95 for(int i=1;i<=n;i++) 96 w[i]=a[i].val=gint(),a[i].id=i; 97 for(int i=1;i<n;i++) 98 x=gint(),y=gint(),ae(x,y),ae(y,x); 99 ll ans1=solve1(); 100 ll ans2=solve2(); 101 printf("%I64d\n",ans1-ans2); 102 return 0; 103 }