HDU5877 Weak Pair dfs + 线段树/树状数组 + 离散化

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=5877


题意:

weak pair的要求:

1.u是v的祖先(注意不一定是父亲)

2.val[u]*val[k] <=k;



题解:

1.将val(以及k/val)离散化,可用map 或者 用数组。 只要能将val与树状数组或线段树的下标形成映射就可以了。

2.从根节点开始搜索(题目中的树,不是线段树或树状数组的树),先统计当前节点与祖先能形成多少对weak pair,然后将其插入到树状数组或线段树中。

3.递归其子树。递归完子树后,再把当前节点从树状数组或线段树中删去。因为:根据递归的特性,如果不删除,这个值将会残留在c数组, 那么他的堂兄弟,堂叔伯,堂侄子等(后一步 递归的)会误认为这个值是他的祖先的。所以要及时删除。


类似的题(边查询边更新):http://blog.csdn.net/dolfamingo/article/details/71001021



树状数组(map离散)

 

#include<cstdio>//hdu5877 树状数组 map离散 dfs
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<algorithm>
#include<map>
#include<vector>
#define LL long long
#define INF 2e18

using namespace std;

LL sval[200020], val[200020];
int n,fa[100010],c[200010];
map<LL,int> m;
vector<int> son[100010];
LL k,ans;

int lowbit(int x)
{
    return x & (-x);
}

void add(int x, int d)
{
    for(; x<=2*n; x += lowbit(x))
    {
        c[x] += d;
    }
}

int sumc(int x)
{
    int s = 0;
    for(;x>0; x -= lowbit(x))
    {
        s += c[x];
    }
    return s;
}

void dfs(int rt)//c数组中的下标与val[i],k/val[i]映射 且k/v[i]的下标-i的下边等于n(自己定)
{
    ans += sumc(m[val[n+rt]]);//统计<=k/val[rt]的个数,为什么不直接 m[k/val[rt]]? 因为val[rt]可能为0
    add(m[val[rt]],1);
    for(int i = 0; i<son[rt].size(); i++)
    {
        dfs(son[rt][i]);
    }
    add(m[val[rt]],-1);
}

int main()
{
    int T;
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d%lld",&n,&k);

        for(int i = 1; i<=n; i++)
        {
            //存k/val[i]是因为 val[i]*(k/val[i])<= k. 到时可以直接从树状数组中统计<=k/val[i]的个数
            scanf("%lld",&val[i]);
            if(val[i])
                val[n+i] = k/val[i];
            else            //0为特殊情况
                val[n+i] = INF;
        }

         //sval的作用是将v值按从小到大,一一与c数组的下标形成映射
        for(int i = 1; i<=2*n; i++)
            sval[i] = val[i];
        sort(sval+1,sval+2*n+1);

        int cnt = 0;
        m.clear();
        for(int i = 1; i<=2*n; i++)
        {
            //map的作用是将v值与c数组的下标形成映射
            if(!m[sval[i]]) m[sval[i]] = ++cnt;
        }

        for(int i = 1; i<=n; i++)
            fa[i] = 0, son[i].clear();
        for(int i = 1,u,v; i<n; i++)
        {
            scanf("%d%d",&u,&v);
            son[u].push_back(v);
            fa[v] = u;
        }

        ans = 0;
        memset(c,0,sizeof(c));
        for(int i = 1; i<=n; i++)
        {
            if(!fa[i])
            {
                dfs(i);
                break;
            }
        }
        printf("%lld\n",ans);

    }
    return 0;
}




线段树(map离散):

 

#include<cstdio>//hdu5877 线段树 map离散 dfs
#include<cstring>//注意区分题目的树和线段树的树
#include<cstdlib>
#include<cmath>
#include<algorithm>
#include<vector>
#include<map>
#define LL long long
#define INF 2e18

using namespace std;

int n,len;//n是题目给出的树的结点个数,len是线段树的线段长度。
LL k,ans;
LL val[200100],sval[200100];//val记录原始值,sval记录经过排序,删除重复操作的值,用于线段树的操作
int fa[100100],sum[800100];
vector<int>son[100100];
map<LL,int>m;

int query(int root, int le, int ri, int x, int y)
{
    if(x<=le && y>=ri)
        return sum[root];

    int mid = (le+ri)/2, ret = 0;
    if(x<=mid) ret += query(root*2,le,mid,x,y);
    if(y>=mid+1) ret += query(root*2+1,mid+1,ri,x,y);
    return ret;
}

void update(int root, int le, int ri, int pos, int d)
{
    if(le==ri)
    {
        sum[root] += d;
        return;
    }

    int mid = (le+ri)/2;
    if(pos<=mid) update(root*2,le,mid,pos,d);
    else update(root*2+1,mid+1,ri,pos,d);
    sum[root] = sum[root*2] + sum[root*2+1];
}

