codeforces 161D Distance in Tree 树上点分治

链接:https://codeforces.com/contest/161/problem/D

题意:给一个树,求距离恰好为$k$的点对是多少

题解:对于一个树,距离为$k$的点对要么经过根节点,要么跨过子树的根节点,于是考虑树分治

用类似poj1741的想法,可以推出:

对于任意一棵子树,其根节点记为$C$,其子树中:

记距离$C$距离之和为$k$的点对数量$S_{c}$

记$C$儿子节点$C_1...C_n$的子树中,距离$C_i$距离为$k-2$的点对数量为$S'_{c_i}$

其符合条件的点对数量即为$S_{c}-\sum_1^n S'_{c_i}$

(网上这题,主流的树分治写法好像不是这个...有些看不懂啊....)

树上点分治参考我之前的题解:https://www.cnblogs.com/nervendnig/p/10106333.html

速度还是很可以的

相比dp的话,dp收到$K$大小的限制,如果$K$的大小和N同级别,就很难朴素的DP了,可能就要考虑树上倍增DP(实际上好像不能倍增)

而分治显然并不受限制

具体参见代码:

#include <bits/stdc++.h>
#define endl '\n'
#define ll long long
#define ull unsigned long long
#define fi first
#define se second
#define mp make_pair
#define pii pair<int,int>
#define all(x) x.begin(),x.end()
#define IO ios::sync_with_stdio(false)
#define rep(ii,a,b) for(int ii=a;ii<=b;++ii)
#define per(ii,a,b) for(int ii=b;ii>=a;--ii)
#define forn(x,i) for(int i=head[x];i;i=e[i].next)
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#define inline inline __attribute__(                                   \
    (always_inline, __gnu_inline__, __artificial__))                   \
        __attribute__((optimize("Ofast"))) __attribute__((target("sse"))) __attribute__((target("sse2"))) __attribute__((target("mmx")))
using namespace std;
#define tpyeinput int
char nc() {static char buf[1000000],*p1=buf,*p2=buf;return p1==p2&&(p2=(p1=buf)+fread(buf,1,1000000,stdin),p1==p2)?EOF:*p1++;}
void read(tpyeinput &sum) {register char ch=nc();int flag=1;sum=0;while(ch<'0'||ch>'9') {if(ch=='-') flag=-1;ch=nc();}while(ch>='0'&&ch<='9') sum=(sum<<3)+(sum<<1)+(ch-48),ch=nc();sum*=flag;}
void read(tpyeinput &num1,tpyeinput &num2) {read(num1);read(num2);}
const int maxn=1e5+10,maxm=2e5+10;
const int INF=0x3f3f3f3f;
const int mod=1e9+7;
const double PI=acos(-1.0);
//head
int casn,n,m,k,mid,allnode;
struct node {int to,next;}e[maxm];int head[maxn],nume;
void add(int a,int b){e[++nume]=(node){b,head[a]};head[a]=nume;}
int sz[maxn],maxt,deep[maxn],vis[maxn],cnt;
ll ans;
void getc(int now,int pre){
	sz[now]=1;
	for(int i=head[now];i;i=e[i].next){
		if(e[i].to==pre||vis[e[i].to])continue;
		getc(e[i].to,now);
		sz[now]+=sz[e[i].to];
	}
	int tmp=max(sz[now]-1,allnode-sz[now]);
	if(maxt>tmp) maxt=tmp,mid=now;
}
void dfs(int now,int pre,int len,int dis){
	deep[++cnt]=dis;
	if(dis>=len)return;
	for(int i=head[now];i;i=e[i].next){
		if(e[i].to==pre||vis[e[i].to])continue;
		dfs(e[i].to,now,len,dis+1);
	}
}
ll cal(int rt,int pre,int len){
	if(len<=0) return len==0;
	cnt=0;
	dfs(rt,pre,len,0);
	ll res=0;
	int num[507]{};
	rep(i,1,cnt) num[deep[i]]++;
	rep(i,1,cnt) res+=num[len-deep[i]];
	return res;
}
void dc(int rt){
	vis[rt]=1;
	ans+=cal(rt,0,k);
	for(int i=head[rt];i;i=e[i].next){
		if(vis[e[i].to]) continue;
		ans-=cal(e[i].to,rt,k-2);
		allnode=sz[e[i].to],maxt=n;
		getc(e[i].to,rt);dc(mid);
	}
}
int main() {
//#define test
#ifdef test
	auto _start = chrono::high_resolution_clock::now();
	freopen("in.txt","r",stdin);freopen("out.txt","w",stdout);
#endif
	read(n,k);
	int a,b;
	rep(i,1,n-1){
		read(a,b);
		add(a,b);add(b,a);
	}
	allnode=n;
	maxt=INF;
	getc(1,0);
	dc(mid);
	printf("%lld",ans/2);
#ifdef test
	auto _end = chrono::high_resolution_clock::now();
  cerr << "elapsed time: " << chrono::duration<double, milli>(_end - _start).count() << " ms\n";
	fclose(stdin);fclose(stdout);system("out.txt");
#endif
	return 0;
}

  

posted @ 2018-12-12 06:29  nervending  阅读(375)  评论(0编辑  收藏  举报