CROC 2016 - Final Round [Private, For Onsite Finalists Only] C. Binary Table FWT
C. Binary Table
题目连接:
http://codeforces.com/problemset/problem/662/C
Description
You are given a table consisting of n rows and m columns. Each cell of the table contains either 0 or 1. In one move, you are allowed to pick any row or any column and invert all values, that is, replace 0 by 1 and vice versa.
What is the minimum number of cells with value 1 you can get after applying some number of operations?
Input
The first line of the input contains two integers n and m (1 ≤ n ≤ 20, 1 ≤ m ≤ 100 000) — the number of rows and the number of columns, respectively.
Then n lines follows with the descriptions of the rows. Each line has length m and contains only digits '0' and '1'.
Output
Output a single integer — the minimum possible number of ones you can get after applying some sequence of operations.
Sample Input
3 4
0110
1010
0111
Sample Output
2
Hint
题意
给你一个nm的01矩阵,然后每次操作:你可以挑选任意的某一行或者某一列翻转,然后你需要使得整个矩阵的1的数量尽可能少,问你最少数量是多少。
题解:
首先2^nm这个算法很简单:暴力枚举横着怎么翻转,然后每一列O(1)判断就好了。
然后正解怎么做呢?
我们令ans[i]是异或i之后的1的个数是多少,那么ans[i] = sigma(cnt[i]*num[i^j),cnt[i]表示列那个二进制为i的个数,num[i]表示二进制为i这个数的1的数量是多少。
这个很显然发现 i(ij) = i,这就是一个异或卷积的形式,用FWT加速计算就好了。
代码
#include<bits/stdc++.h>
using namespace std;
const int maxn = (1<<20)+6;
int n,m,cnt[maxn];
long long x1[maxn],x2[maxn],ans[maxn];
string s[maxn];
long long t[maxn];
void utfxor(long long a[], int n) {
if(n == 1) return;
int x = n >> 1;
for(int i = 0; i < x; ++ i) {
t[i] = (a[i] + a[i + x]) >> 1;
t[i + x] = (a[i + x] - a[i]) >> 1;
}
memcpy(a, t, n * sizeof(long long));
utfxor(a, x); utfxor(a + x, x);
}
long long tmp[maxn];
void tfxor(long long a[], int n) {
if(n == 1) return;
int x = n >> 1;
tfxor(a, x); tfxor(a + x, x);
for(int i = 0; i < x; ++ i) {
tmp[i] = a[i] - a[i + x];
tmp[i + x] = a[i] + a[i + x];
}
memcpy(a, tmp, n * sizeof(long long));
}
void solve(long long a[],long long b[],int n)
{
tfxor(a,n);
tfxor(b,n);
for(int i=0;i<n;i++) a[i]=1LL*a[i]*b[i];
utfxor(a,n);
}
int main()
{
for(int i=0;i<maxn;i++){
int tmp = i;
while(tmp){
if(tmp&1)cnt[i]++;
tmp>>=1;
}
}
scanf("%d%d",&n,&m);
for(int i=0;i<n;i++)
cin>>s[i];
for(int i=0;i<m;i++){
int tmp = 0;
for(int j=0;j<n;j++){
if(s[j][i]=='1')tmp+=1<<j;
}
x1[tmp]++;
}
for(int i=0;i<(1<<n);i++)
x2[i]=min(cnt[i],n-cnt[i]);
solve(x1,x2,1<<n);
long long ans = 1e15;
for(int i=0;i<(1<<n);i++)
ans=min(ans,x1[i]);
cout<<ans<<endl;
}