gym102586C Sum Modulo

题意:

给你 \(n,m,k\) 以及\(p_i(1\le i\le n)\) ,保证 \(\sum p_i=1\)

你有一个数 \(X\),一开始 \(X=0\)

每次你会生成一个随机数 \(A\),有 \(p_i\) 的概率 \(A=i\) ,然后令 \(X= (X+A)\bmod m\)

\(X=k\) 的期望次数。

数据范围:\(2\le m\le 10^{18},1\le n\le \min(500,m-1),1\le k<m\)

solution:

为了方便,我们当成是从 \(k\) 走到 \(0\) ,每次 \(X=X-A\)

首先我们能想到一个基础的 \(dp\) 方程:

\(f(0)=0\)

\(f(i)=1+\sum p_jf((i-j)\bmod m)\)

\(f(k)\)

暴力高斯消元能得暴力分。一个更好的想法:

每个 \(f(x)\) 都能表示成 \(w_x+\sum a_if(i)\) 的形式。

我们尝试干两件事情:

1.对于每一个 \(f(x)\) 都能把它快速地表示成上述形式。

2.解出 \(f(1),f(2),…,f(n-1)\)

显然做出这两件事就能直接求 \(f(k)\) 了。

先处理第 \(2\) 件事:表示出 \(f(n-m+1)\)\(f(m-1)\) ,然后对于 \(f(1)\)\(f(n-1)\) ,显然就都能以两种方式表示出来,我们就可以列出 \(n-1\) 个方程,高斯消元求解。

再看第 \(1\) 件事。

一个 naive 的做法是直接矩阵快速幂暴力转移,单次都要 \(O(n^3\log m)\) ,过不了。

我们尝试直接用 \(f(x),f(y)\) 得出 \(f(x+y) (x+y<m)\),方法是:

\(f(x)=w_x+\sum a_if(i)\) 是吧。

我们再令 \(f(y)=w_y+\sum b_if(i)\)

那么考虑 \(f(i)->f(x)\) 的这个贡献的系数呀,显然只与 \(x-i\) 有关。

于是 \(f(x+y)=w_x+\sum a_if(i+y)=w_x+w_y\sum a_i+\sum f(i+j)a_ib_j (1\le i,j <n)\)

对于 \(f(2n-2)\)\(f(n)\) ,我们用 \(dp\) 式展开,合并到更小的项即可。

这样,我们就用 \(O(n^2)\) 的复杂度合并了 \(f(x),f(y)\)

于是单个 \(f(x)\) 就能用 \(O(n^2\log m)\) 的复杂度表示出来,具体是先算 \(f(2^i)\) ,再二进制拆分。

但还没做完。对于 \(f(m-n)\)\(f(m-1)\) ,最快的做法是:

先算 \(f(m-n)\)

假设算出 \(f(x)\) ,则 \(f(x+1)=w_x+\sum a_if(i+1)\) ,展开 \(f(n)\) 这一项,即可算出 \(f(x+1)\) 。这是单次 \(O(n)\) 的。

于是我们整个题成功的做到了 \(O(n^2(n+\log m))\) 的复杂度。

可能还有一些科技能做这个题,不管了。

code

#include<bits/stdc++.h>
#define ll long long 
using namespace std;
int n,e[510];
ll m,K;
const int mod=998244353;
int qpow(int a,int b){
	int c=1;
	for(;b;b>>=1){
		if(b&1)c=1ll*a*c%mod;
		a=1ll*a*a%mod;
	}
	return c;
}
struct qq{int p[510],val;}base[62],tk[510],ans[510];
int fz[1010];
qq mul(qq a,qq b){
	//f(x+y)=w_x+sum (w_y+sum f(j+i)*cy(j))*cx(i)
	//f(x+y)=w_x+w_y* (sum cx(i)+sum cx(i)cy(j)f(i+j)
	qq c;c.val=0;
	memset(c.p,0,sizeof(c.p));
	int as=0;
	for(int i=0;i<n;i++)(as+=a.p[i])%=mod;
	c.val=(a.val+1ll*as*b.val%mod)%mod;
	memset(fz,0,sizeof(fz));
	for(int i=0;i<n;i++)for(int j=0;j<n;j++)(fz[i+j]+=1ll*a.p[i]*b.p[j]%mod)%=mod;
	for(int i=2*n-2;i>=n;i--){
		(c.val+=fz[i])%=mod;
		for(int j=1;j<=n;j++)(fz[i-j]+=1ll*e[j]*fz[i]%mod)%=mod;
	}
	for(int i=0;i<n;i++)c.p[i]=fz[i];
	return c;
}
qq fik(ll x){
	qq c;bool fl=0;
	for(int i=60;i>=0;i--)if((x>>i)&1){
		if(!fl)fl=1,c=base[i];
		else c=mul(c,base[i]);
	}
	return c;
}
qq qb(){
	qq c;c.val=1;memset(c.p,0,sizeof(c.p));
	for(int i=0;i<n;i++){
		(c.val+=1ll*e[n-i]*tk[i].val%mod)%=mod;
		for(int j=1;j<n;j++)
			(c.p[j]+=1ll*e[n-i]*tk[i].p[j]%mod)%=mod;
	}
	return c;
}
qq ad(qq a){
	qq c;c.val=(a.val+a.p[n-1])%mod;
	for(int i=0;i<n;i++)
		c.p[i]=((i?a.p[i-1]:0)+1ll*a.p[n-1]*e[n-i]%mod)%mod;
	return c;
}
int a[510][510];
int main(){
	cin>>n>>m>>K;
	if(n==1){return printf("%lld",K),0;}
	int S=0;
	for(int i=1;i<=n;i++)scanf("%d",&e[i]),S+=e[i];
	S=qpow(S,mod-2);
	for(int i=1;i<=n;i++)e[i]=1ll*e[i]*S%mod;
	//f[0]=0
	//i>=1 f[i]=1+sum j<=n f[i-j]*e[j]
	//求f[K]
	base[0].p[1]=1;
	for(int i=1;i<=61;i++)base[i]=mul(base[i-1],base[i-1]);
	tk[0]=fik(m-n+1);
	//m-n+1
	//f(x+1)=w_+f_(i+1)*cx_i
	for(int i=1;i+1<n;i++)tk[i]=ad(tk[i-1]);
	for(int i=1;i<n;i++){
		ans[i]=qb();
		for(int j=0;j+1<n;j++)tk[j]=tk[j+1];
		tk[n-1]=ans[i];
	}
	for(int i=1;i<n;i++){
		a[i][n]=-ans[i].val,a[i][i]--;
		for(int j=1;j<n;j++)(a[i][j]+=ans[i].p[j])%=mod;
	}
	for(int i=1;i<n;i++){
		for(int j=i;j<n;j++)if(a[j][i]){swap(a[i],a[j]);break;}
		for(int j=1;j<n;j++)if(i!=j){
			int tp=1ll*a[j][i]*qpow(a[i][i],mod-2)%mod;
			for(int k=i+1;k<=n;k++)(a[j][k]-=1ll*tp*a[i][k]%mod)%=mod;
		}
	}
	qq an=fik(K);
	int ans=an.val;
	for(int i=1;i<n;i++)(ans+=1ll*an.p[i]*a[i][n]%mod*qpow(a[i][i],mod-2)%mod)%=mod;
	return printf("%d",(ans+mod)%mod),0;
}
posted @ 2022-10-18 17:40  grass8woc  阅读(140)  评论(2编辑  收藏  举报