LOJ 6267. 生成随机数 题解
好题吼! 所以它为什么收敛呢
题意
题解
令 \(m=\sum a_i\) 。
首先肯定会想到建一棵二叉树,叶节点数量 \(\ge m\) ,叶节点按 \(a_i\) 等比例分配。为了减少步数,把相同的状态的叶节点放在一起,若一个子树内叶节点相同,显然到子树根就可以停止,先不考虑返回叶节点的贡献:
\(\displaystyle\sum_{i=1}^n \sum_{j=0}^{\infty} [a_i\&2^j]2^{j-k}(k-j)\)
解释:把每个 \(a_i\) 二进制下的 \(1\) 放在一个子树中,到了子树根就返回,那么该子树根的深度为 \(k-j\) (根为 \(0\) ) ,走到子树中的概率为 \(2^{j-k}\)
但是可能 \(\sum a_i\) 不是 \(2\) 的次幂,必须建返回节点!此时不知道选哪个 \(k\) 最优!
那么我们考虑能不能用无限层 的二叉树,把 \(\frac{a_i}{m}\) 都表示出来,这样我们就不用建返回节点了!
把 \(\frac{a_i}{m}\) 进行二进制表示,第 \(2^{-i}\) 位的 \(1\) 对应到树上的贡献为 \(2^{-i} \cdot i\) ,而 \(\frac{a_i}{m}\) 一定能表示为二进制下的 循环小数 !
\(\text{why?}\)
考虑大除法,只不过每次 \(\times 10\) 变为 \(\times 2\) ,此时对于确定的被除数,下一步的余数也是定值,因为 \(m\) 有限,所以一定会形成环结构。
\(b_{i,j}\) 表示把 \(\frac{a_i}{m}\) 写成二进制小数后, \(2^{-j}\) 位是否为 \(1\) ,贡献就是 \(\displaystyle E(\frac{a_i}{m})=\sum_{j=0}^{\infty} b_{i,j}\times j \times (\frac{1}{2})^j\) 。
然后可以证明, \(E(X)\) 是收敛的!具体证明请看王总博客 (说明循环小数仅是为了说明贡献收敛) 。
答案就是 \(\sum_{i=1}^n E(\frac{a_i}{m})\)
考虑找到 \(E(X)\) 与 \(E(\{2X\})\) 之间的关系,其中 \(\{\}\) 表取小数部分
\({}\)
所以 \(E(\{2X\})=2E(X)-2X\) 。 \(E(\frac{i}{m})\) 向 \(E(\frac{2i\%m}{m})\) 连边,形成基环树森林,然后在环的地方解方程,求出所有的 \(E(\frac{i}{m})\) 即可。
Code
#include<bits/stdc++.h>
#define ri register int
#define ll long long
using namespace std;
const int maxn = 1e6 + 10,mod = 998244353;
template<class T>
inline void rd(T &x){
x = 0; char ch = getchar();
while(!isdigit(ch)) ch = getchar();
while(isdigit(ch)) x = x * 10 + ch - 48,ch = getchar();
}
int vis[maxn*10];
int n,a[maxn],m,rt;
ll f[maxn*10],g[maxn*10],ans[maxn*10],invm;
inline ll qp(ll x,ll k,ll res = 1){
x %= mod;
for(;k;k >>= 1,x = x * x % mod) if(k & 1) res = res * x % mod;
return res;
}
void dfs(int u){
vis[u] = 1;
int nxt = (u<<1) % m; ll nf = f[u] * 2 % mod,ng = g[u] * 2 - u * 2;
ng = (ng % mod + mod) % mod;
if(vis[nxt] == 1){
ans[rt] = (g[nxt] - ng + mod) * invm % mod * qp(nf - f[nxt] + mod,mod - 2) % mod;
return;
}
if(vis[nxt] == 2){
ans[rt] = (ans[nxt] - ng * invm % mod + mod) * qp(nf,mod - 2) % mod;
return;
}
g[nxt] = ng,f[nxt] = nf,dfs(nxt);
}
void calc(int u){
vis[u] = 2;
int nxt = (u<<1) % m;
if(vis[nxt] == 2) return;
ans[nxt] = ans[u] * 2 - 2 * u * qp(m,mod - 2) % mod;
ans[nxt] = (ans[nxt] % mod + mod) % mod;
calc(nxt);
}
int main(){
rd(n); for(ri i = 1;i <= n;++i) rd(a[i]),m += a[i];
invm = qp(m,mod - 2);
for(ri i = 1;i <= n;++i)
if(!vis[a[i]]) rt = a[i],f[rt] = 1,g[rt] = 0,dfs(rt),calc(rt);
ll res = 0;
for(ri i = 1;i <= n;++i){
res += ans[a[i]];
if(res >= mod) res -= mod;
}
printf("%lld\n",res);
return 0;
}