UNR#2 黎明前的巧克力(FWT)

题目链接

题意:给定1个长度为n的数组,你要找出两个不均为空的集合,使得它们异或和相等,求方案数,对998244353取模

一个经典套路题?(也许)

一下|S|表示S二进制下1的个数

因为每个集合S,贡献为$2^{|S|}$,所以考虑设第i个元素的生成函数$F_i(x)$为$1 + 2x^{a_i}$,$ans = \Pi F_i(x)$[$x^0$]  这里的多项式乘法定义为异或卷积,显然我们应该先求出FWT的点值然后IFWT回去

考虑$fwt(i)$ = $\Sigma$ $(-1)^{|i & j|}$ $* f_j$,所以$FWT(F_k(x))$在i处的点值= $1$ + $(-1) ^{|i & a_k|}$ * $2$,即对于每个位置,点值为-1或3

考虑设对于第i个位置,有$x_i$个-1,$n - x_i$个3

考虑因为FWT是线性变换,所以

有以下Lemma

$\Sigma FWT(F_i(x))$ = $FWT(\Sigma F_i(x))$(I)

考虑解方程令$Sum = \Sigma F_i$,对Sum做FWT,我们就可以得到$Sum_i$,又因为Lemma(I),我们有$x_i * (-1) + (n - x_i) * 3$ = $Sum_i$,

于是可以得到所有的$x_i$,然后该点FWT后的点值显然是,$(-1)^{x_i} * 3^{n-x_i}$,然后再做一遍IFWT,$x^0$处的系数即为答案

代码如下:

/*[UNR #2]黎明前的巧克力*/
#include<bits/stdc++.h>
using namespace std;
#define ll long long
int read(){
    char c = getchar();
    int x = 0;
    while(c < '0' || c > '9')        c = getchar();
    while(c >= '0' && c <= '9')        x = x * 10 + c - 48,c = getchar();
    return x;
}
const int N = (1 << 20) | 1;
#define mod 998244353
int inv2,inv4;
int qpow(int x,int y){
    int ans = 1;
    while(y){
        if(y & 1)    ans = 1ll * ans * x % mod;
        x = 1ll * x * x % mod;
        y >>= 1;
    }
    return ans;
}
void fwt(int *f,int n,int opt){
    for(int j = 1; j < n; j <<= 1){
        for(int i = 0,len = j << 1; i < n; i += len){
            for(int k = 0; k < j; ++k){
                int x = f[i+k],y = f[i+k+j];
                f[i+k] = (x + y) % mod;
                f[i+k+j] = (x - y + mod) % mod;
                if(opt == -1)    f[i+k] = 1ll * f[i+k] * inv2 % mod,f[i+k+j] = 1ll * f[i+k+j] * inv2 % mod;
            }
        }
    } 
}
int s[N],f[N],a[N];
int main(){
    int n = read();
    for(int i = 1; i <= n; ++i)        a[i] = read(),s[0]++,s[a[i]] += 2;
    const int Mx = (1 << 20);
    inv2 = qpow(2,mod-2);inv4 = qpow(4,mod-2);
    fwt(s,Mx,1);
    for(int i = 0; i < Mx; ++i){
        int x = 1ll * (3ll * n - s[i] + mod) * inv4 % mod;
        f[i] = qpow(3,n-x);
        if(x & 1)    f[i] = (mod - f[i]) % mod;
    }
    fwt(f,Mx,-1);
    int ans = (f[0] - 1 + mod) % mod;
    cout<<ans<<endl;
    return 0;
}
View Code

 

posted @ 2021-02-28 00:41  y_dove  阅读(111)  评论(0编辑  收藏  举报