[HDOJ5877]Weak Pair(DFS,线段树,离散化)
题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=5877
题意:给一棵树和各点的权值a,求点对(u,v)个数,满足:1.u是v的祖先,2.a(u)*a(v)<=k。
对于这棵树,我们先存好树的结构。再离散化,最后dfs的时候往线段树里插点,那对应idx的值就是1。然后二分找不大于k/a[v]的下标,线段树统计计数就行了。换儿子的时候记得抹去上一个兄弟。
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 #define lrt rt << 1 5 #define rrt rt << 1 | 1 6 typedef long long LL; 7 const int maxn = 200100; 8 int n, rt; 9 int in[maxn]; 10 LL k, ret, a[maxn]; 11 vector<int> G[maxn]; 12 LL h[maxn]; 13 int hcnt; 14 LL sum[maxn<<3]; 15 16 inline int getid(LL x) { 17 return lower_bound(h, h+hcnt, x) - h + 1; 18 } 19 20 void pushUP(int rt) { 21 sum[rt] = sum[lrt] + sum[rrt]; 22 } 23 24 void build(int l, int r, int rt) { 25 sum[rt] = 0; 26 if(l == r) return; 27 int mid = (l + r) >> 1; 28 build(l, mid, lrt); 29 build(mid+1, r, rrt); 30 pushUP(rt); 31 } 32 33 void update(int l, int r, int rt, int pos, LL val) { 34 if(l == r) { 35 sum[rt] += val; 36 return; 37 } 38 int mid = (l + r) >> 1; 39 if(pos <= mid) update(l, mid, lrt, pos, val); 40 else update(mid+1, r, rrt, pos, val); 41 pushUP(rt); 42 } 43 44 LL query(int L, int R, int l, int r, int rt) { 45 if(l >= L && R >= r) return sum[rt]; 46 int mid = (l + r) >> 1; 47 LL ret = 0; 48 if(L <= mid) ret += query(L, R, l, mid, lrt); 49 if(mid < R) ret += query(L, R, mid+1, r, rrt); 50 return ret; 51 } 52 53 void dfs(int u) { 54 int uu = getid(a[u]); 55 int vv = getid(k/a[u]); 56 ret += query(1, vv, 1, hcnt, 1); 57 update(1, hcnt, 1, uu, 1); 58 for(int i = 0; i < G[u].size(); i++) dfs(G[u][i]); 59 update(1, hcnt, 1, uu, -1); 60 } 61 62 int main() { 63 //freopen("in", "r", stdin); 64 int T, u, v; 65 scanf("%d", &T); 66 while(T--) { 67 scanf("%d %I64d",&n,&k); 68 memset(in, 0, sizeof(in)); hcnt = 0; ret = 0; 69 for(int i = 1; i <= n; i++) { 70 scanf("%I64d", &a[i]); 71 G[i].clear(); 72 h[hcnt++] = a[i]; h[hcnt++] = k / a[i]; 73 } 74 sort(h, h+hcnt); hcnt = unique(h, h+hcnt) - h; 75 build(1, hcnt, 1); 76 for(int i = 0; i < n-1; i++) { 77 scanf("%d %d",&u,&v); 78 G[u].push_back(v); 79 in[v]++; 80 } 81 for(int i = 1; i <= n; i++) if(!in[i]) dfs(i); 82 printf("%I64d\n", ret); 83 } 84 return 0; 85 }