POJ1741:Tree

浅谈树分治:https://www.cnblogs.com/AKMer/p/10014803.html

题目传送门:http://poj.org/problem?id=1741

这是一道树分治的模板题。

我们考虑当前经过联通块的重心\(rt\)的路径小于等于\(k\)的有多少条,不经过\(rt\)的等把\(rt\)这个点删了之后再递归去统计。

我们可以记录一下从\(rt\)出发到每个点的距离\(dis\),然后把\(dis\)从小到大排序,用两个指针\(l\)\(r\)分别指向\(dis\)数组的开头和结尾。每次以\(l\)为基准,找能与\(l\)匹配起来小于等于\(k\)的另一条路径。假设\(l\)\(dis\)加上\(r\)\(dis\)小于等于\(k\),且\(r\)是满足这个条件的最长的路径,那么区间\([l,r]\)里的所有\(dis\)加上\(l\)\(dis\)都会小于等于\(k\)。所以对于\(l\),能与它匹配的路径条数是\(r-l+1-num\)\(num\)表示区间\([l,r]\)中与\(l\)在同一个子树里的路径条数,这个我们可以开个桶动态维护。然后对于\(l\)不断增加,\(r\)必然是递减的,所以复杂度是\(O(n)\)的。因为每个点只会被递归\(log\)次,所以总复杂度是\(O(nlog^2n)\)的。

边分治做法差不太多,不过空间会因为要重建树而翻一倍。

时间复杂度:\(O(nlog^2n)\)

空间复杂度:\(O(n)\)

点分治版代码如下:

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int maxn=1e4+5;

bool vis[maxn];
int siz[maxn],num[maxn];
int n,limit,tot,ans,mx,rt,cnt,N;
int now[maxn],son[maxn*2],pre[maxn*2],val[maxn*2];

int read() {
	int x=0,f=1;char ch=getchar();
	for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
	for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
	return x*f;
}

struct road {
	int bel,dis;

	road() {}

	road(int _bel,int _dis) {
		bel=_bel,dis=_dis;
	}

	bool operator<(const road &a)const {
		return dis<a.dis;
	}
}tmp[maxn];

void add(int a,int b,int c) {
	pre[++tot]=now[a];
	now[a]=tot,son[tot]=b,val[tot]=c;
}

void clear() {
	ans=tot=0;
	memset(now,0,sizeof(now));
	memset(vis,0,sizeof(vis));
}

void find_root(int fa,int u) {
	int res=0;siz[u]=1;
	for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
		if(!vis[v]&&v!=fa)
			find_root(u,v),siz[u]+=siz[v],res=max(res,siz[v]);
	res=max(res,N-siz[u]);
	if(res<mx)mx=res,rt=u;
}

void solve(int belong,int fa,int u,int dis) {
	tmp[++cnt]=road(belong,dis);siz[u]=1;//之前求出的siz是以u为根的,现在顺便把以rt为根时的siz求出来,递归下去就直接可以调用了。
	for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
		if(!vis[v]&&v!=fa)
			solve(belong,u,v,dis+val[p]),siz[u]+=siz[v];
}

void point_division(int u,int size) {//size表示当前联通块总大小
	N=size,mx=rt=n+1,find_root(0,u);
	u=rt,cnt=0,num[0]=1,vis[u]=1;
	for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
		if(!vis[v])solve(v,u,v,val[p]),num[v]=siz[v];
	sort(tmp+1,tmp+cnt+1);int l=0,r=cnt;
	while(l<r) {
		while(r>l&&tmp[l].dis+tmp[r].dis>limit)num[tmp[r].bel]--,r--;
		if(l<r)ans+=(r-l+1)-num[tmp[l].bel];num[tmp[l].bel]--,l++;
	}
	for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
		if(!vis[v])point_division(v,siz[v]);
}

int main() {
	while(1) {
		n=read(),limit=read();
		if(!n)break; clear();
		for(int i=1;i<n;i++) {
			int x=read(),y=read(),v=read();
			add(x,y,v),add(y,x,v);
		}
		point_division(1,n);
		printf("%d\n",ans);
	}
	return 0;
}

边分治版代码如下:

