CF Gym102538A Airplane Cliques

cf

一个自然的想法是在一个点集里选出一个特定的点,在该点处计入点集贡献.由于点集中所有两点间路径的并是个连通块

一个想法就是枚举连通块中深度最浅的点,然后认为在它子树内的距离\(\le x\)的点都可以在点集内.不过这是错的,因为你很轻松就可以找到这个点两棵不同子树内到他距离为\(x\)的点,而这两个点距离为\(2x\)

所以现在就是要选择一个点集中的特定点,满足所有到它距离\(\le x\)的点满足两两距离\(\le x\).反过来考虑(),我们枚举点集中深度最深的点,深度相同就按照编号排序(其实就是找bfs序最大的点),这时候,所有满足bfs序更小的,到这个点距离\(\le x\)的点都是满足两两距离限制的.假设当前枚举的bfs序最大的点为\(a\),现在考虑\(p,q\)两点,路径\([a,p]\)和路径\([a,q]\)的分叉点为\(b\).可以发现\(p,q\)之中最多有一个在分叉点上方

qwq

  • 如果两个点都在\(b\)下方,由于\(a\)为当前深度最深的点,那么一定有\(\max(dis(b,p),dis(b,q))\le dis(a,b)\),所以\(dis(b,p)+dis(b,q)=\max(dis(b,p),dis(b,q))+\min(dis(b,p),dis(b,q))\le dis(a,b)+\min(dis(b,p),dis(b,q))=\min(dis(a,p),dis(a,q))\le x\)

  • 如果有一个点都在\(b\)上方(假设为\(q\)),因为\(dis(b,q)\le dis(a,b)\),所以\(dis(b,p)+dis(b,q)\le dis(a,b)+dis(b,q)=dis(a,q)\le x\)

所以对于每个点\(a\),如果统计出bfs序比它小的,到它距离\(\le x\)的点个数\(cn_a\),那对于\(ans_i\)\(\binom{cn_a-1}{i-1}\)的贡献,这个可以把组合数拆开后ntt计算卷积的值

至于\(cn_a\)的计算可以一个log或两个log,如果是一个log,那么可以先算出\(f_i\)表示以某个点(或一条边上的中点)\(i\)为中点,半径为\(\lfloor\frac{x}{2}\rfloor\)的连通块内点数,然后按照bfs序的逆序枚举点\(u\),到\(u\)距离\(\le x\)且深度不大于\(u\)的连通块点数就是\(u\)往上跳\(\lfloor\frac{x}{2}\rfloor\)距离到的点\(v\)\(f_v\)的值,再考虑bfs序要\(\le u\)的bfs序的话,就每找到一个\(v\)就给\(f_v\)减掉1即可,这样在后面就不会统计到bfs序更大的点了

#include<bits/stdc++.h>
#define LL long long

