【XSY3927】二叉树(多项式加速DP,递推)
题面
题解
设 \(f_n\) 表示叶子数为 \(n\) 的答案,容易得出以下式子:
其中 \(c_{s_i}=v_i\),其余的 \(c_i\) 均为 \(A\)。
注意到 \(k\leq 10\) 很小,\(c_i\) 中只有个别值不同,所以考虑将所有的 \(v_i\) 减去 \(A\),然后得到下面这个递推式:
设 \(f_i\) 的生成函数为 \(F(x)\),设 \(P(x)=\sum\limits_{i=1}^kv_if_{s_i}x^{s_i}\),那么有:
(\(+x\) 是因为 \(AF(x)^2+P(x)F(x)\) 只计算了 \(n\geq 2\) 的系数,需要初始化 \(f_1=1\))
注意到 \(F(x)\) 的零次项为 \(0\),所以 \(P(x)\) 的 \(n\) 次项系数对 \(F(x)\) 的 \(n\) 次项系数没有贡献。所以我们只保留 \(P(x)\) 的 \(n-1\) 次项系数再代入等号右边的 \(P(x)\) 其实是等价的。其本质就是递推。
所以我们记 \(P'(x)=P(x) \bmod x^n\),也会有:
解得:
令 \(Q(x)=\big(1-P'(x)\big)^2-4Ax\),\(G(x)=\sqrt{Q(x)}\)。注意到 \(q_0=1\),那么 \(g_0=1\)。又由于 \(F(x)\) 常数项为 \(0\),所以应该取负号,故:
注意这条式子里 \(p_n\) 看似会对 \(f_n\) 的取值有影响,但我们推的式子是正确的,说明 \(p_n\) 实际上被抵消掉了,它对 \(f_n\) 的取值没有影响。
我们只需要得到 \(F(x)\) 的 \(n\) 次项,那我们就需要 \(g_n\),考虑推导 \(G(x)=\sqrt{Q(x)}\) 实现快速算 \(g_n\),两边求导得:
提取 \(x^n\) 的系数:
所以 \(ng_n=\sum\limits_{i=1}^nq_ig_{n-i}\left(\dfrac{3}{2}i-n\right)\)。
注意到 \(Q(x)\) 只有 \(O(k^2)\) 项有值,所以如果知道 \(g_1\sim g_{n-1}\),\(g_n\) 就可以暴力算。
那么我们考虑递推:
- 假设我们已经知道了 \(P'(x)=P(x) \bmod {x^{n}}\),即已经知道了 \(p_1\sim p_{n-1}\)。
- 我们用 \(P'(x)\) 暴力计算出 \(Q(x)\),那么我们就知道了 \(q_1\sim q_{n-1}\) 和 \(q'_n\)。(单次时间复杂度 \(O(k^2\log k^2)\))
- 利用 \(q_1\sim q_{n-1}\) 和 \(q'_n\) 计算出 \(g'_n\),再通过 \(g'_n\) 得到 \(f_n\)。(单次时间复杂度 \(O(k^2)\))
- 通过 \(f_n\) 更新 \(P(x)\),然后得到 \(p_1\sim p_n\),注意记得更新 \(q_n\) 和 \(g_n\)。(单次时间复杂度 \(O(k^2\log k^2)\))
\(q'_n\) 和 \(g_n'\) 的意思是它们并不是真正的 \(q_n\) 和 \(g_n\),但是通过 \(q'_n\) 和 \(g'_n\) 也能算出 \(f_n\),记得最后要用 \(f_n\) 重新得到真正的 \(q_n\) 和 \(g_n\)。
注意 \(P(x)\) 只有 \(O(k)\) 次更新,所以上述 2,4 步骤实际上只会执行 \(O(k)\) 次。
所以总时间复杂度为 \(O(nk^2+k^3\log k^2)\)。
感觉这道题还是有点绕的,需要自己手推。
代码如下:
#include<bits/stdc++.h>
#define K 15
#define N 1000010
#define re register
using namespace std;
namespace modular
{
const int mod=1000000007,inv2=500000004;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
inline void Add(int &x,int y){x=(x+y>=mod?x+y-mod:x+y);}
const int cc=mul(3,inv2);
}using namespace modular;
inline int poww(int a,int b)
{
int ans=1;
while(b)
{
if(b&1) ans=mul(ans,a);
a=mul(a,a);
b>>=1;
}
return ans;
}
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
struct data
{
int p,v;
data(){};
data(int a,int b){p=a,v=b;}
};
typedef vector<data> poly;
int n,k,A,val[N];
int f[N],g[N];
poly p,q;
inline void work(const poly &a,poly &ans)
{
static map<int,int>mp;
mp.clear();
mp[1]=dec(0,mul(4,A));
for(int i=0,sa=a.size();i<sa;i++)
for(int j=0;j<sa;j++)
Add(mp[a[i].p+a[j].p],mul(a[i].v,a[j].v));
ans.clear();
for(map<int,int>::iterator it=mp.begin();it!=mp.end();it++)
ans.push_back(data(it->first,it->second));
}
inline void getg(int n)
{
int ans=0;
for(re int i=0,s=q.size();i<s;i++)
{
if(q[i].p<1) continue;
if(q[i].p>n) break;
ans=add(ans,mul(dec(mul(q[i].p,cc),n),mul(q[i].v,g[n-q[i].p])));
}
g[n]=mul(ans,poww(n,mod-2));
}
int main()
{
n=read(),k=read(),A=read();
memset(val,-1,sizeof(val));
for(int i=1;i<=k;i++)
{
int s=read(),v=read();
val[s]=dec(v,A);
}
f[1]=g[0]=1;
p.push_back(data(0,dec(0,1)));
if(~val[1]) p.push_back(data(1,val[1]));
work(p,q);
getg(1);
const int c3=poww(mul(2,A),mod-2);
for(re int now=2;now<=n;now++)
{
getg(now);
f[now]=mul(dec(0,g[now]),c3);
if(~val[now])
{
p.push_back(data(now,mul(val[now],f[now])));
work(p,q);
getg(now);
}
}
printf("%d\n",f[n]);
return 0;
}
/*
5 1 1
2 2
*/