题解:「SNOI2017」礼物
本文记录一种递推优化套路。
可以发现,若将答案表示成 \(\sum_{i=1}^{n} \alpha_i i^k\) 的形式时,系数 \(\alpha_i\) 满足以下性质。
\[\alpha_i = 1 + 1 + 2 + 4 + \dots + 2^{n-i-2}
\]
所以求得:
\[\begin{cases}
\ \alpha_n = \displaystyle{1} & \\
\ \alpha_i = \displaystyle{2^{n-i-1}} & i\neq n\\
\end{cases}
\]
所以题目转化为求式子 \(n^k + \sum_{i=1}^{n-1} 2^{n-i-1} i^k\) 的值。
先转换式子成:
\[n^k + 2^{n-1} \sum_{i=1}^{n-1} \frac{i^k}{2^i}
\]
发现瓶颈在于如何快速求解 \(\mathcal{O}(n)\) 求和式 \(F_n = \sum_{i=1}^{n} \frac{i^k}{2^i}\) 。
为了方便,以下部分默认设 \(t=\frac{1}{2}\) 。
这里用到了一种倍增的方法解题,具体的说,就是通过求解 \(F_{n+1}\) , \(F_{2n}\) ,并利用其与 \(F_n\) 之间的关系进行快速运算。
先处理 \(F_{n+1}\) 与 \(F_n\) 之间的关系。
\[F_{n+1} = F_n + t^{n+1} (n+1)^k
\]
再处理 \(F_{2n}\) 与 \(F_n\) 。
\[\begin{aligned}
F_{2n} - F_{n} & = \sum_{i=n+1}^{2n} t^i i^k \\
& = t^n \sum_{i=1}^{n} t^i (i+n)^k \\
& = t^n \sum_{i=1}^{n} t^i \sum_{j=0}^{k} {k \choose j} i^{k-j}n^j \\
& = t^n \sum_{j=0}^{k} {k\choose j} n^j \sum_{i=1}^n t^i i^{k-j} \\
\end{aligned}\]
(所用芝士:二项式定理,求和顺序变换)
终于推出了与 \(F\) 定义式相似的式子 \(\sum_{i=1}^n t^i i^{k-j}\) ,由于 \(k\) 很小,不妨拓展 \(F\) 的定义。
设 \(F(n, k) = \sum_{i=1}^n t^i i^k\) ,得出倍增递推式:
\[\begin{cases}
\ F(n+1, k) = F(n, k) + t^{n+1} (n+1)^k \\
\ F(2n, k) = F(n, k) + t^n \sum_{j=0}^{k} {k \choose j} n^j F(n, k-j) \\
\end{cases}\]
可以用记忆化搜索解决,用 std::map
维护 \(F\) 数组。
时间复杂度 \(\mathcal{O}(k\log n \log (k \log n))\)
Code(C++):
#include<bits/stdc++.h>
#define forn(i,s,t) for(register int i=(s);i<=(t);++i)
#define form(i,s,t) for(register int i=(s);i>=(t);--i)
#define mkp make_pair
typedef long long LL;
using namespace std;
const int Mod = 1e9+7, T = Mod+1 >> 1, M = 11;
LL fac[M], inv[M];
inline LL Q_pow(LL p, LL k) {
LL Ans = 1; p = p%Mod;
while(k) {if(k & 1) Ans = Ans * p %Mod; p = p*p %Mod, k >>= 1;}
return Ans;
}
/*-------预处理 O(1) 组合数----------*/
inline void init(int k) {
fac[0] = fac[1] = inv[0] = 1;
forn(i,2,k) fac[i] = fac[i-1] * 1ll * i %Mod;
inv[k] = Q_pow(fac[k], Mod - 2);
form(i,k-1,1) inv[i] = inv[i+1] * 1ll * (i+1) %Mod;
}
inline LL C(int n, int m) {return fac[n] * (inv[m] * inv[n - m] %Mod) %Mod;}
map<pair<LL, int>, int> f;
int F(LL n, int k) {
if(n == 1) return T;
if(n == 0) return 0;
if(f.count(mkp(n, k))) return f[mkp(n, k)];
if(n & 1) return f[mkp(n, k)] = (1ll * F(n-1, k) + Q_pow(1ll*T, n) * Q_pow(n, k) %Mod) %Mod;
else {
LL m = n/2, res = 0;
forn(j,0,k) res = (res + C(k, j) * Q_pow(m, j)%Mod * 1ll * F(m, k-j) %Mod) %Mod;
return f[mkp(n, k)] = (F(m, k) + res * Q_pow(T, m) %Mod) %Mod;
}
}
LL n; int k;
int main() {
scanf("%lld%d", &n, &k);
init(k);
printf("%lld\n", (Q_pow(n, k) + Q_pow(2ll, n-1) * 1ll * F(n-1, k)) %Mod);
return 0;
}