【XSY3927】二叉树(多项式加速DP,递推)

题面

二叉树

题解

\(f_n\) 表示叶子数为 \(n\) 的答案,容易得出以下式子:

\[\begin{aligned} &f_0=0,f_1=1\\ &f_n=\sum_{i=1}^n c_if_if_{n-i} \end{aligned} \]

其中 \(c_{s_i}=v_i\),其余的 \(c_i\) 均为 \(A\)

注意到 \(k\leq 10\) 很小,\(c_i\) 中只有个别值不同,所以考虑将所有的 \(v_i\) 减去 \(A\),然后得到下面这个递推式:

\[f_n=\sum_{i=1}^nAf_if_{n-i}+\sum_{i=1}^kv_if_{s_i}f_{n-s_i} \]

\(f_i\) 的生成函数为 \(F(x)\),设 \(P(x)=\sum\limits_{i=1}^kv_if_{s_i}x^{s_i}\),那么有:

\[F(x)\equiv AF(x)^2+P(x)F(x)+x\pmod {x^{n+1}} \]

\(+x\) 是因为 \(AF(x)^2+P(x)F(x)\) 只计算了 \(n\geq 2\) 的系数,需要初始化 \(f_1=1\)

注意到 \(F(x)\) 的零次项为 \(0\),所以 \(P(x)\)\(n\) 次项系数对 \(F(x)\)\(n\) 次项系数没有贡献。所以我们只保留 \(P(x)\)\(n-1\) 次项系数再代入等号右边的 \(P(x)\) 其实是等价的。其本质就是递推。

所以我们记 \(P'(x)=P(x) \bmod x^n\),也会有:

\[F(x)\equiv AF(x)^2+P'(x)F(x)+x\pmod{x^{n+1}} \]

解得:

\[F(x)\equiv \dfrac{1-P'(x)\pm\sqrt{\big(1-P'(x)\big)^2-4Ax}}{2A}\pmod {x^{n+1}} \]

\(Q(x)=\big(1-P'(x)\big)^2-4Ax\)\(G(x)=\sqrt{Q(x)}\)。注意到 \(q_0=1\),那么 \(g_0=1\)。又由于 \(F(x)\) 常数项为 \(0\),所以应该取负号,故:

\[F(x)\equiv \dfrac{1-P'(x)-\sqrt{\big(1-P'(x)\big)^2-4Ax}}{2A}\pmod {x^{n+1}} \]

注意这条式子里 \(p_n\) 看似会对 \(f_n\) 的取值有影响,但我们推的式子是正确的,说明 \(p_n\) 实际上被抵消掉了,它对 \(f_n\) 的取值没有影响。

我们只需要得到 \(F(x)\)\(n\) 次项,那我们就需要 \(g_n\),考虑推导 \(G(x)=\sqrt{Q(x)}\) 实现快速算 \(g_n\),两边求导得:

