POJ1471 Tree/洛谷P4178 Tree
点分治板子。
点分治就是直接找树的重心进行暴力计算,每次树的深度不会超过子树深度的\(\frac{1}{2}\),计算完就消除影响,找下一个重心。
所以伪代码:
void solve(int u)
{
calc(u);
used[u]=true;
for(int i=head[u];i;i=e[i].nxt)
{
int v=e[i].to;
if(!used[v])
{
getroot(v)
solve(root);
}
}
}
calc因题而异,主要靠思维。
这两题仅数据范围不同,这里放POJ的代码。
用个值域树状数组可以快速计算出距离不超过一个数的路径个数。
#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
const int N=40010;
const int inf=10000007;
struct edge {
int to,nxt,val;
} e[N<<1];
int head[N],num_edge,rt,k,ans,a[N],b[N],d[N],mn,t[inf],n;
bool used[N];
int max(const int &a,const int &b){return a>b?a:b;}
inline void add(int from,int to,int val) {
++num_edge;
e[num_edge].nxt=head[from];
e[num_edge].val=val;
e[num_edge].to=to;
head[from]=num_edge;
}
#define lt(x) (x&(-x))
void add(int i,int x) {
if(i<=0)return;
while(i<=k) {
t[i]+=x;
i+=lt(i);
}
}
int ask(int i) {
if(i<=0)return 0;
int res=0;
if(i>k)i=k;
while(i) {
res+=t[i];
i-=lt(i);
}
return res;
}
int mx[N],size[N],sum;
void getrt(int u,int fa)
{
mx[u]=0,size[u]=1;
for(int i=head[u];i;i=e[i].nxt)
{
int v=e[i].to;
if(v==fa||used[v])continue;
getrt(v,u);
size[u]+=size[v];
mx[u]=max(mx[u],size[v]);
}
mx[u]=max(sum-mx[u],mx[u]);
if(mx[u]<mx[rt])rt=u;
}
void getdis(int u,int fa,int dis) {
if(dis>k)return;
a[++a[0]]=dis;b[++b[0]]=dis;
for(int i=head[u]; i; i=e[i].nxt) {
int v=e[i].to;
if(v==fa||used[v])continue;
getdis(v,u,dis+e[i].val);
}
}
void calc(int u) {
b[0]=0;
for(int i=head[u]; i; i=e[i].nxt) {
int v=e[i].to;
if(used[v])continue;
a[0]=0;getdis(v,u,e[i].val);
for(int j=1; j<=a[0]; ++j) {
if(a[j]>k)continue;
ans+=ask(k-a[j]);
}
for(int j=1; j<=a[0]; ++j) {
if(a[j]>k)continue;
add(a[j],1);
++ans;
}
}
for(int i=1; i<=b[0]; ++i) {
if(b[i]>k)continue;
add(b[i],-1);
}
}
void solve(int u) {
used[u]=true,calc(u);
for(int i=head[u]; i; i=e[i].nxt) {
int v=e[i].to;
if(!used[v])
{
rt=0;
sum=size[v];
getrt(v,u);
solve(rt);
}
}
}
void clear()
{
num_edge=0;ans=0;
memset(head,0,sizeof(head));
memset(used,false,sizeof(used));
}
int main() {
while(scanf("%d%d",&n,&k)!=EOF)
{
if(n==0&&k==0)return 0;
clear();
for(int i=1,x,y,z; i<n; ++i) {
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);
add(y,x,z);
}
sum=mx[rt=0]=n;
getrt(1,0);
solve(rt);
printf("%d\n",ans);
}
}
路漫漫其修远兮,吾将上下而求索