UOJ633 【UR #21】你将如闪电般归来 【ODE,整式递推】
设 \(F_k(x)\) 表示答案的生成函数,则 \(\displaystyle F_k(x)=\int F_{k-1}(x)\cosh x\text{ d}x\),\(F_0(x)=x\)。
设 \(\mathcal F(x,t):=\sum_{k\ge 0}F_k(x)t^k\),则 \(\dfrac{\partial}{\partial x}\mathcal F=1+t\mathcal F\cosh x\),解得 \(\displaystyle\mathcal F=\text e^{t\sinh x}\int\text e^{-t\sinh x}\text{ d}x\),从而 \(\displaystyle[t^k]\mathcal F=\sum_{i=0}^k\frac{(-1)^i}{i!(k-i)!}(\sinh x)^{k-i}\int(\sinh x)^i\text{ d}x\) 可以表示为 \(\text e^{ax}\) 和 \(x\text e^{ax}\) 的线性组合,其中 \(-k\le a\le k\)。
换元 \(G_k(\sinh x):=F_k(x)\),则 \(\displaystyle G_k(x)=\int G_{k-1}(x)\text{ d}x\),\(G_0(x)\) 是 \(\sinh^{-1} x:=\ln(x+\sqrt{1+x^2})\)。
接下来是经典的 D-Finite 复合有理分式 \(\implies\) 整式递推。
考虑求出 \(G_k\) 的 ODE:已知 \(G_0'(x)=(1+x^2)^{-1/2}\),设 \(a_n=[x^n]G_0'(x)\),则 \(na_n=(1-n)a_{n-2}\),设 \(b_n=[x^n]G_k(x)=(n-k-1)!a_{n-k-1}/n!\),则 \(n(n-1)b_n+(n-k-2)^2b_{n-2}=[n=k+1]/(k-1)!\),从而 \((1+x^2)G_k''(x)+(1-2k)xG_k'(x)+k^2G_k(x)=x^{k-1}/(k-1)!\)。
设 \(y:=\dfrac{x-x^{-1}}{2}\),\(P(x):=G_k(y)\),则由 \(P'(x)=\dfrac{1+x^{-2}}2\cdot G_k'(y)\) 和 \(P''(x)=\left(\dfrac{1+x^{-2}}2\right)^2G_k''(y)-x^{-3}G_k'(y)\) 解得 \(G_k'(y)=\dfrac{2}{1+x^{-2}}\cdot P'(x)\) 和 \(G''_k(y)=\left(\dfrac 2{1+x^{-2}}\right)^2P''(x)+x^{-3}\left(\dfrac 2{1+x^{-2}}\right)^3P'(x)\),代入 \(G_k\) 的 ODE 得到 \(k^2 (1+x^2) P(x) + ((1+2k)x + (1-2k)x^3) P'(x) + ( x^2+x^4) P''(x) = \dfrac{y^{k-1}}{(k-1)!} \cdot (1+x^2)\)。
由 \(P(x)=F_k(\ln x)\) 以及先前推导知 \(P(x)\) 是 \(x^n\) 和 \(x^n\ln x\) 的线性组合,设 \(f_n=[x^n]P(x)\),\(g_n=[x^n\ln x]P(x)\),在 \(P(x)\) 的 ODE 中取 \([x^n\ln x]\) 得到 \((n+k)^2g_n+(n-k-2)^2g_{n-2}=0\),取 \([x^n]\) 得到 \((n+k)^2 f_n + (n-k-2)^2 f_{n-2} + 2(n+k)g_n + 2(n-k-2)g_{n-2} = [x^n] \dfrac{y^{k-1}}{(k-1)!}\cdot(1+x^2)\),由 \([t^k]\mathcal F\) 的柿子知边界情况 \(g_k=\dfrac 1{2^kk!}\) 和 \(f_k=-\dfrac{H_k}{2^kk!}\)。进一步观察可以知道 \(i\) 与 \(-i\) 次项对答案的贡献相同可以卡卡常。
时间复杂度 \(\mathcal O(k\log_kp)\)。
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 10000010, M = 666666, mod = 998244353;
int ksm(int a, int b){
int res = 1;
for(;b;b >>= 1, a = (LL)a * a % mod)
if(b & 1) res = (LL)res * a % mod;
return res;
}
int SQ(int x){return (LL)x * x % mod;}
int k, fac[N], ifac[N], inv[N], pw[N], f[N], g[N], pri[M], tot, ans;
LL n;
bitset<N> notp;
int main(){
ios::sync_with_stdio(false);
cin >> k >> n; -- k;
if(!k){printf("%d\n", n == 1); return 0;}
notp.set(0); notp.set(1); pw[1] = *fac = 1;
for(int i = 2;i <= k;++ i){
if(!notp.test(i)){pri[tot++] = i; pw[i] = ksm(i, (n - 1) % (mod - 1));}
for(int j = 0;j < tot && i * pri[j] <= k;++ j){
notp.set(i * pri[j]);
pw[i * pri[j]] = (LL)pw[i] * pw[pri[j]] % mod;
if(!(i % pri[j])) break;
}
}
n %= mod;
for(int i = 1;i <= k;++ i) fac[i] = (LL)fac[i - 1] * i % mod;
ifac[k] = ksm(fac[k], mod - 2);
for(int i = k;i;-- i){
ifac[i - 1] = (LL)ifac[i] * i % mod;
inv[i] = (LL)ifac[i] * fac[i - 1] % mod;
}
int pw2k = ksm(2, mod - 1 - k);
f[k] = (LL)pw2k * ifac[k] % mod;
for(int i = k - 2;i >= 0;i -= 2)
f[i] = mod - (LL)f[i + 2] * SQ(inv[k - i] * (i + 2ll + k) % mod) % mod;
for(int i = k;i >= 1;i -= 2)
ans = (ans + (LL)f[i] * pw[i]) % mod;
ans = (LL)ans * n % mod;
for(int i = 0;i <= (k - 1 >> 1);++ i)
g[k - 1 - i * 2] = 2ll * ((i & 1) ? mod - pw2k : pw2k) % mod * ifac[i] % mod * ifac[k - 1 - i] % mod;
for(int i = k + 1;i >= 2;i -= 2){
g[i] += g[i - 2];
if(g[i] >= mod) g[i] -= mod;
}
for(int i = 2;i <= k + 1;++ i)
g[i] = (g[i] + 2ll * ((k + 2ll - i) * f[i - 2] + ((LL)mod - k - i) * f[i])) % mod;
f[k] = 0;
for(int i = 1;i <= k;++ i){
f[k] += inv[i];
if(f[k] >= mod) f[k] -= mod;
}
f[k] = mod - (LL)f[k] * pw2k % mod * ifac[k] % mod;
for(int i = k - 1;i >= 0;-- i)
f[i] = (g[i + 2] - (LL)f[i + 2] * SQ(i + 2 + k) % mod + mod) * SQ(inv[k - i]) % mod;
for(int i = k;i >= 1;-- i)
ans = (ans + (LL)f[i] * pw[i] % mod * i) % mod;
ans <<= 1; if(ans >= mod) ans -= mod;
printf("%d\n", ans);
}