[Codeforces]662C - Binary Table(FWT)
一些套路的整合题,是一个好题。
题意:
给定一个\(n\times m\)的01矩阵,每次可以选择一行或者一列进行取反,问任意进行操作后,矩阵中剩下的1最少有几个。
\(n\le 20, m\le 10^5\)
先进行一下转化,首先注意到\(n\)是很小的,有一个贪心策略是,确定了行的取反状态后,列的取反方案其实确定了,每一列,假如取反后1比较少就取反。
当每一列和行反转状态用二进制数表达之后,令行的翻转状态为\(x\),答案就变成
\[\sum_{i = 1}^{m}f(a_i \oplus x)
\]
其中
\[f(x) = min(popcount(x), n - popcount(x))
\]
怎么继续优化?
利用FWT常用的一个套路,\(a\oplus b=c\)推出\(a\oplus c=b\)
令
\[a_i\oplus x=j
\]
则
\[a_i\oplus j=x
\]
这时我们已经把等式左边的一个变量凑到外面去了。
这个\(j\)其实是一个任意数,\(f(j)\)对答案的贡献其实就跟\(a_i\)的数量有关。
枚举\(a_i\)的值可以得到
\[f(x) = \sum_{i\oplus j = x}f(j)\times cnt(i)
\]
这玩意就可以进行异或卷积了。
#include <bits/stdc++.h>
#define pt(x) cout << x << endl;
#define Mid ((l + r) / 2)
#define lson (rt << 1)
#define rson (rt << 1 | 1)
using namespace std;
int read() {
char c; int num, f = 1;
while(c = getchar(),!isdigit(c)) if(c == '-') f = -1; num = c - '0';
while(c = getchar(), isdigit(c)) num = num * 10 + c - '0';
return f * num;
}
const int N = (1 << 21) + 1009;
const int M = 2e5 + 1009;
const int mod = 998244353;
int f[N], cnt[N], n, m, g[29][M];
int Pow(int a, int p) {
int ans = 1;
for( ; p; p >>= 1, a = 1ll * a * a % mod)
if(p & 1)
ans = 1ll * ans * a % mod;
return ans % mod;
}
void FWT_xor(int *A, int n, int type) {
int inv_2 = Pow(2, mod - 2);
for(int m = 1; m < n; m <<= 1) {
for(int i = 0; i < n; i += 2 * m) {
for(int j = 0; j < m; j++) {
int x = A[i + j], y = A[i + j + m];
A[i + j] = (1ll * x + y) * (type == 1 ? 1 : inv_2) % mod;
A[i + j + m] = (1ll * x - y + mod) * (type == 1 ? 1 : inv_2) % mod;
}
}
}
}
signed main()
{
n = read(); m = read();
for(int i = 1; i <= n; i++)
for(int j = 1; j <= m; j++)
scanf("%1d", &g[i][j]);
for(int j = 1; j <= m; j++) {
int a = 0;
for(int i = 1; i <= n; i++)
a = a * 2 + g[i][j];
cnt[a]++;
}
int lim = 1 << n;
for(int i = 0; i < lim; i++) {
f[i] = min(__builtin_popcount(i), n - __builtin_popcount(i));
}
FWT_xor(f, lim, 1);
FWT_xor(cnt, lim, 1);
for(int i = 0; i < lim; i++) f[i] = 1ll * f[i] * cnt[i] % mod;
FWT_xor(f, lim, -1);
int ans = 0x3f3f3f3f;
for(int i = 0; i < (1 << n); i++) {
ans = min(ans, f[i]);
}
printf("%d\n", ans);
return 0;
}