\[\begin{aligned} G'(x)&=\dfrac{Q'(x)}{2\sqrt{Q(x)}}\\ G'(x)Q(x)&=\dfrac{Q'(x)\sqrt{Q(x)}}{2}=\dfrac{Q'(x)G(x)}{2} \end{aligned} \]

提取 \(x^n\) 的系数:

\[\begin{aligned} \sum_{i=0}^n(n-i+1)g_{n-i+1}q_i&=\dfrac{1}{2}\sum_{i=0}^n(i+1)q_{i+1}g_{n-i}\\ (n+1)g_{n+1}q_0+\sum_{i=1}^n(n-i+1)g_{n-i+1}q_i&=\dfrac{1}{2}\sum_{i=1}^{n+1}iq_ig_{n-i+1}\\ (n+1)g_{n+1}&=\sum_{i=1}^{n+1}\dfrac{1}{2}iq_ig_{n-i+1}-(n-i+1)g_{n-i+1}q_i\\ &=\sum_{i=1}^{n+1}q_ig_{n-i+1}\left(\dfrac{3}{2}i-n-1\right) \end{aligned} \]

所以 \(ng_n=\sum\limits_{i=1}^nq_ig_{n-i}\left(\dfrac{3}{2}i-n\right)\)

注意到 \(Q(x)\) 只有 \(O(k^2)\) 项有值,所以如果知道 \(g_1\sim g_{n-1}\)\(g_n\) 就可以暴力算。

那么我们考虑递推:

  1. 假设我们已经知道了 \(P'(x)=P(x) \bmod {x^{n}}\),即已经知道了 \(p_1\sim p_{n-1}\)
  2. 我们用 \(P'(x)\) 暴力计算出 \(Q(x)\),那么我们就知道了 \(q_1\sim q_{n-1}\)\(q'_n\)。(单次时间复杂度 \(O(k^2\log k^2)\)
  3. 利用 \(q_1\sim q_{n-1}\)\(q'_n\) 计算出 \(g'_n\),再通过 \(g'_n\) 得到 \(f_n\)。(单次时间复杂度 \(O(k^2)\)
  4. 通过 \(f_n\) 更新 \(P(x)\),然后得到 \(p_1\sim p_n\),注意记得更新 \(q_n\)\(g_n\)。(单次时间复杂度 \(O(k^2\log k^2)\)

\(q'_n\)\(g_n'\) 的意思是它们并不是真正的 \(q_n\)\(g_n\),但是通过 \(q'_n\)\(g'_n\) 也能算出 \(f_n\),记得最后要用 \(f_n\) 重新得到真正的 \(q_n\)\(g_n\)

注意 \(P(x)\) 只有 \(O(k)\) 次更新,所以上述 2,4 步骤实际上只会执行 \(O(k)\) 次。

所以总时间复杂度为 \(O(nk^2+k^3\log k^2)\)

感觉这道题还是有点绕的,需要自己手推。

代码如下:

#include<bits/stdc++.h>

#define K 15
#define N 1000010
#define re register

using namespace std;

namespace modular
{
	const int mod=1000000007,inv2=500000004;
	inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
	inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
	inline int mul(int x,int y){return 1ll*x*y%mod;}
	inline void Add(int &x,int y){x=(x+y>=mod?x+y-mod:x+y);}
	const int cc=mul(3,inv2);
}using namespace modular;

inline int poww(int a,int b)
{
	int ans=1;
	while(b)
	{
		if(b&1) ans=mul(ans,a);
		a=mul(a,a);
		b>>=1;
	}
	return ans;
}

inline int read()
{
	int x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9')
	{
		if(ch=='-') f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=(x<<1)+(x<<3)+(ch^'0');
		ch=getchar();
	}
	return x*f;
}

struct data
{
	int p,v;
	data(){};
	data(int a,int b){p=a,v=b;}
};

typedef vector<data> poly;

int n,k,A,val[N];
int f[N],g[N];

poly p,q;

inline void work(const poly &a,poly &ans)
{
	static map<int,int>mp;
	mp.clear();
	mp[1]=dec(0,mul(4,A));
	for(int i=0,sa=a.size();i<sa;i++)
		for(int j=0;j<sa;j++)
			Add(mp[a[i].p+a[j].p],mul(a[i].v,a[j].v));
	ans.clear();
	for(map<int,int>::iterator it=mp.begin();it!=mp.end();it++)
		ans.push_back(data(it->first,it->second));
}

inline void getg(int n)
{
	int ans=0;
	for(re int i=0,s=q.size();i<s;i++)
	{
		if(q[i].p<1) continue;
		if(q[i].p>n) break;
		ans=add(ans,mul(dec(mul(q[i].p,cc),n),mul(q[i].v,g[n-q[i].p])));
	}
	g[n]=mul(ans,poww(n,mod-2));
}

int main()
{
	n=read(),k=read(),A=read();
	memset(val,-1,sizeof(val));
	for(int i=1;i<=k;i++)
	{
		int s=read(),v=read();
		val[s]=dec(v,A);
	}
	f[1]=g[0]=1;
	p.push_back(data(0,dec(0,1)));
	if(~val[1]) p.push_back(data(1,val[1]));
	work(p,q);
	getg(1);
	const int c3=poww(mul(2,A),mod-2);
	for(re int now=2;now<=n;now++)
	{
		getg(now);
		f[now]=mul(dec(0,g[now]),c3);
		if(~val[now])
		{
			p.push_back(data(now,mul(val[now],f[now])));
			work(p,q);
			getg(now);
		}
	}
	printf("%d\n",f[n]);
	return 0;
}
/*
5 1 1
2 2
*/
posted @ 2022-10-30 14:26  ez_lcw  阅读(23)  评论(0编辑  收藏  举报