【UOJ424】count(笛卡尔树,DP,生成函数,矩阵快速幂)

首先可以发现两个序列 \(A,B\) 同构当且仅当它们的笛卡尔树同构。

那么可以考虑枚举笛卡尔树,然后判断它能否构成满足题目条件的序列。

发现一棵笛卡尔树满足条件当且仅当它有 \(n\) 个节点(废话),而且它的最长左链长度不超过 \(m-1\)

定义一条链的左链长度为这条链上向左的边数,一棵树的最长左链长度为这棵树所有从根到叶子的链的左链长度的最大值。

那么就可以 DP 了:设 \(f_{n,m}\) 表示 \(n\) 个点、最长左链长度+1不超过 \(m\) 的笛卡尔树个数。

这里定义状态的时候把最长左链长度+1是为了方便初始状态定义,因为理论上只有 \(1\) 个点的笛卡尔树的最长左链长度为 \(0\),只有 \(0\) 个点的笛卡尔树的最长左链长度为 \(-1\)

枚举根节点的左子树大小,容易得到转移:

\[\begin{aligned} f_{n,0}&=[n=0]\\ f_{n,1}&=1\\ f_{n,m}&=\sum_{i=0}^{n-1}f_{i,m-1}f_{n-1-i,m} \end{aligned} \]

暴力转移是 \(O(n^3)\) 的。

考虑优化,设 \(F_m\)\(f_{*,m}\) 的生成函数,有:

\[\begin{aligned} F_m&=xF_{m-1}F_m\\ F_m&=\dfrac{1}{1-xF_{m-1}} \end{aligned} \]

初始状态:

\[\begin{aligned} F_0&=1\\ F_1&=\dfrac{1}{1-x} \end{aligned} \]

于是得到了 \(O(nm\log n)\) 的做法。

观察递推式,发现可以写成矩阵快速幂的形式:设 \(F_{m-1}=\dfrac{A}{B}\),那么 \(F_{m}=\dfrac{B}{B-xA}\)

于是有了 \(O(n\log n\log m)\) 的做法。

进一步地,设 \(F_m=\dfrac{A_m}{B_m}\),可以归纳证明 \(A_m\)\(m-1\) 次多项式,\(B_m\)\(m\) 次多项式。

这意味着我们可以先 NTT 一遍,把 \(O(n)\) 个点值带进去求矩阵快速幂,再把求出来的点值 NTT 回来,而这过程中不会发生循环卷积的 BUG。

时间复杂度 \(O(n\log n+n\log m)\)

#include<bits/stdc++.h>

#define LN 20
#define N 100010

using namespace std;

namespace modular
{
	const int mod=998244353;
	inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
	inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
	inline int mul(int x,int y){return 1ll*x*y%mod;}
	inline void Add(int &x,int y){x=x+y>=mod?x+y-mod:x+y;}
}using namespace modular;

inline int poww(int a,int b)
{
	int ans=1;
	while(b)
	{
		if(b&1) ans=mul(ans,a);
		a=mul(a,a);
		b>>=1;
	}
	return ans;
}

inline int read()
{
	int x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9')
	{
		if(ch=='-') f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=(x<<1)+(x<<3)+(ch^'0');
		ch=getchar();
	}
	return x*f;
}

int n,m;
int A[N<<2],B[N<<2],C[N<<2],D[N<<2],invD[N<<2];
int rev[N<<2],w[LN][N<<2][2];

void init(int limit)
{
	for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
	{
		int len=mid<<1;
		int gn=poww(3,(mod-1)/len);
		int ign=poww(gn,mod-2);
		int g=1,ig=1;
		for(int j=0;j<mid;j++,g=mul(g,gn),ig=mul(ig,ign))
			w[bit][j][0]=g,w[bit][j][1]=ig;
	}
}

void NTT(int *a,int limit,int opt)
{
	opt=(opt<0);
	for(int i=0;i<limit;i++)
		rev[i]=(rev[i>>1]>>1)|((i&1)*(limit>>1));
	for(int i=0;i<limit;i++)
		if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
	{
		for(int i=0,len=mid<<1;i<limit;i+=len)
		{
			for(int j=0;j<mid;j++)
			{
				int x=a[i+j],y=mul(w[bit][j][opt],a[i+mid+j]);
				a[i+j]=add(x,y),a[i+mid+j]=dec(x,y);
			}
		}
	}
	if(opt)
	{
		int tmp=poww(limit,mod-2);
		for(int i=0;i<limit;i++)
			a[i]=mul(a[i],tmp);
	}
}

void getinv(int *f,int *g,int n)
{
    static int ff[N<<2];
    assert(f[0]);
    g[0]=poww(f[0],mod-2);
    int now=2;
    for(;now<(n<<1);now<<=1)
    {
        int limit=now<<1;
        for(int i=0;i<now;i++) ff[i]=f[i];
        NTT(ff,limit,1),NTT(g,limit,1);
        for(int i=0;i<limit;i++)
            g[i]=mul(g[i],dec(2,mul(ff[i],g[i])));
        NTT(g,limit,-1);
        for(int i=now;i<limit;i++) g[i]=0;
    }
    for(int i=n;i<now;i++) g[i]=0;
    for(int i=0;i<now;i++) ff[i]=0;
}

struct Matrix
{
	int a[2][2];
	Matrix(){a[0][0]=a[0][1]=a[1][0]=a[1][1]=0;};
	void init(){a[0][0]=a[1][1]=1;}
}trans[N<<2],st[N<<2];

Matrix mul(Matrix a,Matrix b)
{
	Matrix c;
	for(int i=0;i<2;i++)
		for(int j=0;j<2;j++)
			for(int k=0;k<2;k++)
				Add(c.a[i][j],mul(a.a[i][k],b.a[k][j]));
	return c;
}

Matrix poww(Matrix a,int b)
{
	Matrix ans;
	ans.init();
	while(b)
	{
		if(b&1) ans=mul(ans,a);
		a=mul(a,a);
		b>>=1;
	}
	return ans;
}

int main()
{
	n=read(),m=read();
	if(n<m)
	{
		puts("0");
		return 0;
	}
	int limit=1;
	while(limit<=(n<<1)) limit<<=1;
	init(limit);
	A[1]=1;
	NTT(A,limit,1);
	B[0]=1;
	NTT(B,limit,1);
	for(int i=0;i<limit;i++)		
	{
		trans[i].a[0][0]=0,trans[i].a[0][1]=1,trans[i].a[1][0]=dec(0,A[i]),trans[i].a[1][1]=1;
		st[i].a[0][0]=st[i].a[1][0]=B[i];
		Matrix res=mul(poww(trans[i],m),st[i]);
		C[i]=res.a[0][0],D[i]=res.a[1][0];
	}
	NTT(D,limit,-1);
	getinv(D,invD,n+1);
	NTT(invD,limit,1);
	for(int i=0;i<limit;i++) C[i]=mul(C[i],invD[i]);
	NTT(C,limit,-1);
	printf("%d\n",C[n]);
	return 0;
}
posted @ 2022-10-30 10:08  ez_lcw  阅读(127)  评论(0编辑  收藏  举报