[HNOI/AHOI2018]排列
题目大意:
给定$n(n\le5\times10^5)$个正整数$a_1,a_2,\ldots,a_n(0\le a_i\le n)$,及$n$个正整数$w_1,w_2,\ldots,w_n$。称$a$的一个排列$a_{p[1]},a_{p[2]},\ldots,a_{p[n]}$为合法排列当且仅当该排列满足:对于任意的$k$和$j$,若$j\le k$,$a_{p[j]}\ne p[k]$。定义这个合法排列的权值为$\sum w_{p[i]}\times i$。问是否存在合法排列。如果有,求最大权值。
思路:
原题是HDU1055。
连边$a_i\to i$,显然一个排列是合法的当且仅当这个排列是该图的一个拓扑序。即若存在环则合法排列不存在。
对于存在合法排列的情况,每次贪心地选取$w_i$最小的点。若去掉图中已被选择的点后,$i$的入度为$0$,则此时选择$i$一定最优。若现在还不能选择$i$,则优先考虑$a_i$,若后面$a_i$被选择后,马上选择$i$一定更优,可以用并查集将$i$合并到$a_i$上。注意到$a_i$可能也依赖于别的结点,也可能有结点依赖于$i$,因此合并时需要维护整个含有依赖关系的连通块。合并后结点的优先级需要相应调整,可以证明用块内元素平均值进行比较是正确的。这显然可以用堆来维护,时间复杂度$O(n\log n)$。
1 #include<queue> 2 #include<cstdio> 3 #include<cctype> 4 #include<functional> 5 #include<ext/pb_ds/priority_queue.hpp> 6 typedef long long int64; 7 inline int getint() { 8 register char ch; 9 while(!isdigit(ch=getchar())); 10 register int x=ch^'0'; 11 while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0'); 12 return x; 13 } 14 const int N=5e5+1; 15 int n,a[N],h[N],sz,size[N],cnt; 16 int64 w[N]; 17 struct Edge { 18 int to,next; 19 }; 20 Edge e[N]; 21 inline void add_edge(const int &u,const int &v) { 22 e[++sz]=(Edge){v,h[u]};h[u]=sz; 23 } 24 bool vis[N]; 25 bool dfs(const int &x) { 26 if(vis[x]) return false; 27 vis[x]=true; 28 cnt++; 29 for(int i=h[x];i;i=e[i].next) { 30 const int &y=e[i].to; 31 if(!dfs(y)) return false; 32 } 33 return true; 34 } 35 inline bool check() { 36 dfs(0); 37 return cnt==n+1; 38 } 39 struct Node { 40 int id; 41 bool operator > (const Node &another) const { 42 return w[id]*size[another.id]>w[another.id]*size[id]; 43 } 44 }; 45 __gnu_pbds::priority_queue<Node,std::greater<Node> > q; 46 __gnu_pbds::priority_queue<Node,std::greater<Node> >::point_iterator p[N]; 47 struct DisjointSet { 48 int anc[N]; 49 int find(const int &x) { 50 return x==anc[x]?x:anc[x]=find(anc[x]); 51 } 52 void reset(const int &n) { 53 for(register int i=0;i<=n;i++) anc[i]=i; 54 } 55 void merge(const int &x,const int &y) { 56 anc[find(x)]=find(y); 57 } 58 }; 59 DisjointSet s; 60 inline int64 solve() { 61 if(!check()) return -1; 62 for(register int i=size[0]=1;i<=n;i++) { 63 size[i]=1; 64 p[i]=q.push((Node){i}); 65 } 66 s.reset(n); 67 int64 ret=0; 68 for(register int i=1;i<=n;i++) { 69 const int x=q.top().id,par=s.find(a[x]); 70 s.merge(x,par); 71 q.pop(); 72 ret+=w[x]*size[par]; 73 w[par]+=w[x]; 74 size[par]+=size[x]; 75 if(par) q.modify(p[par],(Node){par}); 76 } 77 return ret; 78 } 79 int main() { 80 n=getint(); 81 for(register int i=1;i<=n;i++) { 82 add_edge(a[i]=getint(),i); 83 } 84 for(register int i=1;i<=n;i++) { 85 w[i]=getint(); 86 } 87 printf("%lld\n",solve()); 88 return 0; 89 }