「CSP-S 2019」树的重心

题目

考场上送\(75pts\)真实良心,正解不难;考虑直接对于每一个点算割掉多少条边能使得这个点成为重心,不难发现对于一个不是重心的点,我们要割掉的那条边一定在那个大于\(\lfloor \frac{n}{2} \rfloor\)的子树里面,而最大子树割掉之后可能就不是最大的了,但新的最大子树只可能是原来的次大子树,推一下柿子要割掉的子树大小需要在\([2A-n,n-2B]\)之间,其中\(A\)为最大子树,\(B\)为次大子树

于是我们先求一个重心作为根,这样所有非重心节点的最大子树就会跨过这个根,在dfs的过程就能更新子树的大小,用树状数组维护一下就好了,由于需要排除子树内部的情况,所以还需要一个线段树合并;至于重心节点不超过两个,暴力求一下就好

代码

#include<bits/stdc++.h>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
inline int read() {
    char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
    while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
const int maxn=3e5+5;
const int M=maxn*25;
int l[M],r[M],d[M];
struct E{int v,nxt;}e[maxn<<1];
int T,n,num,__,cnt;LL ans=0;
int head[maxn],sum[maxn],rt[maxn],mx[maxn],c[maxn],t[maxn],sz[maxn],col[maxn];
inline void add(int x,int v) {
    for(re int i=x;i<=n;i+=i&(-i)) c[i]+=v;
}
inline int ask(int x) {
    int nw=0;
    for(re int i=x;i;i-=i&(-i)) nw+=c[i];
    return nw;
}
inline void add_E(int x,int y) {
    e[++num].v=y;e[num].nxt=head[x];head[x]=num;
}
void Dfs(int x,int fa) {
    sum[x]=1;mx[x]=0;
    for(re int i=head[x];i;i=e[i].nxt) {
	if(e[i].v==fa) continue;
	Dfs(e[i].v,x);sum[x]+=sum[e[i].v];mx[x]=max(mx[x],sum[e[i].v]);
    }
    mx[x]=max(n-sum[x],mx[x]);
}
int ins(int nw,int x,int y,int pos) {
    if(!nw) nw=++cnt,d[nw]=l[nw]=r[nw]=0;d[nw]++;
    if(x==y) return nw;
    int mid=x+y>>1;
    if(pos<=mid) l[nw]=ins(l[nw],x,mid,pos);
    else r[nw]=ins(r[nw],mid+1,y,pos);
    return nw;
}
int merge(int a,int b,int x,int y) {
    if(!a||!b) return a|b;
    if(x==y) {
	d[a]+=d[b];
	return a;
    }
    int mid=x+y>>1;
    l[a]=merge(l[a],l[b],x,mid);r[a]=merge(r[a],r[b],mid+1,y);
    d[a]=d[l[a]]+d[r[a]];return a;
}
int query(int nw,int x,int y,int lx,int ry) {
    if(!nw||lx>ry) return 0;
    if(lx<=x&&ry>=y) return d[nw];
    int mid=x+y>>1,h=0;
    if(lx<=mid) h+=query(l[nw],x,mid,lx,ry);
    if(ry>mid) h+=query(r[nw],mid+1,y,lx,ry);
    return h;
}
void dfs(int x,int fa) {
	rt[x]=ins(rt[x],1,n,sz[x]);
    if(fa) add(sz[fa],-1),add(n-sz[x],1);
    for(re int i=head[x];i;i=e[i].nxt)
	if(e[i].v!=fa) dfs(e[i].v,x),rt[x]=merge(rt[x],rt[e[i].v],1,n);
    if(mx[x]+mx[x]>n) {
	int k=0;
	if(mx[x]-t[x]>=2*mx[x]-n) k=ask(mx[x]-t[x])-ask(2*mx[x]-n-1);
	if(mx[x]-t[x]<n-2*t[x]) k+=ask(n-2*t[x])-ask(mx[x]-t[x]);
	k-=query(rt[x],1,n,2*mx[x]-n,mx[x]-t[x]);
	k-=query(rt[x],1,n,mx[x]-t[x]+1,n-2*t[x]);
	ans+=1ll*k*x;
    }
    if(fa) add(sz[fa],1),add(n-sz[x],-1);
}
void Dfs_(int x,int fa) {
    sz[x]=1,t[x]=0;
    for(re int i=head[x];i;i=e[i].nxt) {
	if(e[i].v==fa) continue;
	Dfs_(e[i].v,x);sz[x]+=sz[e[i].v];t[x]=max(t[x],sz[e[i].v]);
    }
}
void DFs(int x,int fa,int cm) {
    col[x]=cm;sz[x]=1;
    for(re int i=head[x];i;i=e[i].nxt) {
	if(e[i].v==fa) continue;
	DFs(e[i].v,x,cm);sz[x]+=sz[e[i].v];
    } 
}
void solve(int Rt) {
    int col_num=1,A=0,B=0;
    for(re int i=head[Rt];i;i=e[i].nxt,col_num++) {
	 DFs(e[i].v,Rt,col_num);
	 if(sz[e[i].v]>=sz[A]) B=A,A=e[i].v;
	 else if(sz[e[i].v]>sz[B]) B=e[i].v;
    }
    for(re int i=1;i<=n;i++) {
	if(i==Rt) continue;
	if(col[i]!=col[A]&&2*sz[A]<=(n-sz[i])) ans+=Rt;
	if(col[i]==col[A]&&2*max(sz[A]-sz[i],sz[B])<=(n-sz[i])) ans+=Rt;
    }
}
int main() {
    T=read();
    for(re int Rt;T;--T) {
	cnt=0;ans=0;n=read(),num=0,__=0;memset(head,0,sizeof(head));memset(rt,0,sizeof(rt));memset(c,0,sizeof(c));
	for(re int x,y,i=1;i<n;i++) x=read(),y=read(),add_E(x,y),add_E(y,x);
	Dfs(1,0);for(re int i=1;i<=n;i++) if(mx[i]+mx[i]<=n) Rt=i;
	Dfs_(Rt,0);
	for(re int i=1;i<=n;i++) add(sz[i],1);dfs(Rt,0);
	for(re int i=1;i<=n;i++) if(mx[i]+mx[i]<=n) solve(i);
	printf("%lld\n",ans);
    }
    return 0;
}
posted @ 2019-12-05 16:02  asuldb  阅读(486)  评论(0编辑  收藏  举报