HDOJ5877(dfs序+离散化+树状数组)

Weak Pair

Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 262144/262144 K (Java/Others)
Total Submission(s): 2081    Accepted Submission(s): 643


Problem Description
You are given a rooted tree of N nodes, labeled from 1 to N. To the ith node a non-negative value ai is assigned.An ordered pair of nodes (u,v) is said to be weakif
  (1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
  (2) au×avk.

Can you find the number of weak pairs in the tree?
 

 

Input
There are multiple cases in the data set.
  The first line of input contains an integer T denoting number of test cases.
  For each case, the first line contains two space-separated integers, N and k, respectively.
  The second line contains N space-separated integers, denoting a1 to aN.
  Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v.

  Constrains: 
  
  1N105 
  
  0ai109 
  
  0k1018
 

 

Output
For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
 

 

Sample Input
1
2 3
1 2
1 2
 
Sample Output
1
 
思路:将公式au*av<=k变换为 au<=k/av。 在遍历结点v的过程中,统计au<=k/av的节点u的个数。
#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long LL;
const int MAXN = 100005;
int n, bit[MAXN+MAXN], deg[MAXN], len, vis[MAXN];
LL val[MAXN], k;
vector<int> arc[MAXN];
LL res;
LL buf[MAXN+MAXN];
void add(int i, int x)
{
    while(i < MAXN + MAXN)
    {
        bit[i] += x;
        i += (i & (-i));
    }
}
int sum(int i)
{
    int s = 0;
    while(i > 0)
    {
        s += bit[i];
        i -= (i & (-i));
    }
    return s;
}
void dfs(int u)
{
    int id = lower_bound(buf, buf + len, k / val[u]) - buf + 1;
    int pre = sum(id);
    for(int i = 0, size = arc[u].size(); i < size; i++)
    {
        dfs(arc[u][i]);
    }
    int post = sum(id);
    res += (post - pre);

    int index = lower_bound(buf, buf + len, val[u]) - buf + 1;
    add(index, 1);
}
int main()
{
  //  freopen("input.in", "r", stdin);
    int T;
    scanf("%d", &T);
    while(T--)
    {
        scanf("%d %I64d", &n, &k);
        memset(bit, 0, sizeof(bit));
        memset(deg, 0, sizeof(deg));
        len = 0;
        res = 0;
        for(int i = 1; i <= n; i++) arc[i].clear();
        for(int i = 1; i <= n; i++)
        {
            scanf("%I64d", &val[i]);
            buf[len++] = val[i];
            buf[len++] = k / val[i];
        }
        for(int i = 0; i < n - 1; i++)
        {
            int u, v;
            scanf("%d %d", &u, &v);
            arc[u].push_back(v);
            deg[v]++;
        }
        sort(buf, buf + len);
        for(int i = 1; i <= n; i++)
        {
            if(deg[i] == 0)
            {
                dfs(i);
                break;
            }
        }
        printf("%I64d\n", res);
    }
    return 0;
}

 

 

 

posted on 2016-09-15 10:39  vCoders  阅读(345)  评论(0编辑  收藏  举报

导航