CF1761D Carry Bit
Description
设 \(f(x,y)\) 是 \(x+y\) 的二进制进位数(即 \(f(x,y)=g(x)+g(y)-g(x+y)\) ,其中 \(g(x)\) 是 \(x\) 的二进制表示中 \(1\) 的个数)。
给定两个整数 \(n\) 和 \(k\) ,求出满足\(0 \leq a,b < 2^n\) 且 \(f(a,b) = k\) 的有序对 \((a,b)\) 的个数。注意,对于\(a\neq b\),\((a,b)\) 和 \((b,a)\) 被认为是两个不同的对。
由于这个数可能很大,对它取模 \(10^9+7\) 输出。
\(0\le k<n\le 10^6\)。
Solution
考虑什么时候两个数 \(x+y\) 在 \(i\) 处有进位:
- \(x_i=y_i=1\),则必然有进位
- \(x_i\ \text{xor}\ y_i=1\),并且之前一位有进位,也会有进位。
但是之前一位是否有进位你是不知道的,所以以上分类讨论还不能发挥什么作用。
那我们就连着上一位一同分类:
- 上一位没有进位,这一位也没有进位,那么 \(x,y\) 所对应的位可以是 \((0,0),(0,1),(1,0)\)。
- 上一位有进位,这一位没有进位:\(x,y\) 只能对应 \((0,0)\)。
- 上一位没有进位,这一位有进位:\(x,y\) 只能对应 \((1,1)\)。
- 上一位有进位,这一位也有进位:\(x,y\) 能够对应 \((1,0),(0,1),(1,1)\)。
我们发现当且仅当上一位和这一位的进位状态相同时,这种方案对答案的贡献才能增加。为了方便,不妨使用一个长度 \(n+1\) 的 \(01\) 序列 \(t\) 记录 \(0\) 到 \(n\) 位的进位状态。显然有 \(t_{n}=0\) 并且这一位其实是不需要的,但是为了方便我们可以加入这一位,这样的话我们知道当 \(f_i\neq f_{i+1}\) 的个数一定时必然是 \(0\) 的连续段更多。
于是我们设有 \(i\) 个 \(f_i\neq f_{i+1}\) 的地方,那么一共有 \(i+1\) 个段,那么 \(0\) 的连续段数为 \(\left\lceil\dfrac{i+1}{2}\right\rceil\),\(1\) 的连续段个数就是 \(\left\lfloor\dfrac{i+1}{2}\right\rfloor\)。
由于有 \(k\) 次进位,所以有 \(k\) 个 \(1\);由于一共 \(n+1\) 个位,所以 \(0\) 就有 \(n+1-k\) 个。
于是变成长度为 \(n\) 的序列分成 \(m\) 段的方案数,显然是 \(\dbinom{n-1}{m-1}\)。
由于有 \(i\) 个 \(f_i\neq f_{i+1}\),那么有 \(n-i\) 个 \(f_i=f_{i+1}\),那么有 \(3^{n-i}\) 的贡献。
于是答案就是:
直接 \(O(n\log n)\) 做即可,预处理 \(3\) 的幂优化至 \(O(n)\)。
#include <bits/stdc++.h>
#define int long long
namespace mystd {
inline int read() {
char c = getchar();
int x = 0, f = 1;
while (c < '0' || c > '9') f = (c == '-') ? -1 : f, c = getchar();
while (c >= '0' && c <= '9') x = (x << 1) + (x << 3) + c - '0', c = getchar();
return x * f;
}
inline void write(int x) {
if (x < 0) x = ~(x - 1), putchar('-');
if (x > 9) write(x / 10);
putchar(x % 10 + '0');
}
}
using namespace std;
using namespace mystd;
const int maxn = 1e6 + 100;
const int mod = 1e9 + 7;
int n, k, fac[maxn], ifac[maxn], inv[maxn];;
int qpow(int p, int q) {
int res = 1;
while (q) {
if (q & 1) res = res * p % mod;
p = p * p % mod, q >>= 1;
}
return res;
}
int C(int n, int m) { return n < m ? 0 : fac[n] * ifac[m] % mod * ifac[n - m] % mod; }
signed main() {
n = read(), k = read();
fac[0] = ifac[0] = inv[1] = 1;
for (int i = 1; i <= n; i++) {
if (i > 1) inv[i] = inv[mod % i] * (mod - mod / i) % mod;
fac[i] = fac[i - 1] * i % mod, ifac[i] = ifac[i - 1] * inv[i] % mod;
}
if (k == 0) return write(qpow(3, n)), 0;
int ans = 0;
for (int i = 0; i <= n; i++) {
(ans += qpow(3, n - i) * C(n - k, (i + 2) / 2 - 1) % mod * C(k - 1, i + 1 - ((i + 2) / 2) - 1) % mod) %= mod;
}
write(ans);
return 0;
}