using namespace std;
const int N=6e5+10,M=(1<<20)+10,mod=998244353;
int rd()
{
    int x=0,w=1;char ch=0;
    while(ch<'0'||ch>'9'){if(ch=='-') w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+(ch^48);ch=getchar();}
    return x*w;
}
void ad(int &x,int y){x+=y,x-=x>=mod?mod:0;}
int fpow(int a,int b){int an=1;while(b){if(b&1) an=1ll*an*a%mod;a=1ll*a*a%mod,b>>=1;}return an;}
int ginv(int a){return fpow(a,mod-2);}
int to[N<<1],nt[N<<1],hd[N],tot=1;
void adde(int x,int y)
{
	++tot,to[tot]=y,nt[tot]=hd[x],hd[x]=tot;
	++tot,to[tot]=x,nt[tot]=hd[y],hd[y]=tot;
}
int n,m,lm,sz[N],f[N],g[N],mx,nsz,rt;
bool ban[N];
void fdrt(int x,int ffa)
{
	sz[x]=1;
	int nx=0;
	for(int i=hd[x];i;i=nt[i])
	{
		int y=to[i];
		if(ban[y]||y==ffa) continue;
		fdrt(y,x),sz[x]+=sz[y],nx=max(nx,sz[y]);
	}
	nx=max(nx,nsz-sz[x]);
	if(mx>nx) mx=nx,rt=x;
}
void d1(int x,int ffa,int de)
{
	m=max(m,de),g[de]+=x<=n;
	for(int i=hd[x];i;i=nt[i])
	{
		int y=to[i];
		if(ban[y]||y==ffa) continue;
		d1(y,x,de+1);
	}
}
void d2(int x,int ffa,int de)
{
	if(de>lm) return;
	f[x]+=g[min(m,lm-de)];
	for(int i=hd[x];i;i=nt[i])
	{
		int y=to[i];
		if(ban[y]||y==ffa) continue;
		d2(y,x,de+1);
	}
}
void wk(int x)
{
	mx=nsz+1,fdrt(x,0);
	x=rt,ban[x]=1,fdrt(x,0);
	d1(x,0,0);
	for(int i=1;i<=m;++i) g[i]+=g[i-1];
	f[x]+=g[min(lm,m)];
	for(int i=hd[x];i;i=nt[i])
	{
		int y=to[i];
		if(ban[y]) continue;
		d2(y,x,1);
	}
	memset(g,0,sizeof(int)*(m+1));
	for(int i=hd[x];i;i=nt[i])
	{
		int y=to[i];
		if(ban[y]) continue;
		m=0,d1(y,x,1);
		for(int i=1;i<=m;++i) g[i]+=g[i-1];
		for(int i=0;i<=m;++i) g[i]=-g[i];
		d2(y,x,1),memset(g,0,sizeof(int)*(m+1));
	}
	for(int i=hd[x];i;i=nt[i])
	{
		int y=to[i];
		if(ban[y]) continue;
		nsz=sz[y],wk(y);
	}
}
int st[N],tp,dp[N],ff[N],sq[N];
void d3(int x,int ffa)
{
	st[++tp]=x,ff[x]=st[max(1,tp-lm)];
	dp[x]=tp;
	for(int i=hd[x];i;i=nt[i])
	{
		int y=to[i];
		if(y==ffa) continue;
		d3(y,x);
	}
	--tp;
}
int fac[N],iac[N],W[21],iW[21],rdr[M],aa[M],bb[M];
void ntt(int *a,int n,bool op)
{
	int l=0,y;
	while((1<<l)<n) ++l;
	for(int i=0;i<n;++i)
	{
		rdr[i]=(rdr[i>>1]>>1)|((i&1)<<(l-1));
		if(i<rdr[i]) swap(a[i],a[rdr[i]]);
	}
	for(int i=1,p=0;i<n;i<<=1,++p)
	{
		int ww=op?W[p]:iW[p];
		for(int j=0;j<n;j+=i<<1)
			for(int k=0,w=1;k<i;++k,w=1ll*w*ww%mod)
			{
				y=1ll*a[j+k+i]*w%mod;
				a[j+k+i]=(a[j+k]-y+mod)%mod;
				a[j+k]=(a[j+k]+y)%mod;
			}
	}
	if(!op) for(int i=0,w=ginv(n);i<n;++i) a[i]=1ll*a[i]*w%mod;
}

int main()
{
	freopen("1.in","r",stdin); 
	freopen("1.out","w",stdout);
	for(int i=1,p=0;p<=20;i<<=1,++p)
		W[p]=fpow(3,(mod-1)/(i<<1)),iW[p]=ginv(W[p]);
	fac[0]=1;
	for(int i=1;i<=N-5;++i) fac[i]=1ll*fac[i-1]*i%mod;
	iac[N-5]=ginv(fac[N-5]);
	for(int i=N-5;i;--i) iac[i-1]=1ll*iac[i]*i%mod;
	n=rd(),lm=rd();
	for(int i=1;i<n;++i)
	{
		int x=rd(),y=rd();
		adde(x,i+n),adde(y,i+n);
	}
	nsz=n+n-1,wk(1);
	d3(1,0);
	for(int i=1;i<=n;++i) sq[i]=i;
	sort(sq+1,sq+n+1,[&](int aa,int bb){return dp[aa]>dp[bb];});
	for(int i=1;i<=n;++i)
	{
		int x=sq[i];
		--f[ff[x]],++aa[f[ff[x]]];
	}
	for(int i=0;i<=n;++i) aa[i]=1ll*aa[i]*fac[i]%mod;
	for(int i=0;i<=n;++i) bb[i]=iac[n-i];
	int len=1;
	while(len<=n+n+2) len<<=1;
	ntt(aa,len,1),ntt(bb,len,1);
	for(int i=0;i<len;++i) aa[i]=1ll*aa[i]*bb[i]%mod;
	ntt(aa,len,0);
	for(int i=0;i<n;++i) printf("%d ",(int)(1ll*aa[n+i]*iac[i]%mod));
	return 0;
}
posted @ 2020-04-14 17:13  ✡smy✡  阅读(351)  评论(3编辑  收藏  举报