codeforces 161D Distance in Tree 树上点分治
链接:https://codeforces.com/contest/161/problem/D
题意:给一个树,求距离恰好为$k$的点对是多少
题解:对于一个树,距离为$k$的点对要么经过根节点,要么跨过子树的根节点,于是考虑树分治
用类似poj1741的想法,可以推出:
对于任意一棵子树,其根节点记为$C$,其子树中:
记距离$C$距离之和为$k$的点对数量$S_{c}$
记$C$儿子节点$C_1...C_n$的子树中,距离$C_i$距离为$k-2$的点对数量为$S'_{c_i}$
其符合条件的点对数量即为$S_{c}-\sum_1^n S'_{c_i}$
(网上这题,主流的树分治写法好像不是这个...有些看不懂啊....)
树上点分治参考我之前的题解:https://www.cnblogs.com/nervendnig/p/10106333.html
速度还是很可以的
相比dp的话,dp收到$K$大小的限制,如果$K$的大小和N同级别,就很难朴素的DP了,可能就要考虑树上倍增DP(实际上好像不能倍增)
而分治显然并不受限制
具体参见代码:
#include <bits/stdc++.h> #define endl '\n' #define ll long long #define ull unsigned long long #define fi first #define se second #define mp make_pair #define pii pair<int,int> #define all(x) x.begin(),x.end() #define IO ios::sync_with_stdio(false) #define rep(ii,a,b) for(int ii=a;ii<=b;++ii) #define per(ii,a,b) for(int ii=b;ii>=a;--ii) #define forn(x,i) for(int i=head[x];i;i=e[i].next) #pragma GCC optimize("Ofast") #pragma GCC optimize("unroll-loops") #define inline inline __attribute__( \ (always_inline, __gnu_inline__, __artificial__)) \ __attribute__((optimize("Ofast"))) __attribute__((target("sse"))) __attribute__((target("sse2"))) __attribute__((target("mmx"))) using namespace std; #define tpyeinput int char nc() {static char buf[1000000],*p1=buf,*p2=buf;return p1==p2&&(p2=(p1=buf)+fread(buf,1,1000000,stdin),p1==p2)?EOF:*p1++;} void read(tpyeinput &sum) {register char ch=nc();int flag=1;sum=0;while(ch<'0'||ch>'9') {if(ch=='-') flag=-1;ch=nc();}while(ch>='0'&&ch<='9') sum=(sum<<3)+(sum<<1)+(ch-48),ch=nc();sum*=flag;} void read(tpyeinput &num1,tpyeinput &num2) {read(num1);read(num2);} const int maxn=1e5+10,maxm=2e5+10; const int INF=0x3f3f3f3f; const int mod=1e9+7; const double PI=acos(-1.0); //head int casn,n,m,k,mid,allnode; struct node {int to,next;}e[maxm];int head[maxn],nume; void add(int a,int b){e[++nume]=(node){b,head[a]};head[a]=nume;} int sz[maxn],maxt,deep[maxn],vis[maxn],cnt; ll ans; void getc(int now,int pre){ sz[now]=1; for(int i=head[now];i;i=e[i].next){ if(e[i].to==pre||vis[e[i].to])continue; getc(e[i].to,now); sz[now]+=sz[e[i].to]; } int tmp=max(sz[now]-1,allnode-sz[now]); if(maxt>tmp) maxt=tmp,mid=now; } void dfs(int now,int pre,int len,int dis){ deep[++cnt]=dis; if(dis>=len)return; for(int i=head[now];i;i=e[i].next){ if(e[i].to==pre||vis[e[i].to])continue; dfs(e[i].to,now,len,dis+1); } } ll cal(int rt,int pre,int len){ if(len<=0) return len==0; cnt=0; dfs(rt,pre,len,0); ll res=0; int num[507]{}; rep(i,1,cnt) num[deep[i]]++; rep(i,1,cnt) res+=num[len-deep[i]]; return res; } void dc(int rt){ vis[rt]=1; ans+=cal(rt,0,k); for(int i=head[rt];i;i=e[i].next){ if(vis[e[i].to]) continue; ans-=cal(e[i].to,rt,k-2); allnode=sz[e[i].to],maxt=n; getc(e[i].to,rt);dc(mid); } } int main() { //#define test #ifdef test auto _start = chrono::high_resolution_clock::now(); freopen("in.txt","r",stdin);freopen("out.txt","w",stdout); #endif read(n,k); int a,b; rep(i,1,n-1){ read(a,b); add(a,b);add(b,a); } allnode=n; maxt=INF; getc(1,0); dc(mid); printf("%lld",ans/2); #ifdef test auto _end = chrono::high_resolution_clock::now(); cerr << "elapsed time: " << chrono::duration<double, milli>(_end - _start).count() << " ms\n"; fclose(stdin);fclose(stdout);system("out.txt"); #endif return 0; }