FWT 做题笔记

【UNR #2】黎明前的巧克力

题目链接

难度:简单

考虑如果一个集合大小为 \(|S|\) 的集合 \(S\) 中的所有数字异或和为 \(0\),那对答案的贡献应该是 \(2^{|S|}\),前提是 \(|S|\not =0\),如果 \(|S|=0\),则贡献应该是 \(0\),我们最后再去掉这个限制的影响。

由此我们可以写出巧克力的结合幂级数:

\[1+2x^{a_i} \]

一个暴力的想法是直接用 \(FWT\),但是这样一共要做 \(n\)\(FWT\),每一次都是 \(O(v\log v)\) 的(\(v\) 表示值域),对于时间来说非常浪费,我们需要考虑利用这个集合幂级数的特殊性质。

根据我们之前博客讲过的 \(FWT(A)_i\) 的式子,容易发现,每一项的数都是 \(-1\) 或者是 \(3\),所以我们可以考虑算出来最后的每个位置上有多少个集合幂级数是 \(-1\),有多少个是 \(3\)。具体来说,把上面的集合幂级数累加,然后只做一遍 \(FWT\),通过解方程,我们可以得到把上面的式子异或卷积之后的到的 \(FWT\) 式子。然后 \(IFWT\) 即可。

代码:

#include<bits/stdc++.h>
#define mset(a,b) memset((a),(b),sizeof((a)))
#define rep(i,l,r) for(int i=(l);i<=(r);i++)
#define dec(i,l,r) for(int i=(r);i>=(l);i--)
#define inc(a,b) (((a)+(b))>=mod?(a)+(b)-mod:(a)+(b))
#define sub(a,b) (((a)-(b))<0?(a)-(b)+mod:(a)-(b))
#define mul(a,b) 1ll*(a)*(b)%mod
#define sgn(a) (((a)&1)?(mod-1):1)
#define cmax(a,b) (((a)<(b))?(a=b):(a))
#define cmin(a,b) (((a)>(b))?(a=b):(a))
#define Next(k) for(int x=head[k];x;x=li[x].next)
#define vc vector
#define ar array
#define pi pair
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define N 2000100
#define M number
using namespace std;

typedef double dd;
typedef long double ld;
typedef long long ll;
typedef unsigned int uint;
typedef unsigned long long ull;
//#define int long long
typedef pair<int,int> P;
typedef vector<int> vi;

const int INF=0x3f3f3f3f;
const dd eps=1e-9;
const int mod=998244353;

template<typename T> inline void read(T &x) {
    x=0; int f=1;
    char c=getchar();
    for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
    for(;isdigit(c);c=getchar()) x=x*10+c-'0';
    x*=f;
}

int n,a[N],b[N],inv2,c[N],inv4;

inline int ksm(int a,int b,int mod){int res=1;while(b){if(b&1)res=1ll*res*a%mod;a=1ll*a*a%mod;b>>=1;}return res;}
inline int inv(int a){return ksm(a,mod-2,mod);}
inline P Calc(int a,int b){
    int x=1ll*((a+b)%mod)*inv4%mod;
    int y=(a-x)%mod;
    // x=(x+mod)%mod;y=(y+mod)%mod;
    return mp(x,y);
}
inline void FWT(int *f,int n,int op){
    for(int i=2;i<=n;i<<=1)
        for(int l=0;l<n;l+=i)
            rep(k,l,l+i/2-1){
                // printf("k=%d\n",k);
                int a=f[k],b=f[k+i/2];
                if(op){
                    f[k]=(a+b)%mod;f[k+i/2]=(a-b)%mod;
                }
                else{
                    f[k]=1ll*((a+b)%mod)*inv2%mod;
                    f[k+i/2]=1ll*((a-b)%mod)*inv2%mod;
                }
            }
}

int main(){
    // assert(freopen("my.in","r",stdin));
    // assert(freopen("my.out","w",stdout));
    inv2=ksm(2,mod-2,mod);inv4=inv(4);
    read(n);rep(i,1,n) read(a[i]);
    int nn=n;
    rep(i,1,n) b[a[i]]+=2;b[0]+=n;
    int m=1000000;n=1;while(n<m) n<<=1;
    FWT(b,n,1);
    // rep(i,0,n-1) printf("%d ",b[i]);puts("");
    rep(i,0,n-1){
        P now=Calc(nn,b[i]);
        // printf("now.fi=%d now.se=%d\n",now.fi,now.se);
        c[i]=1ll*ksm(3,now.fi,mod)*sgn(now.se)%mod;
    }
    FWT(c,n,0);
    c[0]--;
    printf("%d\n",(c[0]+mod)%mod);
    return 0;
}

CF1119H

