LOJ 6267. 生成随机数 题解

好题吼! 所以它为什么收敛呢

题意

https://loj.ac/p/6267

题解

\(m=\sum a_i\)

首先肯定会想到建一棵二叉树,叶节点数量 \(\ge m\) ,叶节点按 \(a_i\) 等比例分配。为了减少步数,把相同的状态的叶节点放在一起,若一个子树内叶节点相同,显然到子树根就可以停止,先不考虑返回叶节点的贡献

\(\displaystyle\sum_{i=1}^n \sum_{j=0}^{\infty} [a_i\&2^j]2^{j-k}(k-j)\)

解释:把每个 \(a_i\) 二进制下的 \(1\) 放在一个子树中,到了子树根就返回,那么该子树根的深度为 \(k-j\) (根为 \(0\) ) ,走到子树中的概率为 \(2^{j-k}\)

但是可能 \(\sum a_i\) 不是 \(2\) 的次幂,必须建返回节点!此时不知道选哪个 \(k\) 最优!

那么我们考虑能不能用无限层 的二叉树,把 \(\frac{a_i}{m}\) 都表示出来,这样我们就不用建返回节点了

\(\frac{a_i}{m}\) 进行二进制表示,第 \(2^{-i}\) 位的 \(1\) 对应到树上的贡献为 \(2^{-i} \cdot i\) ,而 \(\frac{a_i}{m}\) 一定能表示为二进制下的 循环小数 !

\(\text{why?}\)

考虑大除法,只不过每次 \(\times 10\) 变为 \(\times 2\) ,此时对于确定的被除数,下一步的余数也是定值,因为 \(m\) 有限,所以一定会形成环结构。

\(b_{i,j}\) 表示把 \(\frac{a_i}{m}\) 写成二进制小数后, \(2^{-j}\) 位是否为 \(1\) ,贡献就是 \(\displaystyle E(\frac{a_i}{m})=\sum_{j=0}^{\infty} b_{i,j}\times j \times (\frac{1}{2})^j\)

然后可以证明, \(E(X)\) 是收敛的!具体证明请看王总博客 (说明循环小数仅是为了说明贡献收敛) 。

答案就是 \(\sum_{i=1}^n E(\frac{a_i}{m})\)

考虑找到 \(E(X)\)\(E(\{2X\})\) 之间的关系,其中 \(\{\}\) 表取小数部分
\({}\)

\[E(\{2X\})= \begin{cases} \displaystyle\sum_{j=0}^{\infty} b_{i,j+1}\times j \times (\frac{1}{2})^j=\sum_{j=0}^{\infty} b_{i,j}\times (j-1) \times (\frac{1}{2})^{j-1}=2E(X)-2X\ (X < \frac{1}{2}) \\ \displaystyle\sum_{j=0}^{\infty} b_{i,j+1}\times j \times (\frac{1}{2})^j-1=2E(X)-2\{X\}-1=2E(X)-2X\ (X \ge \frac{1}{2}) \\ \end{cases} \]

所以 \(E(\{2X\})=2E(X)-2X\)\(E(\frac{i}{m})\)\(E(\frac{2i\%m}{m})\) 连边,形成基环树森林,然后在环的地方解方程,求出所有的 \(E(\frac{i}{m})\) 即可。

Code

#include<bits/stdc++.h>
#define ri register int
#define ll long long
using namespace std;
const int maxn = 1e6 + 10,mod = 998244353;
template<class T>
inline void rd(T &x){
    x = 0; char ch = getchar();
    while(!isdigit(ch)) ch = getchar();
    while(isdigit(ch)) x = x * 10 + ch - 48,ch = getchar();
}
int vis[maxn*10];
int n,a[maxn],m,rt;
ll f[maxn*10],g[maxn*10],ans[maxn*10],invm;
inline ll qp(ll x,ll k,ll res = 1){
	x %= mod;
	for(;k;k >>= 1,x = x * x % mod) if(k & 1) res = res * x % mod;
	return res;
}
void dfs(int u){
	vis[u] = 1;
	int nxt = (u<<1) % m; ll nf = f[u] * 2 % mod,ng = g[u] * 2 - u * 2;
	ng = (ng % mod + mod) % mod;
	if(vis[nxt] == 1){
		ans[rt] = (g[nxt] - ng + mod) * invm % mod * qp(nf - f[nxt] + mod,mod - 2) % mod;
		return;
	}
	if(vis[nxt] == 2){
		ans[rt] = (ans[nxt] - ng * invm % mod + mod) * qp(nf,mod - 2) % mod;
		return;
	}
	g[nxt] = ng,f[nxt] = nf,dfs(nxt);
}
void calc(int u){
	vis[u] = 2;
	int nxt = (u<<1) % m;
	if(vis[nxt] == 2) return;
	ans[nxt] = ans[u] * 2 - 2 * u * qp(m,mod - 2) % mod;
	ans[nxt] = (ans[nxt] % mod + mod) % mod;
	calc(nxt);
}
int main(){
	rd(n); for(ri i = 1;i <= n;++i) rd(a[i]),m += a[i];
	invm = qp(m,mod - 2);
	for(ri i = 1;i <= n;++i)
		if(!vis[a[i]]) rt = a[i],f[rt] = 1,g[rt] = 0,dfs(rt),calc(rt);
	ll res = 0;
	for(ri i = 1;i <= n;++i){
		res += ans[a[i]];
		if(res >= mod) res -= mod;
	}
	printf("%lld\n",res);
	return 0;
}

posted @ 2022-03-16 19:23  Lumos壹玖贰壹  阅读(126)  评论(0编辑  收藏  举报