Loading

POJ1471 Tree/洛谷P4178 Tree

Tree
P4178 Tree

点分治板子。

点分治就是直接找树的重心进行暴力计算,每次树的深度不会超过子树深度的\(\frac{1}{2}\),计算完就消除影响,找下一个重心。

所以伪代码:

void solve(int u)
{
    calc(u);
    used[u]=true;
    for(int i=head[u];i;i=e[i].nxt)
    {
        int v=e[i].to;
        if(!used[v])
        {
            getroot(v)
            solve(root);
        }
    }
}

calc因题而异,主要靠思维。

这两题仅数据范围不同,这里放POJ的代码。

用个值域树状数组可以快速计算出距离不超过一个数的路径个数。

#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
const int N=40010;
const int inf=10000007;
struct edge {
    int to,nxt,val;
} e[N<<1];
int head[N],num_edge,rt,k,ans,a[N],b[N],d[N],mn,t[inf],n;
bool used[N];
int max(const int &a,const int &b){return a>b?a:b;}
inline void add(int from,int to,int val) {
    ++num_edge;
    e[num_edge].nxt=head[from];
    e[num_edge].val=val;
    e[num_edge].to=to;
    head[from]=num_edge;
}
#define lt(x) (x&(-x))
void add(int i,int x) {
    if(i<=0)return;
    while(i<=k) {
        t[i]+=x;
        i+=lt(i);
    }
}
int ask(int i) {
    if(i<=0)return 0;
    int res=0;
    if(i>k)i=k;
    while(i) {
        res+=t[i];
        i-=lt(i);
    }
    return res;
}
int mx[N],size[N],sum;
void getrt(int u,int fa)
{
	mx[u]=0,size[u]=1;
	for(int i=head[u];i;i=e[i].nxt)
	{
		int v=e[i].to;
		if(v==fa||used[v])continue;
		getrt(v,u);
		size[u]+=size[v];
		mx[u]=max(mx[u],size[v]);
	}	
	mx[u]=max(sum-mx[u],mx[u]);
	if(mx[u]<mx[rt])rt=u;
}
void getdis(int u,int fa,int dis) {
    if(dis>k)return;
    a[++a[0]]=dis;b[++b[0]]=dis;
    for(int i=head[u]; i; i=e[i].nxt) {
        int v=e[i].to;
        if(v==fa||used[v])continue;
        getdis(v,u,dis+e[i].val);
    }
}
void calc(int u) {
	b[0]=0;
    for(int i=head[u]; i; i=e[i].nxt) {
        int v=e[i].to;
        if(used[v])continue;
        a[0]=0;getdis(v,u,e[i].val);
        for(int j=1; j<=a[0]; ++j) {
            if(a[j]>k)continue;
            ans+=ask(k-a[j]);
        }
        for(int j=1; j<=a[0]; ++j) {
            if(a[j]>k)continue;
            add(a[j],1);
            ++ans;
        }
    }
    for(int i=1; i<=b[0]; ++i) {
        if(b[i]>k)continue;
        add(b[i],-1);
    }
}
void solve(int u) {
    used[u]=true,calc(u);
    for(int i=head[u]; i; i=e[i].nxt) {
        int v=e[i].to;
        if(!used[v])
        {
        	rt=0;
        	sum=size[v];
        	getrt(v,u);
        	solve(rt);
		}
    }
}
void clear()
{
    num_edge=0;ans=0;
    memset(head,0,sizeof(head));
    memset(used,false,sizeof(used));
}
int main() {
    while(scanf("%d%d",&n,&k)!=EOF)
    {
        if(n==0&&k==0)return 0;
        clear();
        for(int i=1,x,y,z; i<n; ++i) {
            scanf("%d%d%d",&x,&y,&z);
            add(x,y,z);
            add(y,x,z);
        }
        sum=mx[rt=0]=n;
        getrt(1,0);
        solve(rt);
        printf("%d\n",ans);
    }
}
posted @ 2020-02-23 17:37  zzctommy  阅读(198)  评论(0编辑  收藏  举报