[HDOJ5877]Weak Pair(DFS,线段树,离散化)

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=5877

题意:给一棵树和各点的权值a,求点对(u,v)个数,满足:1.u是v的祖先,2.a(u)*a(v)<=k。

对于这棵树,我们先存好树的结构。再离散化,最后dfs的时候往线段树里插点,那对应idx的值就是1。然后二分找不大于k/a[v]的下标,线段树统计计数就行了。换儿子的时候记得抹去上一个兄弟。

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 
 4 #define lrt rt << 1
 5 #define rrt rt << 1 | 1
 6 typedef long long LL;
 7 const int maxn = 200100;
 8 int n, rt;
 9 int in[maxn];
10 LL k, ret, a[maxn];
11 vector<int> G[maxn];
12 LL h[maxn];
13 int hcnt;
14 LL sum[maxn<<3];
15 
16 inline int getid(LL x) {
17   return lower_bound(h, h+hcnt, x) - h + 1;
18 }
19 
20 void pushUP(int rt) {
21   sum[rt] = sum[lrt] + sum[rrt];
22 }
23 
24 void build(int l, int r, int rt) {
25   sum[rt] = 0;
26   if(l == r) return;
27   int mid = (l + r) >> 1;
28   build(l, mid, lrt);
29   build(mid+1, r, rrt);
30   pushUP(rt);
31 }
32 
33 void update(int l, int r, int rt, int pos, LL val) {
34   if(l == r) {
35     sum[rt] += val;
36     return;
37   }
38   int mid = (l + r) >> 1;
39   if(pos <= mid) update(l, mid, lrt, pos, val);
40   else update(mid+1, r, rrt, pos, val);
41   pushUP(rt);
42 }
43 
44 LL query(int L, int R, int l, int r, int rt) {
45   if(l >= L && R >= r) return sum[rt];
46   int mid = (l + r) >> 1;
47   LL ret = 0;
48   if(L <= mid) ret += query(L, R, l, mid, lrt);
49   if(mid < R) ret += query(L, R, mid+1, r, rrt);
50   return ret;
51 }
52 
53 void dfs(int u) {
54   int uu = getid(a[u]);
55   int vv = getid(k/a[u]);
56   ret += query(1, vv, 1, hcnt, 1);
57   update(1, hcnt, 1, uu, 1);
58   for(int i = 0; i < G[u].size(); i++) dfs(G[u][i]);
59   update(1, hcnt, 1, uu, -1);
60 }
61 
62 int main() {
63   //freopen("in", "r", stdin);
64   int T, u, v;
65   scanf("%d", &T);
66   while(T--) {
67     scanf("%d %I64d",&n,&k);
68     memset(in, 0, sizeof(in)); hcnt = 0; ret = 0;
69     for(int i = 1; i <= n; i++) {
70       scanf("%I64d", &a[i]);
71       G[i].clear();
72       h[hcnt++] = a[i]; h[hcnt++] = k / a[i];
73     }
74     sort(h, h+hcnt); hcnt = unique(h, h+hcnt) - h;
75     build(1, hcnt, 1);
76     for(int i = 0; i < n-1; i++) {
77       scanf("%d %d",&u,&v);
78       G[u].push_back(v);
79       in[v]++;
80     }
81     for(int i = 1; i <= n; i++) if(!in[i]) dfs(i);
82     printf("%I64d\n", ret);
83   }
84   return 0;
85 }

 

posted @ 2016-10-20 14:39  Kirai  阅读(143)  评论(0编辑  收藏  举报