题目链接

难度:较难

可以转化成上面那道题。

第一个转化:容易发现这个题的集合幂级数是这个形式:

\[ux^{a_i}+vx^{b_i}+wx^{c_i} \]

如果直接考虑的话,情况个数是 \(2^3=8\),我们可以用下面的集合幂级数代替上面的,以把情况数缩小 \(\frac{1}{2}\)

\[ux^{a_i\oplus c_i}+vx^{b_i\oplus c_i}+w \]

这样做对答案有什么影响,不难发现,我们只需要让答案下标异或上 \(\oplus_{i=1}^{n}c_i=xsum\),就可以修正这样做带来的影响。这种技巧在没有常数项的时候适用,如果有常数项的存在将无法缩小情况。

接下来我们考虑转化成上面那道题解方程式的模型。不考虑 \(w\),一共有 \(4\) 种情况:

\[\begin{cases} u+v\\ u-v\\ -u+v\\ -u-v\\ \end{cases} \]

\(4\) 种情况在第 \(i\) 位的出现次数分别为 \(A_i,B_i,C_i,D_i\)
那么这一位的 \(FWT\) 值应该是 \((u+v)^{A_i}(u-v)^{B_i}(-u+v)^{C_i}(-u-v)^{D_i}\),考虑通过解方程得到 \(A_i,B_i,C_i,D_i\) 是多少。

显然 \(A_i+B_i+C_i+D_i=n\),通过计算 \(FWT(\sum x^{a_i\oplus c_i})\) 可以得到 \(A_i+B_i\) 是多少(通过解一个二元方程)。同理通过计算 \(FWT(\sum x^{b_i\oplus c_i})\) 可以得到 \(A_i+C_i\) 是多少。

非常人类智慧的一点是我们可以通过计算 \(FWT(\sum x^{(a_i\oplus c_i)\oplus (b_i \oplus c_i)})\) 来得到 \(A_i+D_i\),这是因为以下式子成立:

\[FWT(x^{(a_i\oplus c_i)\oplus (b_i \oplus c_i)})=FWT(x^{a_i\oplus b_i}) \]

可知

\[[x^k]=(-1)^{|k\&(a_i\oplus b_i)|}=(-1)^{(k\&a_i)\oplus (k\&b_i)}=(-1)^{k\&a_i}(-1)^{k\&b_i} \]

所以只有当两边同号的时候 \(x^k\) 的系数才能是 \(1\),由此通过解方程我们可以得到 \(A_i+D_i\)

代码:

#include<bits/stdc++.h>
#define mset(a,b) memset((a),(b),sizeof((a)))
#define rep(i,l,r) for(int i=(l);i<=(r);i++)
#define dec(i,l,r) for(int i=(r);i>=(l);i--)
#define inc(a,b) (((a)+(b))>=mod?(a)+(b)-mod:(a)+(b))
#define sub(a,b) (((a)-(b))<0?(a)-(b)+mod:(a)-(b))
#define mul(a,b) 1ll*(a)*(b)%mod
#define sgn(a) (((a)&1)?(mod-1):1)
#define cmax(a,b) (((a)<(b))?(a=b):(a))
#define cmin(a,b) (((a)>(b))?(a=b):(a))
#define mr(a) ((a)=((a)+mod)%mod)
#define Next(k) for(int x=head[k];x;x=li[x].next)
#define vc vector
#define ar array
#define pi pair
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define N 400010
#define M number
using namespace std;

typedef double dd;
typedef long double ld;
typedef long long ll;
typedef unsigned int uint;
typedef unsigned long long ull;
#define int long long
typedef pair<int,int> P;
typedef vector<int> vi;

const int INF=0x3f3f3f3f;
const dd eps=1e-9;
const int mod=998244353;

template<typename T> inline void read(T &x) {
    x=0; int f=1;
    char c=getchar();
    for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
    for(;isdigit(c);c=getchar()) x=x*10+c-'0';
    x*=f;
}

int n,k,u,v,w,a[N],b[N],c[N],xsum,inv2;
int f[N],g[N],h[N],ans[N],Ans[N];

