POJ1741 Tree(点分治)
嘟嘟嘟
没错,这一道最经典的点分治模板题。
题意:求树上两点间距离\(\leqslant k\)的点对个数。
点分治这东西我好早就听说了,然后一两个月前也学了一下,不过只是刷了个模板,没往深处学。
对于这道题,就说说大概的步骤吧。
1.找重心:一遍\(dfs\)即可。
2.求出每一个子树中的点到重心的距离。并且记录这个点属于哪一棵子树。
3.把上述的点存下来,按距离从小到大排序。
4.统计答案。采用双指针,\(i\)从头开始,\(j\)从尾开始。这样每一个\(i\)对答案的贡献是\(j - i - num[point_i]\),\(num[point_i]\)表示的是\(i\)所在的子树有多少个。(为了减去属于相同子树的贡献)
5.递归到每一个子树中统计答案。
20.10.18更新:
上了大学打算打acm,然后又学了一遍点分治。
在统计答案时有一个更简单的做法:在每一次从重心开始往每一个子树dfs得到距离序列s后,我们把s用双指针扫一下,求出距离小于等于\(k\)的点的对数\(x\)。然后把\(x\)从答案中减去。最后再加上整个重心的距离序列即可。
这是利用类似容斥的思想,从而避免了记录上述\(num[i]\),还减少了代码难度。
先给出方法一的代码:
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define rg register
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 4e4 + 5;
inline ll read()
{
ll ans = 0;
char ch = getchar(), last = ' ';
while(!isdigit(ch)) last = ch, ch = getchar();
while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
if(last == '-') ans = -ans;
return ans;
}
inline void write(ll x)
{
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar(x % 10 + '0');
}
int n, k;
struct Edge
{
int nxt, to, w;
}e[maxn << 1];
int head[maxn], ecnt = -1;
void addEdge(int x, int y, int w)
{
e[++ecnt] = (Edge){head[x], y, w};
head[x] = ecnt;
}
bool out[maxn];
int Siz, siz[maxn], Max[maxn];
void dfs1(int now, int _f, int &cg)
{
siz[now] = 1; Max[now] = -1;
for(int i = head[now], v; i != -1; i = e[i].nxt)
{
if(!out[v = e[i].to] && v != _f)
{
dfs1(v, now, cg);
siz[now] += siz[v];
Max[now] = max(Max[now], siz[v]);
}
}
Max[now] = max(Max[now], Siz - siz[now]);
if(!cg || Max[now] < Max[cg]) cg = now;
}
struct Node
{
int dis, bel;
bool operator < (const Node& oth)const
{
return dis < oth.dis;
}
}a[maxn];
int cnt = 0;
void dfs2(int now, int _f, int dis, int x, int cg)
{
siz[now] = 1;
a[++cnt] = (Node){dis, x};
for(int i = head[now], v; i != -1; i = e[i].nxt)
{
if(!out[v = e[i].to] && v != _f)
{
dfs2(v, now, dis + e[i].w, now == cg ? v : x, cg);
siz[now] += siz[v];
}
}
}
int num[maxn], ans = 0;
void solve(int now)
{
int cg = 0; cnt = 0;
dfs1(now, 0, cg);
dfs2(cg, 0, 0, 0, cg);
sort(a + 1, a + cnt + 1);
for(int i = head[cg]; i != -1; i = e[i].nxt)
if(!out[e[i].to]) num[e[i].to] = 0;
for(int i = 1; i <= cnt; ++i) num[a[i].bel]++;
for(int i = 1, j = cnt; i <= j; ++i)
{
num[a[i].bel]--;
while(a[i].dis + a[j].dis > k && i <= j) num[a[j--].bel]--;
if(i > j) break;
ans += j - i - num[a[i].bel];
}
out[cg] = 1;
for(int i = head[cg], v; i != -1; i = e[i].nxt)
if(!out[v = e[i].to]) Siz = siz[v], solve(v);
}
int main()
{
Mem(head, -1);
n = read();
for(int i = 1; i < n; ++i)
{
int x = read(), y = read(), w = read();
addEdge(x, y, w); addEdge(y, x, w);
}
k = read();
ans = 0; Siz = n;
solve(1);
write(ans), enter;
return 0;
}
然后是方法2的主要代码
int a[maxn], cnt = 0;
In void dfs2(int now, int _f, int d)
{
if(d > K) return;
a[++cnt] = d;
forE(i, now, v) if(v != _f && !out[v]) dfs2(v, now, d + e[i].w);
}
int ans;
In int calc(int* a, int cnt)
{
int ret = 0;
for(int i = 1, j = cnt; i < j; ++ i)
{
while(i < j && a[i] + a[j] > K) --j;
ret += j - i;
}
return ret;
}
int st[maxn], top = 0;
In void solve(int now)
{
cg = 0; st[top = 1] = 0;
dfs1(now, 0, cg); //求重心cg
forE(i, cg, x)
{
if(out[x]) continue;
cnt = 0, dfs2(x, cg, e[i].w);
sort(a + 1, a + cnt + 1);
ans -= calc(a, cnt);
for(int j = 1; j <= cnt; ++j) st[++top] = a[j];
}
sort(st + 1, st + top + 1);
ans += calc(st, top);
out[cg] = 1;
forE(i, cg, x) if(!out[x]) Siz = siz[x], solve(x);
}