HDU 6057 - Kanade's convolution | 2017 Multi-University Training Contest 3

/*
HDU 6057 - Kanade's convolution [ FWT ]  |  2017 Multi-University Training Contest 3
题意:
	给定两个序列 A[0...2^m-1], B[0...2^m-1]
	求 C[0...2^m-1]	,满足:
		C[k] = ∑[i&j==k] A[i^j] * B[i|j]
	m <= 19
分析:
	看C[k]的形式与集合卷积的形式接近,故转化式子时主要向普通的集合卷积式方向靠
	与三种位运算都相关的结论是 : i^j + i&j = i|j
	设 x = i^j, y = i|j,则显然 k = y-x,且 k 与 x 互成关于 y 的补集,即 k = x^y 
	
	再来关心给定(x,y),符合 x = i^j, y = i|j的(i,j)对的数目
		注意到相同的位 i&j 是确定的,x = i^j 是i和j不同的位的数目,这部分谁是 0 谁是 1 不固定
			故(i,j)对的数目为 2^bits(x)

	此时重写原式: C[k] = ∑ [k == x^y] [k == y-x] A[x]*2^bits(x) * B[y]
	
	设 A'[x] = A[x]*2^bits(x)
	由于 [k == x^y],第二个条件 [k == y-x] 等价于 bits(k) == bits(y) - bits(x)
		C[k] = ∑ [k == x^y] [bits(k) == bits(x) - bits(y)] A'[x] * B[y]
	
	将 A,B,C三个数组按 bits 划分:
		C[bits(k)][k] = ∑ [k == x^y] A[bits(x)][x]*2^bits(x) * B[bits(y)][y]
	
	最后按不同的维度(bits)做 FWT即可
*/
#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
const int N = 1<<20;
int rev2;
long long inv( long long a , long long m)
{
	if (a == 1) return 1;
	return inv(m%a, m) * (m - m/a) % m;
}
void FWT(int a[], int n) {
    for (int d = 1; d < n; d <<= 1)
        for (int m = d<<1, i = 0; i < n; i += m)
            for (int j = 0; j < d; j++)
            {
                int x = a[i+j], y = a[i+j+d];
                a[i+j] = (x+y) % MOD;
                a[i+j+d] = (x-y+MOD) % MOD;
            }
}
void UFWT(int a[], int n) {
    for (int d = 1; d < n; d <<= 1)
        for (int m = d<<1, i = 0; i < n; i += m)
            for (int j = 0; j < d; j++)
            {
                int x = a[i+j], y = a[i+j+d];
                a[i+j] = 1LL*(x+y) * rev2 % MOD;
                a[i+j+d] = (1LL*(x-y)*rev2 % MOD + MOD) % MOD;
            }
}
int a[20][N], b[20][N], c[20][N];
int bits[N];
int m, n;
void init()
{
    rev2 = inv(2, MOD);
    bits[0] = 0;
    for (int i = 1; i < N; i++) bits[i] = bits[i>>1] + (i&1);
}
int main()
{
    init();
    scanf("%d", &m);
    n = 1<<m;
    for (int i = 0; i < n; i++)
    {
        int x; scanf("%d", &x);
        a[bits[i]][i] = 1LL*x * (1<<bits[i]) % MOD;
    }
    for (int i = 0; i < n; i++)
    {
        int x; scanf("%d", &x);
        b[bits[i]][i] = x;
    }
    for (int i = 0; i <= m; i++) FWT(a[i], n);
    for (int i = 0; i <= m; i++) FWT(b[i], n);
    for (int i = 0; i <= m; i++)
        for (int j = i; j <= m; j++)
            for (int k = 0; k < n; k++)
            {
                c[j-i][k] = (c[j-i][k] + 1LL*a[i][k] * b[j][k] % MOD) % MOD;
            }
    for (int i = 0; i <= m; i++) UFWT(c[i], n);
    long long ans = 0, base = 1;
    for (int i = 0; i < n; i++)
    {
        ans = ( ans + c[bits[i]][i] * base % MOD ) % MOD;
        base = base * 1526 % MOD;
    }
    printf("%lld\n", ans);
}

  

posted @ 2017-08-05 20:50  nicetomeetu  阅读(192)  评论(0编辑  收藏  举报