Split The Tree(2018东北四省赛)

题意:给你一棵树,树上每个节点都有一个权值,问将一条边断开,形成两棵树,每棵树上权值不同的数量 的和 , 让这个和尽量大

 

很显然这是一道求区间数字种类数的问题,这道题给的是一棵树,不是区间,因此我们首先将树给区间化, 然而dfs序 就是干这种事情的,我们将树区间化后,就转化为求区间【L,R】中不同的种类数 和 除了【L,R】的不同种类数(不是这个区间的,我们可以给他在添加 n 个,类似于这样【1,2,3,4,1,2,3,4】),对于求区间数字的种类数,我们用树状数组,莫队,主席树等诸多方法就可以解决了

 

Code:

#include <bits/stdc++.h>
 
using namespace std;
const int N = 1e6 + 10;
vector<int>edge[N];
 
int st[N],ed[N],w[N],a[2*N],pos[N],last[2*N],head[2*N],c[2*N];
int tot;
 
int lowbit(int x){  return x&(-x);}
 
void update(int x,int y)
{
    while(x <= 2*tot)
    {
        c[x] += y;
        x += lowbit(x);
    }
}
 
int query(int x)
{
    int ans = 0;
    while(x)
    {
        ans += c[x];
        x -= lowbit(x);
    }
    return ans;
}
 
void dfs(int u,int pre)
{
    st[u] = ++tot;
    pos[tot] = u;
   // cout<<pre<<" "<<u<<endl;
    for(int i = 0;i < edge[u].size();i++)
    {
        int to = edge[u][i];
        if(pre != to)  dfs(to,u);
    }
    ed[u] = tot;
}
 
struct Node{
    int l,r,id;
    int ans;
}node[2*N];
 
int cmp1(Node a,Node b)
{
    return a.l == b.l?a.r < b.r:a.l < b.l;
}
int cmp2(Node a,Node b)
{
    return a.id < b.id;
}
int main()
{
    int n;
    tot = 0;
    memset(c,0,sizeof(c));
    memset(head,0,sizeof(head));
    memset(last,0,sizeof(last));
    scanf("%d",&n);
    for(int i = 2;i <= n;i++)
    {
        int pre;
        scanf("%d",&pre);
        edge[pre].push_back(i);
        edge[i].push_back(pre);
    }
    for(int i = 1;i <= n;i++)
        scanf("%d",&w[i]);
 
    dfs(1,0);

  //树状数组解决区间中数字种类的个数
for(int i = 1;i <= tot;i++) a[i + tot] = a[i] = w[pos[i]]; for(int i = 2*tot;i >= 1;i--) last[i] = head[a[i]],head[a[i]] = i; int ant = 0; for(int i = 2;i <= n;i++) { node[ant].id = i;node[ant].l = st[i];node[ant].r = ed[i];ant++; node[ant].id = i;node[ant].l = ed[i] + 1;node[ant].r = tot + st[i] - 1;ant++; } for(int i = 1;i <= 100000;i++) if(head[i]) update(head[i],1); sort(node,node+ant,cmp1); int l = 1; for(int i = 0;i < ant;i++) { while(l < node[i].l) { if(last[l])update(last[l],1); l++; } node[i].ans=query(node[i].r)-query(node[i].l-1); } sort(node,node+ant,cmp2); int maxx = 0; for(int i = 0;i < 2*tot;i += 2) maxx = max(maxx,node[i].ans + node[i+1].ans); printf("%d\n",maxx); return 0; }

 

posted @ 2018-09-06 10:30  jadelemon  阅读(518)  评论(0编辑  收藏  举报