FWT 做题笔记
【UNR #2】黎明前的巧克力
难度:简单
考虑如果一个集合大小为 \(|S|\) 的集合 \(S\) 中的所有数字异或和为 \(0\),那对答案的贡献应该是 \(2^{|S|}\),前提是 \(|S|\not =0\),如果 \(|S|=0\),则贡献应该是 \(0\),我们最后再去掉这个限制的影响。
由此我们可以写出巧克力的结合幂级数:
一个暴力的想法是直接用 \(FWT\),但是这样一共要做 \(n\) 次 \(FWT\),每一次都是 \(O(v\log v)\) 的(\(v\) 表示值域),对于时间来说非常浪费,我们需要考虑利用这个集合幂级数的特殊性质。
根据我们之前博客讲过的 \(FWT(A)_i\) 的式子,容易发现,每一项的数都是 \(-1\) 或者是 \(3\),所以我们可以考虑算出来最后的每个位置上有多少个集合幂级数是 \(-1\),有多少个是 \(3\)。具体来说,把上面的集合幂级数累加,然后只做一遍 \(FWT\),通过解方程,我们可以得到把上面的式子异或卷积之后的到的 \(FWT\) 式子。然后 \(IFWT\) 即可。
代码:
#include<bits/stdc++.h>
#define mset(a,b) memset((a),(b),sizeof((a)))
#define rep(i,l,r) for(int i=(l);i<=(r);i++)
#define dec(i,l,r) for(int i=(r);i>=(l);i--)
#define inc(a,b) (((a)+(b))>=mod?(a)+(b)-mod:(a)+(b))
#define sub(a,b) (((a)-(b))<0?(a)-(b)+mod:(a)-(b))
#define mul(a,b) 1ll*(a)*(b)%mod
#define sgn(a) (((a)&1)?(mod-1):1)
#define cmax(a,b) (((a)<(b))?(a=b):(a))
#define cmin(a,b) (((a)>(b))?(a=b):(a))
#define Next(k) for(int x=head[k];x;x=li[x].next)
#define vc vector
#define ar array
#define pi pair
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define N 2000100
#define M number
using namespace std;
typedef double dd;
typedef long double ld;
typedef long long ll;
typedef unsigned int uint;
typedef unsigned long long ull;
//#define int long long
typedef pair<int,int> P;
typedef vector<int> vi;
const int INF=0x3f3f3f3f;
const dd eps=1e-9;
const int mod=998244353;
template<typename T> inline void read(T &x) {
x=0; int f=1;
char c=getchar();
for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
for(;isdigit(c);c=getchar()) x=x*10+c-'0';
x*=f;
}
int n,a[N],b[N],inv2,c[N],inv4;
inline int ksm(int a,int b,int mod){int res=1;while(b){if(b&1)res=1ll*res*a%mod;a=1ll*a*a%mod;b>>=1;}return res;}
inline int inv(int a){return ksm(a,mod-2,mod);}
inline P Calc(int a,int b){
int x=1ll*((a+b)%mod)*inv4%mod;
int y=(a-x)%mod;
// x=(x+mod)%mod;y=(y+mod)%mod;
return mp(x,y);
}
inline void FWT(int *f,int n,int op){
for(int i=2;i<=n;i<<=1)
for(int l=0;l<n;l+=i)
rep(k,l,l+i/2-1){
// printf("k=%d\n",k);
int a=f[k],b=f[k+i/2];
if(op){
f[k]=(a+b)%mod;f[k+i/2]=(a-b)%mod;
}
else{
f[k]=1ll*((a+b)%mod)*inv2%mod;
f[k+i/2]=1ll*((a-b)%mod)*inv2%mod;
}
}
}
int main(){
// assert(freopen("my.in","r",stdin));
// assert(freopen("my.out","w",stdout));
inv2=ksm(2,mod-2,mod);inv4=inv(4);
read(n);rep(i,1,n) read(a[i]);
int nn=n;
rep(i,1,n) b[a[i]]+=2;b[0]+=n;
int m=1000000;n=1;while(n<m) n<<=1;
FWT(b,n,1);
// rep(i,0,n-1) printf("%d ",b[i]);puts("");
rep(i,0,n-1){
P now=Calc(nn,b[i]);
// printf("now.fi=%d now.se=%d\n",now.fi,now.se);
c[i]=1ll*ksm(3,now.fi,mod)*sgn(now.se)%mod;
}
FWT(c,n,0);
c[0]--;
printf("%d\n",(c[0]+mod)%mod);
return 0;
}
CF1119H
难度:较难
可以转化成上面那道题。
第一个转化:容易发现这个题的集合幂级数是这个形式:
如果直接考虑的话,情况个数是 \(2^3=8\),我们可以用下面的集合幂级数代替上面的,以把情况数缩小 \(\frac{1}{2}\):
这样做对答案有什么影响,不难发现,我们只需要让答案下标异或上 \(\oplus_{i=1}^{n}c_i=xsum\),就可以修正这样做带来的影响。这种技巧在没有常数项的时候适用,如果有常数项的存在将无法缩小情况。
接下来我们考虑转化成上面那道题解方程式的模型。不考虑 \(w\),一共有 \(4\) 种情况:
设 \(4\) 种情况在第 \(i\) 位的出现次数分别为 \(A_i,B_i,C_i,D_i\)。
那么这一位的 \(FWT\) 值应该是 \((u+v)^{A_i}(u-v)^{B_i}(-u+v)^{C_i}(-u-v)^{D_i}\),考虑通过解方程得到 \(A_i,B_i,C_i,D_i\) 是多少。
显然 \(A_i+B_i+C_i+D_i=n\),通过计算 \(FWT(\sum x^{a_i\oplus c_i})\) 可以得到 \(A_i+B_i\) 是多少(通过解一个二元方程)。同理通过计算 \(FWT(\sum x^{b_i\oplus c_i})\) 可以得到 \(A_i+C_i\) 是多少。
非常人类智慧的一点是我们可以通过计算 \(FWT(\sum x^{(a_i\oplus c_i)\oplus (b_i \oplus c_i)})\) 来得到 \(A_i+D_i\),这是因为以下式子成立:
由
可知
所以只有当两边同号的时候 \(x^k\) 的系数才能是 \(1\),由此通过解方程我们可以得到 \(A_i+D_i\)。
代码:
#include<bits/stdc++.h>
#define mset(a,b) memset((a),(b),sizeof((a)))
#define rep(i,l,r) for(int i=(l);i<=(r);i++)
#define dec(i,l,r) for(int i=(r);i>=(l);i--)
#define inc(a,b) (((a)+(b))>=mod?(a)+(b)-mod:(a)+(b))
#define sub(a,b) (((a)-(b))<0?(a)-(b)+mod:(a)-(b))
#define mul(a,b) 1ll*(a)*(b)%mod
#define sgn(a) (((a)&1)?(mod-1):1)
#define cmax(a,b) (((a)<(b))?(a=b):(a))
#define cmin(a,b) (((a)>(b))?(a=b):(a))
#define mr(a) ((a)=((a)+mod)%mod)
#define Next(k) for(int x=head[k];x;x=li[x].next)
#define vc vector
#define ar array
#define pi pair
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define N 400010
#define M number
using namespace std;
typedef double dd;
typedef long double ld;
typedef long long ll;
typedef unsigned int uint;
typedef unsigned long long ull;
#define int long long
typedef pair<int,int> P;
typedef vector<int> vi;
const int INF=0x3f3f3f3f;
const dd eps=1e-9;
const int mod=998244353;
template<typename T> inline void read(T &x) {
x=0; int f=1;
char c=getchar();
for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
for(;isdigit(c);c=getchar()) x=x*10+c-'0';
x*=f;
}
int n,k,u,v,w,a[N],b[N],c[N],xsum,inv2;
int f[N],g[N],h[N],ans[N],Ans[N];
inline int ksm(int a,int b,int mod){mr(a);int res=1;while(b){if(b&1)res=1ll*res*a%mod;a=1ll*a*a%mod;b>>=1;}return res;}
inline int inv(int a){return ksm(a,mod-2,mod);}
inline void FWT(int *f,int n,int op){
for(int i=2;i<=n;i<<=1)
for(int l=0;l<n;l+=i)
for(int k=l;k<l+(i/2);k++){
ll A=f[k],B=f[k+(i/2)];
if(op==1) f[k]=(A+B),f[k+(i/2)]=(A-B);
else f[k]=1ll*(A+B)*inv2%mod,f[k+(i/2)]=1ll*(A-B)*inv2%mod;
}
}
inline void Calc(ll &a,ll &b,ll &c,ll &d,ll a1,ll a2,ll a3,ll a4){
ll sum=(a2+a3+a4)-a1;sum%=mod;a=1ll*sum*inv2%mod;
b=a2-a;c=a3-a;d=a4-a;mr(a);mr(b);mr(c);mr(d);
}
inline void Calc(ll &a,ll &b,ll a1,ll a2){
ll sum=(a1+a2);sum%=mod;a=1ll*sum*inv2%mod;
b=a1-a;mr(a);mr(b);
// printf("a1=%lld a2=%lld a=%lld b=%lld\n",a1,a2,a,b);
}
signed main(){
// assert(freopen("my.in","r",stdin));
// assert(freopen("my.out","w",stdout));
read(n);read(k);read(u);read(v);read(w);rep(i,1,n) read(a[i]),read(b[i]),read(c[i]);
u%=mod;v%=mod;w%=mod;
rep(i,1,n) a[i]^=c[i],b[i]^=c[i];rep(i,1,n) xsum^=c[i];
rep(i,1,n) f[a[i]]++;rep(i,1,n) g[b[i]]++;rep(i,1,n) h[a[i]^b[i]]++;
int ln=n;
int nn=(1<<k)-1;n=1;while(n<=nn) n<<=1;inv2=inv(2);FWT(f,n,1);FWT(h,n,1);FWT(g,n,1);
// printf("f: ");rep(i,0,n-1) printf("%d ",f[i]);puts("");
// printf("g: ");rep(i,0,n-1) printf("%d ",g[i]);puts("");
// printf("h: ");rep(i,0,n-1) printf("%d ",h[i]);puts("");
rep(i,0,n-1){
// printf("i=%d\n",i);
ll A,B,C,D,a1,a2,a3,a4;a1=ln;
ll aa,bb,aa1,aa2;
aa1=ln;aa2=f[i];Calc(aa,bb,aa1,aa2);
a2=aa;aa2=g[i];Calc(aa,bb,aa1,aa2);
a3=aa;aa2=h[i];Calc(aa,bb,aa1,aa2);
a4=aa;
Calc(A,B,C,D,a1,a2,a3,a4);
ans[i]=1ll*ksm(u+v+w,A,mod)*ksm(u-v+w,B,mod)%mod*ksm(-u+v+w,C,mod)%mod*ksm(-u-v+w,D,mod)%mod;
}
FWT(ans,n,0);
rep(i,0,n-1) Ans[i]=ans[i^xsum];
rep(i,0,n-1) printf("%lld ",mr(Ans[i]));
return 0;
}
出现的问题:取模混乱,如果取模太过于复杂一定要记得 ll;记得处理负数。
石家庄的工人阶级队伍比较坚强
因为不会卡常,所以只得了 \(85pts\),是一道 k 进制 FWT 模板题,只需要一定的转化。不难发现 \(u\) 是 \(x-y\) 三进制下 \(1\) 的个数,而 \(v\) 三进制下 \(2\) 的个数,当然这里的 \(-\) 是在三进制下的。所以整个 \(b\) 就和 \(x-y\) 相关了,卷积即可。
代码:
#include<bits/stdc++.h>
#define mset(a,b) memset((a),(b),sizeof((a)))
#define rep(i,l,r) for(int i=(l);i<=(r);i++)
#define dec(i,l,r) for(int i=(r);i>=(l);i--)
#define mai(a) ((a)<0?((a)+mod):(((a)>mod)?(a)-mod:(a)))
#define sgn(a) (((a)&1)?(mod-1):1)
#define cmax(a,b) (((a)<(b))?(a=b):(a))
#define cmin(a,b) (((a)>(b))?(a=b):(a))
#define Next(k) for(int x=head[k];x;x=li[x].next)
#define vc vector
#define ar array
#define pi pair
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define N 631441
#define M 15
using namespace std;
typedef double dd;
typedef long double ld;
typedef long long ll;
typedef unsigned int uint;
typedef unsigned long long ull;
// #define int long long
typedef pair<int,int> P;
typedef vector<int> vi;
const int INF=0x3f3f3f3f;
const dd eps=1e-9;
template<typename T> inline void read(T &x) {
x=0; int f=1;
char c=getchar();
for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
for(;isdigit(c);c=getchar()) x=x*10+c-'0';
x*=f;
}
int mod,inv3;
inline ll exgcd(ll a,ll b,ll &x,ll &y){
if(b==0){x=1;y=0;return a;}ll g=exgcd(b,a%b,x,y);ll tmp=x;x=y;y=tmp-a/b*y;return g;
}
inline int inv(int a){
ll x,y;int g=exgcd(a,mod,x,y);assert(g==1);
// printf("a=%lld mod=%lld x=%lld y=%lld\n",a,mod,x,y);
assert(mai(1ll*x*a%mod)==1);
return (x%mod+mod)%mod;
}
inline int ksm(int a,int b,int mod){
int res=1;while(b){if(b&1)res=1ll*res*a%mod;a=1ll*a*a%mod;b>>=1;}return res;
}
struct cp{
int x,y;
inline cp() {}
inline cp(int x,int y) : x(x),y(y) {}
inline cp operator + (const cp &b)const{return cp((x+b.x)%mod,(y+b.y)%mod);}
inline cp operator - (const cp &b)const{return cp(mai(x-b.x),mai(y-b.y));}
inline cp operator * (const cp &b)const{
return cp(mai((1ll*x*b.y%mod+1ll*b.x*y%mod)%mod-1ll*x*b.x%mod),mai(1ll*y*b.y%mod-1ll*x*b.x%mod)%mod);
}
inline cp operator * (const int b)const{return cp(1ll*x*b%mod,1ll*y*b%mod);}
inline void Print(){
printf("x=%d y=%d\n",x,y);
}
};
inline cp ksm(cp a,int b){
cp res;res.y=1;res.x=0;
while(b){if(b&1)res=res*a;a=a*a;b>>=1;}return res;
}
int m,t,n,B[M][M];
cp f[N],b[N];
cp w2,w1;
inline int bit1(int x){
int res=0;while(x){if((x%3)==1) res++;x/=3;}return res;
}
inline int bit2(int x){
int res=0;while(x){if((x%3)==2) res++;x/=3;}return res;
}
// inline int add(int x,int y){
// int ans=0;
// int p=1;while(x/p!=0||y/p!=0){
// int nowx=(x/p)%3,nowy=(y/p)%3;
// int now=(nowx+nowy)%3;ans+=p*now;p*=3;
// }
// return ans;
// }
// inline int del(int x,int y){
// int ans=0;
// int p=1;while(x/p!=0||y/p!=0){
// int nowx=(x/p)%3,nowy=(y/p)%3;
// int now=(nowx-nowy+3)%3;ans+=p*now;p*=3;
// }
// return ans;
// }
inline void FWT(cp *f,int n){
for(int i=1;i<n;i*=3)
for(int j=0;j<n;j+=i*3)
rep(k,0,i-1){
int k1=j+k,k2=j+k+i,k3=j+k+i+i;
cp a=f[k1],b=f[k2],c=f[k3];
f[k1]=a+b+c;
f[k2]=a+b*w1+c*w2;
f[k3]=a+b*w2+c*w1;
}
}
inline void IFWT(cp *f,int n){
for(int i=1;i<n;i*=3)
for(int j=0;j<n;j+=i*3)
rep(k,0,i-1){
int k1=j+k,k2=j+k+i,k3=j+k+i*2;
cp a=f[k1],b=f[k2],c=f[k3];
f[k1]=a+b+c;
f[k2]=a+b*w2+c*w1;
f[k3]=a+b*w1+c*w2;
// f[k1]=f[k1]*inv3;
// f[k2]=f[k2]*inv3;
// f[k3]=f[k3]*inv3;
}
int ninv=ksm(inv3,m,mod);
rep(i,0,n-1) f[i]=f[i]*ninv;
}
signed main(){
// assert(freopen("my.in","r",stdin));
// assert(freopen("my.out","w",stdout));
read(m);read(t);read(mod);inv3=inv(3);
n=1;rep(i,1,m) n=n*3;rep(i,0,n-1) read(f[i].y);
rep(i,1,m+1) rep(j,1,m+2-i) read(B[i-1][j-1]);
rep(i,0,n-1) b[i].y=B[bit1(i)][bit2(i)];w1.x=1;w2=(cp){mod-1,mod-1};
FWT(b,n);FWT(f,n);
rep(i,0,n-1) b[i]=ksm(b[i],t);
rep(i,0,n-1) f[i]=f[i]*b[i];
// printf("inv3=%d\n",inv3);
IFWT(f,n);
rep(i,0,n-1) printf("%lld\n",f[i].y);
return 0;
}