CF662C Binary Table 枚举 FWT

题面

洛谷题面 (虽然洛谷最近有点慢)

题解

观察到行列的数据范围相差悬殊,而且行的数量仅有20,完全可以支持枚举,因此我们考虑枚举哪些行会翻转。
对于第i列,我们将它代表的01串提取出来,表示为\(v[i]\)
然后我们假设有第0列,其中的第i行如果是1,表示这行将会翻转。
那么可以发现,执行完对行的操作时,每一列的状态为\(x = v[i] \oplus v[0]\),此时我们只需要考虑对列的操作,令\(cnt[i]\)表示状态为\(i\)时01串中1的个数。
显然为了使得1的个数尽可能少,对于状态为\(x\)的列,产生的贡献为\(s[x] = min(cnt[x], n - cnt[x])\)
\(ans[b]\)表示\(v[0] = b\)时的最优解。
那么有

\[ans[b] = \sum_{i = 1}^{m} s[v[i] \oplus b] \]

考虑换一种枚举方式,我们枚举\(s[i]\),然后就只需要再找到使得\(v[j] \oplus b = i\)\(j\)有多少个就可以快速算出贡献了。
\(v[j] \oplus b = i \Longrightarrow v[j] = i \oplus b\)
因此我们只需要找到有多少个\(v[j] = i \oplus b\)即可,令\(p[i]\)表示有多少列的状态为\(i\),
那么我们要求的个数即为\(p[i \oplus b]\)
因此答案就是:

\[ans[b] = \sum_{i = 0}^{2^n - 1} s[i] p[i \oplus b] \]

观察到\(i \oplus i \oplus b = b\),是一个定值。
因此上式等效于

\[ans[b] = \sum_{i \oplus j = b} s[i] p[j] \]

直接上FWT即可

#include<bits/stdc++.h>
using namespace std;
#define R register int
#define LL long long
#define AC 22
#define ac 100100
#define N 1050000

int n, m, maxn;
LL s[N], p[N], ans[N];
char ss[AC][ac];

inline void upmin(LL &a, LL b){if(b < a) a = b;}

inline int cal(int x)
{
    int rnt = 0;
    while(x) rnt += x & 1, x >>= 1;
    return rnt;
}

void pre()
{
    scanf("%d%d", &n, &m), maxn = 1 << n;
    for(R i = 1; i <= n; i ++) scanf("%s", ss[i] + 1);
    for(R i = 0; i <= maxn; i ++) s[i] = min(cal(i), n - cal(i));
    for(R i = 1; i <= m; i ++)
    {
        int x = 0;
        for(R j = 1; j <= n; j ++) x <<= 1, x += (ss[j][i] == '1');
        ++ p[x];
    }
/*	for(R i = 0; i < maxn; i ++) printf("%lld ", s[i]);
    printf("\n");
    for(R i = 0; i < maxn; i ++) printf("%lld ", p[i]);
    printf("\n"); */
}

void fwt(LL *A, int opt)
{
    for(R i = 2; i <= maxn; i <<= 1)
        for(R r = i >> 1, j = 0; j < maxn; j += i)
            for(R k = j; k < j + r; k ++)
            {
                LL x = A[k], y = A[k + r];
                A[k] = x + y, A[k + r] = x - y;
                if(opt < 0) A[k] >>= 1, A[k + r] >>= 1;
            }
}

void work()
{
    fwt(s, 1), fwt(p, 1);
    for(R i = 0; i < maxn; i ++) ans[i] = s[i] * p[i];
    fwt(ans, -1);
    LL rnt = n * m;
    for(R i = 0; i < maxn; i ++) upmin(rnt, ans[i]);
    printf("%lld\n", rnt);
}

int main()
{
//	freopen("in.in", "r", stdin);
    pre();
    work();
//	fclose(stdin);
    return 0;
}
posted @ 2019-02-12 07:52  ww3113306  阅读(123)  评论(0编辑  收藏  举报
知识共享许可协议
本作品采用知识共享署名-非商业性使用-禁止演绎 3.0 未本地化版本许可协议进行许可。