[atAGC062E]Overlap Binary Tree

\(m=\frac{n+1}{2}\),即二叉树的叶子个数

对于合法序列,按以下方式生成其对应的二叉树:

(此处二叉树指无标号以一个点为根每个非叶节点恰有两个儿子的树)

  • 恰存在一个区间与其余区间均有交,将其作为根并(在序列中)删除
  • 恰存在一个\(i\in [1,n)\)使得\(\max_{1\le j\le i}R_{j}<L_{i+1}\),将\([1,i]\)\((i,n]\)作为左右子树并递归处理

另一方面,考虑一棵二叉树所对应的合法序列数:

从底向上依次确定\(L,R\)的相对顺序,考虑加入一个节点\(k\)

  • 对于\(k\)的左右子树(点集分别为\(ls\)\(rs\)),直接拼接即可

  • 对于\(k\)自身,即要求\(L_{k}<\min_{i\in ls}R_{i}\)\(R_{k}>\max_{i\in rs}L_{i}\),两者独立

    在相对顺序下,前者方案数也即\(\sum_{i\in ls}[L_{i}<\min_{i\in ls}R_{i}]+1\),而这个值对于\(k\)\(+1\)(后者同理)

    换言之,两者分别为\(k\)左链和右链的长度(指边数)\(+1\)

另外,若钦定\(L_{i}+1=R_{i}\),即使得其以\(i\)为左链或右链的末尾节点时不\(+1\)

\(A\)为所有叶子向上"同方向"次数所构成的可重集,答案即\(\sum_{S\subseteq A,|S|=k}\prod_{i\in S}i!\prod_{i\in A-S}(i+1)!\)

对于二叉树,按以下方式生成其对应的新树:

(此处新树指无标号以一条边为根根边端点有序每个点儿子间有序的树)

  • 点集为所有叶子,并在(二叉树的)每个非叶节点左链和右链的末尾节点间连边
  • 以二叉树根节点所对应的边为根,每个点的儿子按该边对应的非叶节点深度排序

注意到两者一一对应,且每个点的度数即为对应的叶子在\(A\)中的元素

换言之,问题即转换为以下形式:

对于所有\(m\)个点的新树,记\(D\)所有点度数所构成的可重集,求\(\sum_{S\subseteq D,|S|=k}\prod_{i\in S}i!\prod_{i\in D-S}(i+1)!\)之和

不妨对该树标号,记\(d_{i}\)为点\(i\)的度数,钦定\(S=\{d_{1},d_{2},...,d_{k}\}\),并将最终答案除以\(k!(m-k)!\)即可

  • 每次选择一个点及其一个儿子(编号),除其所在连通块的根外,其余无父亲节点均可作为该儿子

上述过程共\(m-2\)轮,第\(i\)轮方案数为\((m-i-1)(m-i)\),最终根边端点有\(2\)种选法

同时,由于加边顺序并不影响,每棵树被计算\((m-2)!\)次,即\(\frac{2\prod_{i=1}^{m-2}(m-i-1)(m-i)}{(m-2)!}=2(m-1)!\)

