[HNOI/AHOI2018]排列


题解

堆 + 贪心
首先题面写的肥肠的隐晦
说白了题面的意思就是给你一棵树
\(a[u]\)代表每个点的父节点\(fa[u]\)
然后题目的意思就是选择这个点必须选择先选择了这个点的父亲,第\(i\)个被选择的点的贡献是\(i\times w\)
设当前已经选择了\(x\)个点了,有两个选择\(a,b\)
那么我们考虑先选择哪一个
\(W_{a,b}=\sum_{i=1}^{sz_1}(x+i)\times w_{1,i}+\sum_{i=1}^{sz_2}(x+sz_1+i)\times w_{2,i}\\ W_{b,a}=\sum_{i=1}^{sz_2}(x+i)\times w_{2,i}+\sum_{i=1}^{sz_1}(x+sz_2+i)\times w_{1,i}\\ W_{a,b} > W_{b,a} \to W_{2}\times sz_1 > W_{1}\times sz_2\)

所以我们可以按照\(\frac{W_i}{sz_i}\)从小到大来选择贪心
用一个堆来维护
考虑当前堆中权值最小的点\(i\)
如果\(i\)没有父亲,那么直接选\(i\)
否则就说明不能立即选择i
那么先考虑ta的父亲
因为ta的父亲之前也可能被一些更优的点依赖
所以一旦选择了ta的父亲就会先处理那些更优的依赖
然后处理完更优的依赖就立即来处理点\(i\)
这个可以用并查集实现

代码

#include<queue>
#include<cstdio>
#include<iostream>
#include<algorithm>
# define LL long long
const int M = 500005 ;
using namespace std ;

inline int read() {
    char c = getchar() ; int x = 0 , w = 1 ;
    while(c>'9'||c<'0') { if(c=='-') w = -1 ; c = getchar() ; }
    while(c>='0'&&c<='9') { x = x*10+c-'0' ; c = getchar() ; }
    return x*w ;
}

bool suc , vis[M] ;
int n , f[M] , fa[M] ;
LL ans , w[M] , sz[M] ;
vector < int > vec[M] ; 
struct Node { int u , sz ; LL w ; } ;
inline bool operator < (Node a , Node b) {
    return a.sz * b.w < b.sz * a.w ;
}
int find(int x) {
    if(f[x] != x) f[x] = find(f[x]) ;
    return f[x] ;
}
priority_queue < Node > q ;
void dfs(int u) {
    if(vis[u]) return void(suc = true) ; vis[u] = true ;
    for(int i = 0 , v ; i < vec[u].size() ; i ++) {
        v = vec[u][i] ;
        dfs(v) ;
    }
}
int main() {
    n = read() ;
    for(int i = 1 ; i <= n ; i ++) {
        fa[i] = read() ;
        vec[fa[i]].push_back(i) ;
    }
    for(int i = 1 ; i <= n ; i ++) w[i] = read() ;
    dfs(0) ;
    for(int i = 0 ; i <= n ; i ++)  
        if(!vis[i])	
            suc = true ;
    if(suc) { printf("-1\n") ; return 0 ; }
    for(int i = 0 ; i <= n ; i ++)
        f[i] = i , sz[i] = 1 ;
    for(int i = 1 ; i <= n ; i ++)
        q.push((Node) { i , 1 , w[i] }) ;
    while(!q.empty()) {
        Node x = q.top() ; q.pop() ;
        int u = x.u , ff = find(fa[u]) ;
        if(find(u) == ff) continue ;
        f[u] = ff ;
        ans += 1LL * w[u] * sz[ff] ;
        sz[ff] += sz[u] ; w[ff] += w[u] ;
        if(ff)
            q.push((Node) { ff , sz[ff] , w[ff] }) ;
    }
    cout << ans << endl ;
    return 0 ;
}
posted @ 2019-04-24 16:49  beretty  阅读(207)  评论(0编辑  收藏  举报