POJ 1987 Distance Statistics
http://poj.org/problem?id=1987
题意:给一棵树,求树上有多少对节点满足距离<=K
思路:点分治,我们考虑把每个距离都存起来,然后排序,一遍扫描计算一下,注意还要减掉自己加自己的方案。而且,我们还要去掉走到同一个子树的方案。复杂度:O(nlog^2n)
#include<cstdio> #include<cmath> #include<cstring> #include<iostream> #include<algorithm> #define ll long long int tot,go[1000005],first[1000005],next[1000005]; ll st[1000005],val[1000005]; int sum,son[1000005],root,n,F[1000005],c[1000005]; int pd[1000005],sz,vis[1000005]; ll dis[1000005]; int cnt,K,ans; int read(){ int t=0,f=1;char ch=getchar(); while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();} while ('0'<=ch&&ch<='9'){t=t*10+ch-'0';ch=getchar();} return t*f; } void insert(int x,int y,int z){ tot++; go[tot]=y; next[tot]=first[x]; first[x]=tot; val[tot]=z; } void add(int x,int y,int z){ insert(x,y,z);insert(y,x,z); } void findroot(int x,int fa){ son[x]=1;F[x]=0; for (int i=first[x];i;i=next[i]){ int pur=go[i]; if (pur==fa||vis[pur]) continue; findroot(pur,x); son[x]+=son[pur]; F[x]=std::max(F[x],son[pur]); } F[x]=std::max(F[x],sum-son[x]); if (F[x]<F[root]) root=x; } void bfs(int x){ int h=1,t=1;c[1]=x;pd[x]=sz;dis[x]=0; while (h<=t){ int now=c[h++]; for (int i=first[now];i;i=next[i]){ int pur=go[i]; if (vis[pur]||pd[pur]==sz) continue; pd[pur]=sz; dis[pur]=dis[now]+val[i]; c[++t]=pur; st[++cnt]=dis[pur]; } } std::sort(st+1,st+1+cnt); int j=cnt,res=0,Cnt=0; for (int i=1;i<=t;i++){ while (j>1&&st[i]+st[j]>K) j--; if (st[i]+st[j]<=K) res+=j; if (st[i]+st[i]<=K) Cnt++; } res-=Cnt; ans+=res/2; } int del(int x,int Dis){ dis[x]=Dis;sz++; int h=1,t=1;cnt=1;st[cnt]=Dis; pd[x]=sz;c[1]=x; while (h<=t){ int now=c[h++]; for (int i=first[now];i;i=next[i]){ int pur=go[i]; if (pd[pur]==sz||vis[pur]) continue; dis[pur]=dis[now]+val[i]; st[++cnt]=dis[pur]; pd[pur]=sz; c[++t]=pur; } } int j=cnt,res=0,Cnt=0; std::sort(st+1,st+1+cnt); for (int i=1;i<=t;i++){ while (j>1&&st[i]+st[j]>K) j--; if (st[i]+st[j]<=K) res+=j; if (st[i]+st[i]<=K) Cnt++; } res-=Cnt; return res/2; } void solve(int x){ vis[x]=1;++sz; cnt=1;st[cnt]=0; bfs(x); for (int i=first[x];i;i=next[i]){ int pur=go[i]; if (vis[pur]) continue; ans-=del(pur,val[i]); } int Cnt=sum; for (int i=first[x];i;i=next[i]){ int pur=go[i]; if (vis[pur]) continue; if (son[pur]>son[x]) sum=Cnt-son[x]; else sum=son[pur]; root=0; findroot(pur,x); solve(root); } } int main(){ int m; char s[20]; scanf("%d%d\n",&n,&m); for (int i=1;i<n;i++){ int x,y,z; scanf("%d%d%d",&x,&y,&z); add(x,y,z); scanf("%s",s+1); } scanf("%d\n",&K); F[0]=0x7fffffff; root=0;sum=n; findroot(1,0); solve(root); printf("%d\n",ans); }