noi.ac #543 商店


我们考虑可并堆维护,从深到浅贪心选取。

用priority_queue启发式合并的话,是60pts:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#include<ctime>
#define MAXN 3000010
using namespace std;
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1; ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48); ch=getchar();}
    return x*f;
}
int n,m,t,tot;
int head[MAXN],id[MAXN],fa[MAXN],c[MAXN];
long long ans;
struct Edge{int nxt,to;}edge[MAXN<<1];
priority_queue<int,vector<int>,less<int> >q[MAXN];
inline void add(int from,int to)
{
    edge[++t].nxt=head[from],edge[t].to=to;
    head[from]=t;
}
inline void solve(int x,int pre)
{
    id[x]=++tot;
    q[tot].push(x-1);
    for(int i=head[x];i;i=edge[i].nxt)
    {
        int v=edge[i].to;
        if(v==pre) continue;
        solve(v,x);
        if(q[id[x]].size()<q[id[v]].size()) swap(id[x],id[v]);
        while(!q[id[v]].empty())
        {
            int cur=q[id[v]].top();q[id[v]].pop();
            q[id[x]].push(cur);
        }
    }
    for(int i=1;i<=c[x];i++)
    {
        ans+=q[id[x]].top();
        q[id[x]].pop();
        if(q[id[x]].empty()) break;
    }
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("ce.in","r",stdin);
    #endif
    n=read(),m=read();
    for(int i=2;i<=n;i++) 
    {
        fa[i]=read();
        fa[i]++;
        add(fa[i],i),add(i,fa[i]);
    }
    for(int i=1;i<=m;i++)
    {
        int x;
        x=read();
        x++;
        c[x]++;
    }
    solve(1,0);
    printf("%lld\n",ans);
    return 0;
}

用并查集维护的话,可以AC:

#include<iostream>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<algorithm>
#define MAXN 3000010
int n,m,num;
int cnt[MAXN],fa[MAXN],f[MAXN];
long long ans=0;
inline int find(int x){return x==f[x]?x:f[x]=find(f[x]);}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("ce.in","r",stdin);
    #endif
    scanf("%d%d",&n,&m);
    for(int i=2;i<=n;i++) scanf("%d",&fa[i]),fa[i]++;
    for(int i=1;i<=m;i++)
    {
        int x;
        scanf("%d",&x),x++;
        cnt[x]++;
    }
    for(int i=1;i<=n;i++) f[i]=(cnt[i]?i:fa[i]);
    for(int i=n;i>=1;i--)
    {
        int x=find(i);
        if(cnt[x])
        {
            ans+=i-1;
            cnt[x]--;
            if(cnt[x]==0) f[x]=find(fa[x]);
        }
    }
    printf("%lld\n",ans);
    return 0;
}
posted @ 2019-07-10 14:49  风浔凌  阅读(177)  评论(0编辑  收藏  举报