LOJ#565. 「LibreOJ Round #10」mathematican 的二进制 分治,FFT,概率期望

原文链接www.cnblogs.com/zhouzhendong/p/LOJ565.html

前言

标算真是优美可惜这题直接暴力FFT算一算就solved了。

题解

首先,假装没有进位,考虑解决这个问题。

对于每一位,考虑作用在其之上的概率为 \(p\) 的操作,构建多项式 \(((1-p) + px )\),那么将一个位置上所有这样的多项式乘起来之后, \(x^k\) 项系数就代表这个位置被操作 \(k\) 次的概率。对答案的贡献就是 \(k\times\) \(x^k\) 项系数。

考虑进位。

从低位向高位推,进位就相当于将多项式系数两个两个合并。

从低位向高位,考虑将每一位的多项式两个两个合并之后乘到高一位的多项式上,就可以得出每一位被变换任意次的真的概率。

考虑这个过程的复杂度:

一个多项式的长度对其高位长度的贡献依次是 $len, len / 2, len / 4, len / 2 ^ 3 ,\cdots $ ,所以总贡献是 \(O(len)\) 的。由于多项式总长度为 \(O(n+m)\) ,又由于乘法在FFT时有个log,所以这部分的总时间复杂度为 \(O((n+m)\log m)\)

而前一半需要分治FFT来算多项式,复杂度为 \(O(m\log ^ 2m)\)

总时间复杂度为 \(O(n\log m + m\log ^ 2 m)\)

代码

#include <bits/stdc++.h>
#define clr(x) memset(x,0,sizeof x)
#define For(i,a,b) for (int i=(a);i<=(b);i++)
#define Fod(i,b,a) for (int i=(b);i>=(a);i--)
#define fi first
#define se second
#define pb(x) push_back(x)
#define mp(x,y) make_pair(x,y)
#define outval(x) cerr<<#x" = "<<x<<endl
#define outtag(x) cerr<<"---------------"#x"---------------"<<endl
#define outarr(a,L,R) cerr<<#a"["<<L<<".."<<R<<"] = ";\
						For(_x,L,R)cerr<<a[_x]<<" ";cerr<<endl;
using namespace std;
typedef long long LL;
typedef vector <int> vi;
LL read(){
	LL x=0,f=0;
	char ch=getchar();
	while (!isdigit(ch))
		f|=ch=='-',ch=getchar();
	while (isdigit(ch))
		x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
	return f?-x:x;
}
const int N=1<<19,mod=998244353;
int Pow(int x,int y){
	int ans=1;
	for (;y;y>>=1,x=(LL)x*x%mod)
		if (y&1)
			ans=(LL)ans*x%mod;
	return ans;
}
void Add(int &x,int y){
	if ((x+=y)>=mod)
		x-=mod;
}
void Del(int &x,int y){
	if ((x-=y)<0)
		x+=mod;
}
int Add(int x){
	return x>=mod?x-mod:x;
}
int Del(int x){
	return x<0?x+mod:x;
}
namespace fft{
	int w[N],R[N];
	void init(int n){
		int d=0;
		while ((1<<d)<n)
			d++;
		For(i,0,n-1)
			R[i]=(R[i>>1]>>1)|((i&1)<<(d-1));
		w[0]=1,w[1]=Pow(3,(mod-1)/n);
		For(i,2,n-1)
			w[i]=(LL)w[i-1]*w[1]%mod;
	}
	void FFT(int *a,int n,int flag){
		if (flag<0)
			reverse(w+1,w+n);
		For(i,0,n-1)
			if (i<R[i])
				swap(a[i],a[R[i]]);
		for (int t=n>>1,d=1;d<n;d<<=1,t>>=1)
			for (int i=0;i<n;i+=d<<1)
				for (int j=0;j<d;j++){
					int tmp=(LL)w[t*j]*a[i+j+d]%mod;
					a[i+j+d]=Del(a[i+j]-tmp);
					Add(a[i+j],tmp);
				}
		if (flag<0){
			reverse(w+1,w+n);
			int inv=Pow(n,mod-2);
			For(i,0,n-1)
				a[i]=(LL)a[i]*inv%mod;
		}
	}
}
using fft::FFT;
int n,m;
vi a[N],p[N];
vi operator * (vi A,vi B){
	static int a[N],b[N];
	int n=1;
	while (n<A.size()+B.size())
		n<<=1;
	For(i,0,n-1)
		a[i]=b[i]=0;
	For(i,0,(int)A.size()-1)
		a[i]=A[i];
	For(i,0,(int)B.size()-1)
		b[i]=B[i];
	fft::init(n);
	FFT(a,n,1),FFT(b,n,1);
	For(i,0,n-1)
		a[i]=(LL)a[i]*b[i]%mod;
	FFT(a,n,-1);
	vi ans;
	For(i,0,n-1)
		ans.pb(a[i]);
	while (ans.size()>1&&!ans.back())
		ans.pop_back();
	return ans;
}
vi build(int *a,int L,int R){
	if (L==R)
		return (vi){Del(1-a[L]),a[L]};
	int mid=(L+R)>>1;
	vi lp=build(a,L,mid);
	vi rp=build(a,mid+1,R);
	return lp*rp;
}
void Getp(int x){
	int s=a[x].size();
	if (!s)
		p[x]=(vi){1};
	else
		p[x]=build(&a[x][0],0,s-1);
}
int main(){
	n=read()+23,m=read();
	For(i,1,m){
		int p=read(),x=read(),y=read();
		x=(LL)x*Pow(y,mod-2)%mod;
		a[p].pb(x);
	}
	int ans=0;
	For(i,0,n){
		Getp(i);
		if (i>0)
			p[i]=p[i]*p[i-1];
		For(j,0,(int)p[i].size()-1)
			Add(ans,(LL)p[i][j]*j%mod);
		if (p[i].size()&1)
			p[i].pb(0);
		int s=p[i].size()/2;
		For(j,0,s-1)
			p[i][j]=Add(p[i][j*2]+p[i][j*2+1]);
		while (p[i].size()>s)
			p[i].pop_back();
	}
	cout<<ans<<endl;
	return 0;
}
posted @ 2019-06-04 16:43  zzd233  阅读(319)  评论(0编辑  收藏  举报