inline int ksm(int a,int b,int mod){mr(a);int res=1;while(b){if(b&1)res=1ll*res*a%mod;a=1ll*a*a%mod;b>>=1;}return res;}
inline int inv(int a){return ksm(a,mod-2,mod);}
inline void FWT(int *f,int n,int op){
    for(int i=2;i<=n;i<<=1)
        for(int l=0;l<n;l+=i)
            for(int k=l;k<l+(i/2);k++){
                ll A=f[k],B=f[k+(i/2)];
                if(op==1) f[k]=(A+B),f[k+(i/2)]=(A-B);
                else f[k]=1ll*(A+B)*inv2%mod,f[k+(i/2)]=1ll*(A-B)*inv2%mod;
            }
}
inline void Calc(ll &a,ll &b,ll &c,ll &d,ll a1,ll a2,ll a3,ll a4){
    ll sum=(a2+a3+a4)-a1;sum%=mod;a=1ll*sum*inv2%mod;
    b=a2-a;c=a3-a;d=a4-a;mr(a);mr(b);mr(c);mr(d);
}
inline void Calc(ll &a,ll &b,ll a1,ll a2){
    ll sum=(a1+a2);sum%=mod;a=1ll*sum*inv2%mod;
    b=a1-a;mr(a);mr(b);
    // printf("a1=%lld a2=%lld a=%lld b=%lld\n",a1,a2,a,b);
}

signed main(){
    // assert(freopen("my.in","r",stdin));
    // assert(freopen("my.out","w",stdout));
    read(n);read(k);read(u);read(v);read(w);rep(i,1,n) read(a[i]),read(b[i]),read(c[i]);
    u%=mod;v%=mod;w%=mod;
    rep(i,1,n) a[i]^=c[i],b[i]^=c[i];rep(i,1,n) xsum^=c[i];
    rep(i,1,n) f[a[i]]++;rep(i,1,n) g[b[i]]++;rep(i,1,n) h[a[i]^b[i]]++;
    int ln=n;
    int nn=(1<<k)-1;n=1;while(n<=nn) n<<=1;inv2=inv(2);FWT(f,n,1);FWT(h,n,1);FWT(g,n,1);
    // printf("f: ");rep(i,0,n-1) printf("%d ",f[i]);puts("");
    // printf("g: ");rep(i,0,n-1) printf("%d ",g[i]);puts("");
    // printf("h: ");rep(i,0,n-1) printf("%d ",h[i]);puts("");
    rep(i,0,n-1){
        // printf("i=%d\n",i);
        ll A,B,C,D,a1,a2,a3,a4;a1=ln;
        ll aa,bb,aa1,aa2;
        aa1=ln;aa2=f[i];Calc(aa,bb,aa1,aa2);
        a2=aa;aa2=g[i];Calc(aa,bb,aa1,aa2);
        a3=aa;aa2=h[i];Calc(aa,bb,aa1,aa2);
        a4=aa;
        Calc(A,B,C,D,a1,a2,a3,a4);
        ans[i]=1ll*ksm(u+v+w,A,mod)*ksm(u-v+w,B,mod)%mod*ksm(-u+v+w,C,mod)%mod*ksm(-u-v+w,D,mod)%mod;
    }
    FWT(ans,n,0);
    rep(i,0,n-1) Ans[i]=ans[i^xsum];
    rep(i,0,n-1) printf("%lld ",mr(Ans[i]));
    return 0; 
}

出现的问题:取模混乱,如果取模太过于复杂一定要记得 ll;记得处理负数。

石家庄的工人阶级队伍比较坚强

题目链接

因为不会卡常,所以只得了 \(85pts\),是一道 k 进制 FWT 模板题,只需要一定的转化。不难发现 \(u\)\(x-y\) 三进制下 \(1\) 的个数,而 \(v\) 三进制下 \(2\) 的个数,当然这里的 \(-\) 是在三进制下的。所以整个 \(b\) 就和 \(x-y\) 相关了,卷积即可。

代码:

#include<bits/stdc++.h>
#define mset(a,b) memset((a),(b),sizeof((a)))
#define rep(i,l,r) for(int i=(l);i<=(r);i++)
#define dec(i,l,r) for(int i=(r);i>=(l);i--)
#define mai(a) ((a)<0?((a)+mod):(((a)>mod)?(a)-mod:(a)))
#define sgn(a) (((a)&1)?(mod-1):1)
#define cmax(a,b) (((a)<(b))?(a=b):(a))
#define cmin(a,b) (((a)>(b))?(a=b):(a))
#define Next(k) for(int x=head[k];x;x=li[x].next)
#define vc vector
#define ar array
#define pi pair
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define N 631441
#define M 15
using namespace std;

typedef double dd;
typedef long double ld;
typedef long long ll;
typedef unsigned int uint;
typedef unsigned long long ull;
// #define int long long
typedef pair<int,int> P;
typedef vector<int> vi;

const int INF=0x3f3f3f3f;
const dd eps=1e-9;

template<typename T> inline void read(T &x) {
    x=0; int f=1;
    char c=getchar();
    for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
    for(;isdigit(c);c=getchar()) x=x*10+c-'0';
    x*=f;
}
    
