【POJ1741】 tree
树分治入门题,感觉树分治的思路很劲啊
原题:
给一颗n个节点的树,每条边上有一个距离v(v<=1000)。
定义d(u,v)为u到v的最小距离;
给定k值,求有多少点对(u,v)使u到v的距离小于等于k。
n<=10000,k<2^31,v<=1000
感觉最近几天状态一直不是很好,树分治的总结再写几道题再说吧= =
树分治核心思路是每次的到本节点的某个子树中的最优(复杂度)根,然后以这个根为起始点向这个子树的其它点扩张
这题的思路是每次得到以本节点为根的子树所有节点的树上前缀和,排一下序,从l=1和r=top判断,如果r的前缀和-l的前缀和<=k,答案就+=r-l+1并l++,否则r--
因为上面get答案的时候是以本节点为根的所有子树,所以还要去掉两个节点在本节点的同一个子树中的情况,这个时候只需要给本节点的每一个儿子前缀和设为本节点到儿子的权值(上面的设为0),然后再给儿子做上面的操作,就可以清掉在本节点的同一个子树中的情况辣(看不懂可以自己手玩
树分治几个函数嵌套有点多,感觉思路好劲啊(不过倒是不容易出bug(相比搞基数据结构(也许是因为我是抄chty的代码写的
代码:
1 #include<iostream> 2 #include<cstdio> 3 #include<algorithm> 4 #include<cstring> 5 #include<cmath> 6 #include<vector> 7 using namespace std; 8 const int oo=168430090; 9 int rd(){int z=0,mk=1; char ch=getchar(); 10 while(ch<'0'||ch>'9'){if(ch=='-')mk=-1; ch=getchar();} 11 while(ch>='0'&&ch<='9'){z=(z<<3)+(z<<1)+ch-'0'; ch=getchar();} 12 return z*mk; 13 } 14 struct ddd{int y,v;}; vector <ddd> e[11000]; 15 inline void ist(int x,int y,int z){ e[x].push_back((ddd){y,z});} 16 int n,m; 17 int ans,rt; 18 int sz[11000],f[11000],dp[11000]; 19 bool vst[11000]; int dst[11000]; 20 int cnt; 21 void clear(){ 22 memset(vst,0,sizeof(vst)); 23 ans=rt=0; 24 for(int i=1;i<=n;++i) e[i].clear(); 25 } 26 void gtrt(int x,int fa){ 27 sz[x]=1; f[x]=0; 28 for(int i=0;i<e[x].size();++i)if(e[x][i].y!=fa && !vst[e[x][i].y]){ 29 gtrt(e[x][i].y,x); sz[x]+=sz[e[x][i].y]; 30 f[x]=max(f[x],sz[e[x][i].y]); 31 } 32 f[x]=max(f[x],cnt-sz[x]); 33 if(f[x]<f[rt]) rt=x; 34 } 35 void gtdp(int x,int fa){ 36 dp[++dp[0]]=dst[x]; 37 for(int i=0;i<e[x].size();++i)if(e[x][i].y!=fa && !vst[e[x][i].y]) 38 dst[e[x][i].y]=dst[x]+e[x][i].v,gtdp(e[x][i].y,x); 39 } 40 int gtans(int x,int y){ 41 dst[x]=y,dp[0]=0,gtdp(x,0); 42 sort(dp+1,dp+dp[0]+1); 43 int l=1,r=dp[0],bwl=0; 44 while(l<r){ 45 if(dp[l]+dp[r]<=m) bwl+=r-l,++l; 46 else --r; 47 } 48 return bwl; 49 } 50 void ptt(int x){ 51 ans+=gtans(x,0); vst[x]=true; 52 for(int i=0;i<e[x].size();++i)if(!vst[e[x][i].y]){ 53 ans-=gtans(e[x][i].y,e[x][i].v); 54 rt=0,cnt=sz[e[x][i].y]; 55 gtrt(e[x][i].y,0),ptt(rt); 56 } 57 } 58 int main(){//freopen("ddd.in","r",stdin); 59 for(;;){ 60 cnt=n=rd(),m=rd(); 61 if(n==0 && m==0) break; 62 clear(); 63 int l,r,v; 64 for(int i=1;i<n;++i){ 65 l=rd(),r=rd(),v=rd(); 66 ist(l,r,v),ist(r,l,v); 67 } 68 f[0]=oo; gtrt(1,0),ptt(rt); 69 printf("%d\n",ans); 70 } 71 return 0; 72 }