题解 矩阵补全
- 关于 FWT:
FWT 是线性变换,所以最外层循环的枚举顺序其实无关紧要
然后最外层枚举其实是固定了一个二进制位,用不同的方法合并剩下的位
固定的位之间是独立的,所以固定某一位跑 fwt_or,再固定另一位跑 fwt_and 之类的事情也是可行的
根据上面分析写代码即可
点击查看代码
// ubsan: undefined
// accoders
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1048600
#define pb push_back
#define ll long long
//#define int long long
int n, m, q, lim;
ll s[N];
int b[N];
char buf[N];
const ll mod=1e9+7, inv2=(mod+1)>>1;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
namespace force{
int now;
ll f[2][1<<10];
void solve() {
for (int i=0; i<lim; ++i) f[now][i]=s[i];
for (int t=1; t<m; ++t,now^=1) {
memset(f[now^1], 0, sizeof(f[now^1]));
for (int i=0; i<lim; ++i) if (s[i]) {
for (int j=0; j<lim; ++j) {
int tem=0;
for (int k=0; k<m; ++k) {
if (b[k]==0) {
if ((i&(1<<k))!=(j&(1<<k))) goto jump;
tem|=(i&(1<<k));
}
else if (b[k]==1) {
tem|=(i&(1<<k))|(j&(1<<k));
}
else if (b[k]==2) {
tem|=(i&(1<<k))&(j&(1<<k));
}
else {
tem|=(i&(1<<k))^(j&(1<<k));
}
}
f[now^1][tem]=(f[now^1][tem]+f[now][j])%mod;
jump: ;
}
}
}
for (int i=1,c; i<=q; ++i) {
scanf("%d", &c);
printf("%lld\n", f[now][c]);
}
}
}
namespace task1{
ll ans[1<<20], f[1<<20];
void fwt_or(ll* a, int len, int op) {
for (int i=1; i<len; i<<=1)
for (int j=0,step=i<<1; j<len; j+=step)
for (int k=j; k<j+i; ++k)
a[k+i]=(a[k+i]+op*a[k])%mod;
}
void calc(vector<int>& tem, vector<int>& id, int mask) {
int len=1<<id.size();
for (int i=0; i<len; ++i) f[i]=0;
for (auto& it:tem) ++f[it];
fwt_or(f, len, 1);
for (int i=0; i<len; ++i) f[i]=qpow(f[i], n);
fwt_or(f, len, -1);
for (int i=0; i<len; ++i) if (f[i]) {
int t=mask;
for (int j=0; j<id.size(); ++j) if (i&(1<<j))
t|=(1<<id[j]);
ans[t]=(ans[t]+f[i])%mod;
}
}
void solve(int u, vector<int> id, vector<int>& sta, int mask) {
if (u>=m) {
vector<int> tem;
for (auto& it:sta) {
int mask=0;
for (int i=0; i<id.size(); ++i)
mask|=((it&(1<<id[i]))?1:0)<<i;
tem.pb(mask);
}
calc(tem, id, mask);
return ;
}
if (!b[u]) {
vector<int> t[2];
for (auto& it:sta) t[(it&(1<<u))?1:0].pb(it);
solve(u+1, id, t[0], mask), solve(u+1, id, t[1], mask|(1<<u));
}
else {
id.pb(u);
solve(u+1, id, sta, mask);
}
}
void solve() {
vector<int> sta;
for (int i=0; i<lim; ++i) if (s[i]) sta.pb(i);
solve(0, vector<int>(), sta, 0);
for (int i=1,c; i<=q; ++i) {
scanf("%d", &c);
printf("%lld\n", (ans[c]%mod+mod)%mod);
}
}
}
namespace task{
ll f[N];
void fwt(ll* a, int len, int op) {
for (int i=1,now=0; i<len; i<<=1,++now)
for (int j=0,step=i<<1; j<len; j+=step)
for (int k=j; k<j+i; ++k)
if (b[now]==1) a[k+i]=(a[k+i]+op*a[k])%mod;
else if (b[now]==2) a[k]=(a[k]+op*a[k+i])%mod;
else if (b[now]==3) {
ll x=a[k], y=a[k+i];
a[k]=(x+y)%mod;
a[k+i]=(x-y)%mod;
if (op==-1) a[k]=a[k]*inv2%mod, a[k+i]=a[k+i]*inv2%mod;
}
}
void solve() {
for (int i=0; i<lim; ++i) f[i]=s[i];
fwt(f, lim, 1);
for (int i=0; i<lim; ++i) f[i]=qpow(f[i], n);
fwt(f, lim, -1);
for (int i=1,c; i<=q; ++i) {
scanf("%d", &c);
printf("%lld\n", (f[c]%mod+mod)%mod);
}
}
}
signed main()
{
freopen("completion.in", "r", stdin);
freopen("completion.out", "w", stdout);
scanf("%d%d%s", &n, &m, buf);
lim=1<<m;
for (int i=0; i<lim; ++i) s[i]=buf[i]^48;
for (int i=0; i<m; ++i) scanf("%d", &b[i]);
scanf("%d", &q);
// bool flag=1;
// for (int i=0; i<m; ++i) if (b[i]>1) flag=0;
// if (!flag) force::solve();
// else task1::solve();
task::solve();
return 0;
}