牛客 11257 D Gambling Monster 题解
【大意】
初始时,"土块"有一个数字 \(0\) 。每一轮,他有 \(p_i(0\leq i<n=2^k)\) 的概率抽到数字 \(i\) 。若当前他的数字异或上抽中的数字,将会变得更大,那他会异或上这个数字。问他得到 \((n-1)\) 的期望步数。
【分析】
HL:以后遇到概率 dp 通通倒着跑
所以我们设 \(E(x)\) 表示 \(x\) 到 \((n-1)\) 的期望步数,则有:
\(\displaystyle E(x)=\sum_{x\oplus y=z\\x<z}p_y[E(z)+1]+\sum_{x\oplus y=z\\x\geq z}p_y[E(x)+1]\)
即当转到的数字会使得结果更大,就异或上,所以贡献直接由新的结果转移而来;当不会时,就不异或上了,所以贡献由自己转移
为了方便,我们记 \(\displaystyle \sum_{x\oplus y=z\\x<z}p_y=S_x\),则:
\(\displaystyle E(x)=\sum_{x\oplus y=z\\x<z}p_y[E(z)+1]+(1-S_x)[E(x)+1]\)
\(\displaystyle E(x)-(1-S_x)E(x)=\sum_{x\oplus y=z\\x<z}p_yE(z)+S_x+(1-S_x)\)
\(\displaystyle E(x)={1\over S_x}[\sum_{x\oplus y=z\\x<z}p_yE(z)+1]\)
由于卷积是 \(x\oplus y=z\) 的形式,所以考虑使用 FWT
由于只考虑高维对低微的贡献,所以考虑使用 cdq 分治 FWT
而对于 \(S_x\) ,我们考虑 \(\displaystyle S_x=\sum_{x\oplus y=z\\x<z}p_y\)。当且仅当 \(y\) 的最高位在 \(x\) 中为 \(0\) 时,会使得 \(x\oplus y>x\)
我们对所有概率按最高位归纳,再对每个 \(x\) 按位枚举即可
复杂度为 \(O(n\log^2 n)+O(3^{\log_2 n})=O(n\log^2 n)\)
【代码】
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pii;
typedef double db;
#define fi first
#define se second
#define lowbit(x) ((x)&(-(x)))
const int MOD=1e9+7, MAXN=1<<16, inv2=MOD+1>>1;
inline int add(int a, int b) { return (a+=b)>=MOD?a-MOD:a; }
inline int dis(int a, int b) { return (a-=b)<0?a+MOD:a; }
inline ll fpow(ll a,ll x) { ll ans=1; for(;x;x>>=1,a=a*a%MOD) if(x&1) ans=ans*a%MOD; return ans; }
inline void FWT(int *a, int len, int o=1){
for(int k=0; 1<<k<len; ++k) for(int i=0; i<len; ++i) if(~i>>k&1) {
int j=i^(1<<k), x, y;
x=add(a[i], a[j]), y=dis(a[i], a[j]);
if(o==-1) x=(ll)x*inv2%MOD, y=(ll)y*inv2%MOD;
a[i]=x, a[j]=y;
}
}
inline void doit(int *a, int *b,int len) {
FWT(a, len, 1); FWT(b, len, 1);
for(int i=0;i<len;++i) a[i]=(ll)a[i]*b[i]%MOD;
FWT(a, len, -1);
}
int a[MAXN], b[MAXN];
int n, p[MAXN], hb[MAXN], e[MAXN], sumit[16];
void solve(int l, int r) {
if(l==r){
int s=0;
for(int i=0, x=~l; 1<<i<n; ++i)
if( x>>i &1 )
s=add(s, sumit[i]);
e[l]=add(e[l], 1)*fpow(s, MOD-2)%MOD;
return ;
}
int m=l+r>>1, len=r-l+1;
solve(m+1, r);
memcpy(a, e+l, len*sizeof(e[0]));
memcpy(b, p, len*sizeof(p[0]));
memset(a, 0, len*sizeof(a[0])>>1);
doit(a, b, len);
for(int i=l, j=0;i<=m; ++i, ++j)
e[i]=add(e[i], a[j]);
solve(l, m);
}
inline int ans(){
cin>>n;
int tot=0;
memset(sumit, 0, sizeof(sumit));
for(int i=0;i<n;++i) cin>>p[i], tot=add(tot, p[i]), e[i]=0;
p[0]=p[0]*fpow(tot, MOD-2)%MOD;
for(int i=1, x=fpow(tot, MOD-2); i<n; ++i){
p[i]=(ll)p[i]*x%MOD;
sumit[ hb[i] ]=add(sumit[ hb[i] ], p[i]);
}
solve(0, n-1);
return e[0];
}
int main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
for(int i=2;i<1<<16;++i) hb[i]=hb[i-1]+(i==lowbit(i));
int T; cin>>T;
while(T--) cout<<ans()<<"\n";
cout.flush();
return 0;
}