[NOI Online 2021 提高组] 愤怒的小N
一道偏结论的题。当 \(n=2^{k+1}\) 时
\[\sum_{i=0,\text{popcount}(i)\equiv 0\pmod 2}^{2^{k+1}-1}i^k=\sum_{i=0,\text{popcount}(i)\equiv 1\pmod 2}^{2^{k+1}-1}i^k
\]
采用归纳法证明(后文为方便以 \(=0/1\) 表示 \(\equiv 0/1\))。当 \(k=0\) 时,结论显然成立;
当 \(k>0\) 时,欲证明结论,先移项:
\[\sum_{i=0,\text{popcount}(i)=0}^{2^{k+1}-1}i^k-\sum_{i=0,\text{popcount}(i)=1}^{2^{k+1}-1}i^k=0
\]
对于前缀相同的二进制 \((\cdots 0)\) 和 \((\cdots 1)\) 均出现,仅最低位不同,考虑这样枚举
\[\sum_{i=0}^{2^k-1}c_i\times \left((2i+1)^k-(2i)^k\right)=0
\]
\(c_i\) 为 \(+1\) 或 \(-1\),取决于前缀 \(i\) 中 1
的个数。不难发现 \(c_i=(-1)^{\text{popcount}(i)}\)
所以得到:
\[\sum_{i=0}^{2^k-1}(-1)^{\text{popcount}(i)}\left((2i+1)^k-(2i)^k\right)=0
\]
展开得到:
\[\sum_{i=0}^{2^k-1}(-1)^{\text{popcount}(i)}\sum_{j=0}^{k-1}{k\choose j}(2i)^j=0
\]
对于 \(\text{popcount}(i)\) 分组,就变成:
\[\sum_{j=0}^{k-1}2^j{k\choose j}\sum_{i=0,\text{popcount}(i)=0}^{2^k-1}i^j=\sum_{j=0}^{k-1}2^j{k\choose j}\sum_{i=0,\text{popcount}(i)=1}^{2^k-1}i^j\ \ (*)
\]
由归纳,当 \(k'<k\) 时:
\[\sum_{i=0,\text{popcount(i)}=0}^{2^{k'+1}-1}i^{k'}=\sum_{i=0,\text{popcount(i)}=1}^{2^{k'+1}-1}i^{k'}
\]
引理:
\[\sum_{i=0,\text{popcount(i)}=0}^{t\times 2^{k'+1}-1}i^{k'}=\sum_{i=0,\text{popcount(i)}=1}^{t\times 2^{k'+1}-1}i^{k'},\ t\in \Z
\]
证明引理,我们只需要证明下式:
\[\sum_{i=t\times 2^{k'+1},\text{popcount}(i)=0}^{(t+1)\times 2^{k'+1}-1}i^{k'}=\sum_{i=t\times 2^{k'+1},\text{popcount}(i)=1}^{(t+1)\times 2^{k'+1}-1}i^{k'}
\]
我们同样可以归纳,这里省略掉部分细节。改写一下:
\[\sum_{i=0,\text{popcount}(i)=0}^{2^{k'+1}-1}(i+t\times 2^{k'+1})^{k'}=\sum_{i=0,\text{popcount}(i)=1}^{2^{k'+1}-1}(i+t\times 2^{k'+1})^{k'}
\]
同样展开:
\[\sum_{i=0,\text{popcount}(i)=0}^{2^{k'+1}-1}\sum_{j=0}^{k'}{k'\choose j}i^j(t\times 2^{k'+1})^{k'-j}=\sum_{i=0,\text{popcount}(i)=1}^{2^{k'+1}-1}\sum_{j=0}^{k'}{k'\choose j}i^j(t\times 2^{k'+1})^{k'-j}
\]
\[\Rightarrow\sum_{j=0}^{k'}(t\times 2^{k'+1})^{k'-j}{k'\choose j}\sum_{i=0,\text{popcount}(i)=0}^{2^{k'+1}-1}i^j=\sum_{j=0}^{k'}(t\times 2^{k'+1})^{k'-j}{k'\choose j}\sum_{i=0,\text{popcount}(i)=1}^{2^{k'+1}-1}i^j
\]
由归纳假设每一项都相等,由此引理得证。
由上,对比每一项,能直接得到 \((*)\) 式成立。故结论成立
推导过程中,我们甚至能知道
\[\sum_{i=0,\text{popcount}(i)=0}^{t\times 2^{k+1}-1}i^k=\sum_{i=0,\text{popcount}(i)=1}^{t\times 2^{k+1}-1}i^k
\]
成立。这样我们就可以将复杂度从 \(\mathcal O(\log_2 nk^2)\) 降成 \(\mathcal O(k^3+..)\) 了。
#include <bits/stdc++.h>
const int N = 500005, K = 505, P = 1e9 + 7, inv2 = P + 1 >> 1;
int n, m, l, a[K], b[N], c[N], ans = 0; char str[N];
int inv[K], fac[K], ifac[K], pw[K], f0[K], f1[K], g0[K], g1[K], y[K];
int C(int n, int m) { return 1LL * fac[n] * ifac[m] % P * ifac[n - m] % P; }
int lagrange(int n, int k) {
for (int i = 0; i <= k + 1; i++) {
y[i] = 1;
for (int j = 1; j <= k; j++)
y[i] = 1LL * y[i] * i % P;
if (i) y[i] = (y[i] + y[i - 1]) % P;
}
int ans = 0;
for (int i = 0; i <= k + 1; i++) {
int tmp = y[i];
for (int j = 0; j <= k + 1; j++)
if (i != j) tmp = 1LL * tmp * (n - j + P) % P * (i > j ? inv[i - j] : P - inv[j - i]) % P;
ans = (ans + tmp) % P;
}
return ans;
}
int main() {
scanf("%s%d", str, &m);
for (int i = 0; i < m; i++) scanf("%d", &a[i]);
n = strlen(str); l = std::min(n, m);
std::reverse(str, str + n);
for (int i = n - 1; ~i; i--)
b[i] = (2 * b[i + 1] + (str[i] == '1')) % P, c[i] = c[i + 1] ^ (str[i] == '1');
inv[1] = 1;
for (int i = 2; i <= m; i++) inv[i] = 1LL * (P - P / i) * inv[P % i] % P;
fac[0] = ifac[0] = pw[0] = 1;
for (int i = 1; i <= m; i++)
fac[i] = 1LL * fac[i - 1] * i % P, ifac[i] = 1LL * ifac[i - 1] * inv[i] % P, pw[i] = 2 * pw[i - 1] % P;
f0[0] = 1;
for (int i = 0; i < l; i++) {
if (str[i] == '1') {
int x = 1LL * b[i + 1] * pw[i + 1] % P;
for (int j = i; j < m; j++)
for (int k = 0, t = a[j]; k <= j; k++)
ans = (ans + 1LL * C(j, k) * t % P * (c[i] ? f1[j - k] : f0[j - k])) % P, t = 1LL * t * x % P;
}
for (int j = 0; j < m; j++) {
g0[j] = g1[j] = 0;
for (int k = j; ~k; k--)
g0[j] = (g0[j] * 2 + 1LL * C(j, k) * f1[k]) % P,
g1[j] = (g1[j] * 2 + 1LL * C(j, k) * f0[k]) % P;
}
for (int j = 0; j < m; j++)
f0[j] = (1LL * f0[j] * pw[j] + g0[j]) % P, f1[j] = (1LL * f1[j] * pw[j] + g1[j]) % P;
}
for (int i = 0; i < m; i++)
ans = (ans + 1LL * a[i] * lagrange(1LL * b[i + 1] * pw[i + 1] % P - 1, i) % P * inv2) % P;
printf("%d", ans);
return 0;
}