void dfs(int rt)
{
    int last = m[val[n+rt]];
    int pos = m[val[rt]];

    ans += query(1,1,len,1,last);

    update(1,1,len,pos,1);
    for(int i = 0; i<son[rt].size(); i++)
    {
        dfs(son[rt][i]);
    }
    update(1,1,len,pos,-1);
}

int main()
{
    int T;
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d%lld",&n,&k);
        for(int i = 1; i<=n; i++)
        {
            scanf("%lld",&val[i]);
            if(val[i])
                val[n+i] = k/val[i];
            else
                val[n+i] = INF;
        }

        for(int i = 1; i<=2*n; i++)
            sval[i] = val[i];
        sort(sval+1,sval+2*n+1);

        m.clear();
        len = 0;
        for(int i = 1; i<=2*n; i++)
        {
            if(!m[sval[i]]) m[sval[i]] = ++len;
        }

        for(int i = 1; i<=n; i++)
            fa[i] = 0, son[i].clear();
        for(int i = 1,u,v; i<n; i++)
        {
            scanf("%d%d",&u,&v);
            son[u].push_back(v);
            fa[v] = u;
        }

        ans = 0;
        memset(sum,0,sizeof(sum));
        for(int i = 1; i<=n; i++)
        {
            if(!fa[i])
            {
                dfs(i);
                break;
            }
        }

        printf("%lld\n",ans);
    }
    return 0;
}


 


线段树(数组离散):

 

#include<cstdio>//hdu5877 线段树 dfs 普通数组进行离散  
#include<cstring>//注意区分题目的树和线段树的树  
#include<cstdlib>  
#include<cmath>  
#include<algorithm>  
#include<vector>  
#define LL long long  
#define INF 2e18  
using namespace std;  
  
int n,m;//n是题目给出的树的结点个数,m是线段树的线段长度。  
LL k,ans;  
LL val[200100],sval[200100];//val记录原始值,sval记录经过排序,删除重复操作的值,用于线段树的操作  
int fa[100100],sum[800100];  
vector<int>son[100100];  
  
int query(int root, int le, int ri, int x, int y)  
{  
    if(x<=le && y>=ri)  
        return sum[root];  
  
    int mid = (le+ri)/2, ret = 0;
    if(x<=mid) ret += query(root*2,le,mid,x,y);  
    if(y>=mid+1) ret += query(root*2+1,mid+1,ri,x,y);  
    return ret;
}  
  
void update(int root, int le, int ri, int pos, int d)  
{  
    if(le==ri)  
    {  
        sum[root] += d;  
        return;  
    }  
  
    int mid = (le+ri)/2;  
    if(pos<=mid) update(root*2,le,mid,pos,d);  
    else update(root*2+1,mid+1,ri,pos,d);  
    sum[root] = sum[root*2] + sum[root*2+1];  
}  
  
void dfs(int rt)  
{  
    int last = lower_bound(sval+1, sval+m+1,val[n+rt]) - sval;  
    int pos = lower_bound(sval+1,sval+m+1,val[rt])  - sval;  
  
    ans += query(1,1,m,1,last);  
  
    update(1,1,m,pos,1);  
    for(int i = 0; i<son[rt].size(); i++)  
    {  
        dfs(son[rt][i]);  
    }  
    update(1,1,m,pos,-1);  
}  
  
int main()  
{  
    int T;  
    scanf("%d",&T);  
    while(T--)  
    {  
        scanf("%d%lld",&n,&k);  
        for(int i = 1; i<=n; i++)  
        {  
            scanf("%lld",&val[i]);  
            if(val[i])  
                val[n+i] = k/val[i];  
            else  
                val[n+i] = INF;  
        }  
  
        for(int i = 1; i<=2*n; i++)  
            sval[i] = val[i];  
        sort(sval+1,sval+2*n+1);  
        m = unique(sval+1,sval+2*n+1) - (sval+1);  
  
        for(int i = 1; i<=n; i++)  
            fa[i] = 0, son[i].clear();  
        for(int i = 1,u,v; i<n; i++)  
        {  
            scanf("%d%d",&u,&v);  
            son[u].push_back(v);  
            fa[v] = u;  
        }  
  
        ans = 0;  
        memset(sum,0,sizeof(sum));  
        for(int i = 1; i<=n; i++)  
        {  
            if(!fa[i])  
            {  
                dfs(i);  
                break;  
            }  
        }  
  
        printf("%lld\n",ans);  
    }  
    return 0;  
}  


 

posted on 2017-03-24 20:20  h_z_cong  阅读(226)  评论(0编辑  收藏  举报

导航