#include <cmath>
#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef pair<int,int> pii;
#define fr first
#define sc second

const int maxn=2e4+5;

bool vis[maxn];
int siz[maxn],num[maxn];
int n,limit,tot=1,ans,mx,id,cnt,N;
int now[maxn],pre[maxn*2],son[maxn*2],val[maxn*2];

vector<pii>to[maxn];
vector<pii>::iterator it;

int read() {
    int x=0,f=1;char ch=getchar();
    for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
    for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
    return x*f;
}

struct road {
    int bel,dis;

    road() {}

    road(int _bel,int _dis) {
        bel=_bel,dis=_dis;
    }

    bool operator<(const road &a)const {
        return dis<a.dis;
    }
}tmp[maxn];

void clear() {
    tot=ans=0;
    memset(vis,0,sizeof(vis));
    memset(now,0,sizeof(now));
}

void add(int a,int b,int c) {
    pre[++tot]=now[a];
    now[a]=tot,son[tot]=b,val[tot]=c;
}

void find_son(int fa,int u) {
    to[u].clear();
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(v!=fa)to[u].push_back(make_pair(v,val[p])),find_son(u,v);
}

void rebuild() {
    tot=1,memset(now,0,sizeof(now));
    for(int i=1;i<=cnt;i++) {
        int size=to[i].size();
        if(size<=2) {
            for(it=to[i].begin();it!=to[i].end();it++) {
                pii tmp=*it;
                add(i,tmp.fr,tmp.sc),add(tmp.fr,i,tmp.sc);
            }
        }
		else {
			pii u1=make_pair(++cnt,0),u2;
			if(size==3)u2=to[i].front();else u2=make_pair(++cnt,0);
			add(i,u1.fr,u1.sc),add(u1.fr,i,u1.sc);
			add(i,u2.fr,u2.sc),add(u2.fr,i,u2.sc);
			if(size>3) {
				to[cnt-1].clear();to[cnt].clear();int tmp=0;
				for(it=to[i].begin();it!=to[i].end();it++) {
					if(!tmp)to[cnt].push_back(*it);
					else to[cnt-1].push_back(*it);tmp^=1;
				}
			}
			else {
				to[cnt].clear();
				for(int j=1;j<=2;j++)
					to[cnt].push_back(to[i].back()),to[i].pop_back();
			}
		}
    }
}

void find_edge(int fa,int u) {
    siz[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[p>>1]&&v!=fa) {
            find_edge(u,v),siz[u]+=siz[v];
            if(abs(N-2*siz[v])<mx)mx=abs(N-2*siz[v]),id=p>>1;
        }
}

void solve(int bel,int fa,int u,int dis) {
    if(u<=n)tmp[++tot]=road(bel,dis),num[bel]++;siz[u]=1;
    for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
        if(!vis[p>>1]&&v!=fa)solve(bel,u,v,dis+val[p]),siz[u]+=siz[v];
}

void work(int u,int size) {
    N=size,mx=id=cnt+1,find_edge(0,u);
    tot=0,vis[id]=1;if(id==cnt+1)return;
    int u1=son[id<<1],u2=son[id<<1|1];
    num[u1]=num[u2]=0;
    solve(u1,0,u1,0),solve(u2,0,u2,0);
    sort(tmp+1,tmp+tot+1);int l=1,r=tot;
    while(l<r) {
        while(r>l&&tmp[r].dis+tmp[l].dis>limit-val[id<<1])num[tmp[r].bel]--,r--;
        if(l<r)ans+=(r-l+1)-num[tmp[l].bel];num[tmp[l].bel]--,l++;
    }
    work(u1,siz[u1]),work(u2,siz[u2]);
}

int main() {
    while(1) {
		cnt=n=read();limit=read();
		if(!n)break; clear();
		for(int i=1;i<n;i++) {
			int x=read(),y=read(),v=read();
			add(x,y,v),add(y,x,v);
		}
		find_son(0,1);rebuild();
		work(1,cnt);printf("%d\n",ans);
    }
    return 0;
}
posted @ 2018-11-28 22:47  AKMer  阅读(156)  评论(0编辑  收藏  举报