POJ 1741 Tree 点分治
题意:在一颗树上,求多少对点的距离 <= k.
先吐槽这个题目, f**k, 题目中说了l 和 n的范围, 我还以为k的范围也小于1001, 结果k的范围是没有确定的,直接写了一个树状数组疯狂re。。。。。。
题解:很裸的点分治。
1.找重心。
2.算多少个点对经过重心且满足题意。
3.递归处理所有子树。
代码:
1 #include<cstdio> 2 #include<algorithm> 3 #include<vector> 4 #include<queue> 5 #include<iostream> 6 #include<cstring> 7 using namespace std; 8 #define Fopen freopen("_in.txt","r",stdin); freopen("_out.txt","w",stdout); 9 #define LL long long 10 #define ULL unsigned LL 11 #define fi first 12 #define se second 13 #define pb push_back 14 #define lson l,m,rt<<1 15 #define rson m+1,r,rt<<1|1 16 #define lch(x) tr[x].son[0] 17 #define rch(x) tr[x].son[1] 18 #define max3(a,b,c) max(a,max(b,c)) 19 #define min3(a,b,c) min(a,min(b,c)) 20 typedef pair<int,int> pll; 21 const int inf = 0x3f3f3f3f; 22 const LL INF = 0x3f3f3f3f3f3f3f3f; 23 const LL mod = (int)1e9+7; 24 const int N = 1e5 + 100; 25 int sz[N], vis[N]; 26 int head[N], to[N<<1], ct[N<<1], nt[N<<1], tot; 27 int n, m; int ans; 28 void add(int u, int v, int w){ 29 to[tot] = v; ct[tot] = w; 30 nt[tot] = head[u]; head[u] = tot++; 31 } 32 int rtsz, rt; 33 void get_rt(int o, int u, int num){ 34 sz[u] = 1; 35 int v, mxnum = 0; 36 for(int i = head[u]; ~i; i = nt[i]){ 37 v = to[i]; 38 if(vis[v] || o == v) continue; 39 get_rt(u, v, num); 40 sz[u] += sz[v]; 41 mxnum = max(mxnum, sz[v]); 42 } 43 if(o) mxnum = max(mxnum, num - sz[u]); 44 if(mxnum < rtsz){ 45 rtsz = mxnum; 46 rt = u; 47 } 48 return ; 49 } 50 int bit[N]; 51 void Add(int x, int v){ 52 while(x <= n){ 53 bit[x] += v; 54 x += x & (-x); 55 } 56 return ; 57 } 58 int Query(int x){ 59 int ret = 0; 60 while(x > 0){ 61 ret += bit[x]; 62 x -= x & (-x); 63 } 64 return ret; 65 } 66 int d[N], dcnt; 67 void dfs(int o, int u, int w){ 68 d[++dcnt] = w; 69 sz[u] = 1; 70 for(int i = head[u]; ~i; i = nt[i]){ 71 int v = to[i]; 72 if(vis[v] || o == v) continue; 73 dfs(u, v, w+ct[i]); 74 sz[u] += sz[v]; 75 } 76 return ; 77 } 78 int cal(){ 79 sort(d+1, d+1+dcnt); 80 int l = 1, r = dcnt, ret = 0; 81 while(l < r){ 82 if(d[l] + d[r] <= m) ret += r - l, l++; 83 else r--; 84 } 85 return ret; 86 } 87 88 void solve(int u, int num){ 89 if(num <= 1) return ; 90 rtsz = inf; 91 get_rt(0, u, num); 92 vis[rt] = 1; 93 int v; 94 dcnt = 0; 95 dfs(0,rt,0); 96 ans += cal(); 97 for(int i = head[rt]; ~i; i = nt[i]){ 98 v = to[i]; 99 if(vis[v]) continue; 100 dcnt = 0; 101 dfs(0, v, ct[i]); 102 ans -= cal(); 103 } 104 for(int i = head[rt]; ~i; i = nt[i]){ 105 v = to[i]; 106 if(vis[v]) continue; 107 solve(v, sz[v]); 108 } 109 return ; 110 } 111 int main(){ 112 int u, v, w; 113 while(~scanf("%d%d", &n, &m) && n+m){ 114 tot = 0; ans = 0; 115 for(int i = 1; i <= n; i++){ 116 vis[i] = 0; 117 head[i] = -1; 118 } 119 for(int i = 1; i < n; i++){ 120 scanf("%d%d%d", &u, &v, &w); 121 add(u, v, w); add(v, u, w); 122 } 123 solve(1, n); 124 printf("%d\n", ans); 125 } 126 return 0; 127 }