CF623E Transforming Sequence 题解
【分析】
不难想到,令 \(f_{n, k}\) 表示前 \(n\) 个数,使得 \(b_n\) 有 \(k\) 个 \(1\) 的方案数
则很容易列出转移方程,由于 \(a_i>0\) ,故 \(\displaystyle f_{n, k}=\sum_{i=0}^{k-1} \binom k if_{n-1, i}\cdot 2^i\)
变形得到 \(\displaystyle {f_{n, k}\over k!}x^k=\sum_{i+j=k} {f_{n-1,i}\cdot 2^i\over i!}x^i\cdot {[j>0]\over j!}x^j\)
因此得到生成函数的转移关系 \(\displaystyle \hat F_{n}(x)=\hat F_{n-1}(2x)\cdot (e^x-1)\)
而由 \(f_{0, k}=[k=0]\) 得到 \(\hat F_0(x)=1\)
列举几项得到 \(\hat F_1(x)=e^x-1, \hat F_2(x)=(e^{2x}-1)(e^x-1), \hat F_3(x)=(e^{4x}-1)(e^{2x}-1)(e^x-1)\)
猜想 \(\displaystyle \hat F_n(x)=\prod_{i=0}^{n-1} (e^{2^ix}-1)\) ,若对 \(n\leq m\) 时成立
则 \(\displaystyle \hat F_{m+1}(x)=\hat F_m(x)(2x)\cdot (e^x-1)=\prod_{i=0}^{m-1}(e^{2^i\cdot 2x}-1)\cdot (e^x-1)=\prod_{i=0}^m (e^{2^ix}-1)\)
由数学归纳法证明猜想正确
设答案为 \(A_{n, k}\) ,则 \(\displaystyle A_{n, k}=\sum_{i=0}^k \binom k i f_{n, i}\Rightarrow \hat A_n(x)=\hat F_n(x)\cdot e^x\)
问题化为如何求解 \(\hat F_n(x)\) 不超时
考虑到 \(\displaystyle \hat F_{2n}(x)=\prod_{i=0}^{2n-1}(e^{2^ix}-1)=\prod_{i=0}^{n-1}(e^{2^ix}-1)\cdot \prod_{i=n}^{2n-1}(e^{2^ix}-1)=\prod_{i=0}^{n-1}(e^{2^ix}-1)\cdot \prod_{i=0}^{n-1}(e^{2^{i+n}x}-1)=\hat F_n(x)\cdot \hat F_n(2^n x)\)
而 \(\displaystyle F(x)=\sum_{i=0}^\infty f_ix^i\Rightarrow F(Ax)=\sum_{i=0}^\infty f_iA^i\cdot x^i\)
故已知 \(\hat F_n(x)\) ,可以先 \(O(n)\) 求解 \(\hat F(2^n x)\) ,再多项式卷积 \(O(n\log n)\) 求出 \(\hat F_{2n}(x)\)
因而求解 \(\hat F_n(x)\) 时,先递归求解 \(\hat F_{n/2}(x)\) ,若 \(n\) 为奇数,再利用递推式 \(O(n\log n)\) 算出结果
最后再卷上 \(e^x\) 即可得到答案的生成函数
总复杂度 \(\displaystyle T(n)=T({n\over 2})+O(n\log n)=O(n\log n)\) ,常数约为 \(20\) 次 FFT
【代码】
由于需要任意模数,考虑使用 MTT 。但这题貌似精度需求非常高,用 double 会算出问题,得 long double
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef vector<int> vi;
typedef pair<int, int> pii;
typedef long double db;
#define fi first
#define se second
struct vir{
db r, i;
vir(db r_=0, db i_=0):r(r_), i(i_) {}
};
inline vir operator + (const vir &a, const vir &b) { return vir(a.r+b.r, a.i+b.i); }
inline vir operator - (const vir &a, const vir &b) { return vir(a.r-b.r, a.i-b.i); }
inline vir operator * (const vir &a, const vir &b) { return vir(a.r*b.r-a.i*b.i, a.r*b.i+a.i*b.r); }
inline vir operator / (const vir &a, db b) { return vir(a.r/b, a.i/b); }
inline vir operator ! (const vir &z) { return vir(z.r, -z.i); }
inline vir operator / (const vir &a, const vir &b) { return a*!b/(b*!b); }
const db pi=acos(-1), eps=1e-6;
const int LimBit=18, M=1<<LimBit;
int a[M], b[M], c[M];
struct NTT{
int Stk[M], curStk;
int N, na, nb, P, m, N_;
inline ll kpow(ll a,ll x) { ll ans=1; for(;x;x>>=1,a=a*a%P) if(x&1) ans=ans*a%P; return ans; }
NTT(int P_=1e9+7):P(P_){
m=1<<15;
curStk=0;
N_=-1;
}
vir w[2][M], invN;
int rev[M];
void work(){
if(N_==N) return ; N_=N;
int d = __builtin_ctz(N);
w[0][0] = w[1][0] = vir(1, 0);
vir x = vir(cos(2*pi/N), sin(2*pi/N)), y = !x;
for (int i = 1; i < N; i++) {
rev[i] = (rev[i>>1] >> 1) | ((i&1) << (d-1));
w[0][i] = x * w[0][i-1] , w[1][i] = y * w[1][i-1];
}
invN=vir(1.0/N, 0);
}
inline void fft(vir *a,int f){
static vir x;
for(int i=0;i<N;++i) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int i=1;i<N;i<<=1)
for(int j=0,t=N/(i<<1);j<N;j+=i<<1)
for(int k=0,l=0;k<i;++k,l+=t)
x=w[f][l]*a[j+k+i], a[j+k+i]=a[j+k]-x, a[j+k]=a[j+k]+x;
if(f) for(int i=0;i<N;++i) a[i]=a[i]*invN;
}
inline int mergeNum(ll a1,ll a0,ll b1,ll b0){
// return ((a1*m+a0+b1)%p*m+b0+p)%p;
return (( ((a1<<15)+a0+b1)%P<<15 )+b0)%P;
}
inline void doit(int *a,int *b,int na,int nb){
static vir x, y, p0[M], p1[M];
for(N=1;N<na+nb-1;N<<=1);
for(int i=na;i<N;++i) a[i]=0;
for(int i=nb;i<N;++i) b[i]=0;
for(int i=0;i<N;++i)
// p1[i]=pdd(a[i]/m,a[i]%m),p0[i]=pdd(b[i]/m,b[i]%m);
p1[i]=vir(a[i]>>15, a[i]&(m-1)), p0[i]=vir(b[i]>>15, b[i]&(m-1));
work(); fft(p1,0); fft(p0,0);
for(int i=0,j;i+i<=N;++i){
j=(N-1)&(N-i);
x=p0[i], y=p0[j];
p0[i]=(x-!y)*p1[i]*vir(0, -0.5);
p1[i]=(x+!y)*p1[i]*vir(0.5, 0);
if(i==j) continue;
p0[j]=(y-!x)*p1[j]*vir(0, -0.5);
p1[j]=(y+!x)*p1[j]*vir(0.5, 0);
}
fft(p1,1); fft(p0,1);
for(int i=0;i<N;++i)
a[i]=mergeNum(p1[i].r+0.5, p1[i].i+0.5, p0[i].r+0.5, p0[i].i+0.5);
}
}ntt;
ll n;
int k, f[M];
inline void init() {
b[1]=1;
for(int i=2, P=ntt.P; i<=3e4; ++i)
b[i]=P-(ll)b[P%i]*(P/i)%P;
for(int i=2, P=ntt.P; i<=3e4; ++i)
b[i]=(ll)b[i-1]*b[i]%P;
}
inline void work(ll n) {
if(n==0) {
for(int i=0; i<=k; ++i) f[i]=0;
f[0]=1;
return ;
}
work(n>>1);
for(int i=0, x=ntt.kpow(2, n>>1), y=1, P=ntt.P; i<=k; ++i)
a[i]=(ll)f[i]*y%P, y=(ll)y*x%P;
ntt.doit(f, a, k+1, k+1);
if(n+1&1) return ;
for(int i=0, x=2, y=1, P=ntt.P; i<=k; ++i)
f[i]=(ll)f[i]*y%P, y=(ll)y*x%P;
ntt.doit(f, b, k+1, k+1);
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
init();
cin>>n>>k;
work(n);
++b[0];
ntt.doit(f, b, k+1, k+1);
for(int i=0, x=1, P=ntt.P; i<=k; ++i)
f[i]=(ll)f[i]*x%P, x=(ll)x*(i+1)%P;
cout<<f[k];
cout.flush();
return 0;
}