int mod,inv3;
inline ll exgcd(ll a,ll b,ll &x,ll &y){
    if(b==0){x=1;y=0;return a;}ll g=exgcd(b,a%b,x,y);ll tmp=x;x=y;y=tmp-a/b*y;return g;
}
inline int inv(int a){
    ll x,y;int g=exgcd(a,mod,x,y);assert(g==1);
    // printf("a=%lld mod=%lld x=%lld y=%lld\n",a,mod,x,y);
    assert(mai(1ll*x*a%mod)==1);
    return (x%mod+mod)%mod;
}
inline int ksm(int a,int b,int mod){
    int res=1;while(b){if(b&1)res=1ll*res*a%mod;a=1ll*a*a%mod;b>>=1;}return res;
}

struct cp{
    int x,y;
    inline cp() {}
    inline cp(int x,int y) : x(x),y(y) {}
    inline cp operator + (const cp &b)const{return cp((x+b.x)%mod,(y+b.y)%mod);}
    inline cp operator - (const cp &b)const{return cp(mai(x-b.x),mai(y-b.y));}
    inline cp operator * (const cp &b)const{
        return cp(mai((1ll*x*b.y%mod+1ll*b.x*y%mod)%mod-1ll*x*b.x%mod),mai(1ll*y*b.y%mod-1ll*x*b.x%mod)%mod);
    }
    inline cp operator * (const int b)const{return cp(1ll*x*b%mod,1ll*y*b%mod);}
    inline void Print(){
        printf("x=%d y=%d\n",x,y);
    }
};
inline cp ksm(cp a,int b){
    cp res;res.y=1;res.x=0;
    while(b){if(b&1)res=res*a;a=a*a;b>>=1;}return res;
}
int m,t,n,B[M][M];
cp f[N],b[N];
cp w2,w1;

inline int bit1(int x){
    int res=0;while(x){if((x%3)==1) res++;x/=3;}return res;
}
inline int bit2(int x){
    int res=0;while(x){if((x%3)==2) res++;x/=3;}return res;
}
// inline int add(int x,int y){
//     int ans=0;
//     int p=1;while(x/p!=0||y/p!=0){
//         int nowx=(x/p)%3,nowy=(y/p)%3;
//         int now=(nowx+nowy)%3;ans+=p*now;p*=3;
//     }
//     return ans;
// }
// inline int del(int x,int y){
//     int ans=0;
//     int p=1;while(x/p!=0||y/p!=0){
//         int nowx=(x/p)%3,nowy=(y/p)%3;
//         int now=(nowx-nowy+3)%3;ans+=p*now;p*=3;
//     }
//     return ans;
// }
inline void FWT(cp *f,int n){
    for(int i=1;i<n;i*=3)
        for(int j=0;j<n;j+=i*3)
            rep(k,0,i-1){
                int k1=j+k,k2=j+k+i,k3=j+k+i+i;
                cp a=f[k1],b=f[k2],c=f[k3];
                f[k1]=a+b+c;
                f[k2]=a+b*w1+c*w2;
                f[k3]=a+b*w2+c*w1;
            }
}

inline void IFWT(cp *f,int n){
    for(int i=1;i<n;i*=3)
        for(int j=0;j<n;j+=i*3)
            rep(k,0,i-1){
                int k1=j+k,k2=j+k+i,k3=j+k+i*2;
                cp a=f[k1],b=f[k2],c=f[k3];
                f[k1]=a+b+c;
                f[k2]=a+b*w2+c*w1;
                f[k3]=a+b*w1+c*w2;
                // f[k1]=f[k1]*inv3;
                // f[k2]=f[k2]*inv3;
                // f[k3]=f[k3]*inv3;
            }
    int ninv=ksm(inv3,m,mod);
    rep(i,0,n-1) f[i]=f[i]*ninv;
}

signed main(){
    // assert(freopen("my.in","r",stdin));
    // assert(freopen("my.out","w",stdout));
    read(m);read(t);read(mod);inv3=inv(3);
    n=1;rep(i,1,m) n=n*3;rep(i,0,n-1) read(f[i].y);
    rep(i,1,m+1) rep(j,1,m+2-i) read(B[i-1][j-1]);
    rep(i,0,n-1) b[i].y=B[bit1(i)][bit2(i)];w1.x=1;w2=(cp){mod-1,mod-1};
    FWT(b,n);FWT(f,n);
    rep(i,0,n-1) b[i]=ksm(b[i],t);
    rep(i,0,n-1) f[i]=f[i]*b[i];
    // printf("inv3=%d\n",inv3);
    IFWT(f,n);
    rep(i,0,n-1) printf("%lld\n",f[i].y);
    return 0;
}
posted @ 2023-07-13 21:28  NuclearReactor  阅读(19)  评论(0编辑  收藏  举报