综上,最终答案即\(\frac{2(m-1)!}{k!(m-k)!}[x^{n-1}](\sum_{i=1}^{m-1}i!x^{i})^{k}(\sum_{i=1}^{m-1}((i+1)!-i!)x^{i})^{m-k}\),时间复杂度为\(O(n\log n)\)

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef vector<int> vi;
const int mod=998244353;
int add(int x,int y){
	x+=y;
	return (x<mod ? x : x-mod);
}
int qpow(int n,int m){
	int s=n,ans=1;
	while (m){
		if (m&1)ans=(ll)ans*s%mod;
		s=(ll)s*s%mod,m>>=1;
	}
	return ans;
}
namespace Poly{
	const int N=19;
	int n,tn,inv[1<<N],w[N][1<<N],iw[N][1<<N];
	void init(int g){
		inv[0]=inv[1]=1;
		for(int i=2;i<(1<<N);i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
		for(int i=0;i<N;i++){
			w[i][0]=iw[i][0]=1;
			w[i][1]=qpow(g,(mod-1>>i+1));
			iw[i][1]=qpow(w[i][1],mod-2);
			for(int j=2;j<(1<<i);j++){
				w[i][j]=(ll)w[i][1]*w[i][j-1]%mod;
				iw[i][j]=(ll)iw[i][1]*iw[i][j-1]%mod;
			}
		}
	}
	void get_n(int m){
		n=1,tn=0;
		while (n<m)n<<=1,tn++;
	}
	void dft(int *a){
		for(int i=n,t=0;i>1;i>>=1,t++){
			int *W=w[tn-t-1];
			for(int j=0;j<n;j+=i)
				for(int k=0;k<(i>>1);k++){
					int x=a[j+k],y=a[j+k+(i>>1)];
					a[j+k]=add(x,y);
					a[j+k+(i>>1)]=(ll)(x-y+mod)*W[k]%mod;
				}
		}
	}
	void idft(int *a){
		for(int i=2,t=0;i<=n;i<<=1,t++){
			int *W=iw[t];
			for(int j=0;j<n;j+=i)
				for(int k=0;k<(i>>1);k++){
					int x=a[j+k],y=(ll)W[k]*a[j+k+(i>>1)]%mod;
					a[j+k]=add(x,y),a[j+k+(i>>1)]=add(x,mod-y); 
				}
		}
		int inv=qpow(n,mod-2);
		for(int i=0;i<n;i++)a[i]=(ll)inv*a[i]%mod;
	}
	vi mul(vi a,vi b,int ma,int mb,int m){
		if (ma<0)ma=a.size();
		if (mb<0)mb=b.size();
		if (m<0)m=ma+mb-1;
		ma=min(ma,m),mb=min(mb,m);
		get_n(ma+mb-1);
		a.resize(n),b.resize(n);
		for(int i=ma;i<n;i++)a[i]=0;
		for(int i=mb;i<n;i++)b[i]=0;
		dft(a.data()),dft(b.data());
		for(int i=0;i<n;i++)a[i]=(ll)a[i]*b[i]%mod;
		idft(a.data()),a.resize(m);
		return a;
	}
	vi get_inv(vi a,int m){
		if (m==1)return vi{qpow(a[0],mod-2)};
		vi s=get_inv(a,(m+1>>1)),ans;
		get_n(m<<1);
		a.resize(n),s.resize(n),ans.resize(n);
		for(int i=m;i<n;i++)a[i]=0;
		dft(a.data()),dft(s.data());
		for(int i=0;i<n;i++)ans[i]=(ll)s[i]*(mod+2-(ll)a[i]*s[i]%mod)%mod;
		idft(ans.data()),ans.resize(m);
		return ans;
	}
	vi get_ln(vi a,int m){
		if (m==1)return vi{0};
		vi ans(m-1);
		for(int i=1;i<m;i++)ans[i-1]=(ll)i*a[i]%mod;
		ans=mul(ans,get_inv(a,m),m-1,m,m);
		for(int i=m-1;i;i--)ans[i]=(ll)inv[i]*ans[i-1]%mod;
		ans[0]=0;
		return ans;
	}
	vi get_exp(vi a,int m){
		if (m==1)return vi{1};
		vi s=get_exp(a,(m+1>>1)),ans;
		s.resize(m),ans=get_ln(s,m);
		ans[0]=add(1,mod-ans[0]);
		for(int i=1;i<m;i++)ans[i]=add(a[i],mod-ans[i]);
		return mul(s,ans,(m+1>>1),m,m);
	}
	vi get_pow(vi a,int m,int k){
		if (!k){
			for(int i=0;i<m;i++)a[i]=(!i);
			return a;
		}
		int t=0;
		while ((t<m)&&(!a[t]))t++;
		if ((ll)t*k>=m)return vi(m,0);
		int s1=qpow(a[t],mod-2),s2=qpow(a[t],k);
		for(int i=t;i<m;i++)a[i-t]=(ll)s1*a[i]%mod;
		a=get_ln(a,m-t);
		for(int i=0;i<m-t;i++)a[i]=(ll)k*a[i]%mod;
		a=get_exp(a,m-t),a.resize(m);
		t*=k;
		for(int i=m-1;i>=t;i--)a[i]=(ll)s2*a[i-t]%mod;
		for(int i=0;i<t;i++)a[i]=0;
		return a;
	}
};
int n,m,k,ans;
int main(){
	Poly::init(3);
	scanf("%d%d",&n,&k);
	m=(n+1>>1);
	if (k>m){puts("0");return 0;}
	if ((n==1)&&(k==1)){puts("1");return 0;}
	int s=1;vi v0(n,0),v1(n,0);
	for(int i=1;i<m;i++){
		s=(ll)s*i%mod;
		v0[i]=s,v1[i]=(ll)s*i%mod;
	}
	vi v=Poly::mul(Poly::get_pow(v0,n,k),Poly::get_pow(v1,n,m-k),n,n,n);
	ans=2LL*s*v[n-1]%mod,s=1;
	for(int i=1;i<=k;i++)s=(ll)s*i%mod;
	for(int i=1;i<=m-k;i++)s=(ll)s*i%mod;
	ans=(ll)ans*qpow(s,mod-2)%mod;
	printf("%d\n",ans);
	return 0;
}
posted @ 2023-07-01 19:22  PYWBKTDA  阅读(130)  评论(0编辑  收藏  举报