ARC113F - Social Distance

\(n+1\)个坐标\(x_0=0,x_1,\dots,x_n\)\(a_i\)\((x_{i-1},x_i)\)内等概率随机。

\(\min a_i-a_{i-1}\)的期望。

\(n\le 20\)


算出结果大于等于\(z\)的概率\(f(z)\),然后\(\int_{0}^{+\infty} f(z)dz\)就是答案。

假设固定\(z\),把区间\((x_{i-1},x_i)\)转化为\((x_{i-1}-(i-1)z,x_{i}-(i-1)z)\)。然后要求为\(a_i\)单调递增的概率。

我的粗暴做法:设\(g_i(x)\)表示\(a_i=x\)时,考虑了前\(i\)个,单调递增的概率是多少。显然有\(g_i(x)=\frac{1}{r_{i-1}-l_{i-1}}\int_{-\infty}^x g_{i-1}(y)dy,x\in(l_i,r_i)\)。然后强行维护,即枚举关键点的大小关系(有\(O(n^2)\)种)并计算出对应\(z\)的取值范围,每个\(g_i(x)\)是个\(O(n)\)段函数,每段函数是个\(O(n)\)次多项式,多项式中的每个系数是个关于\(z\)\(O(n)\)的多项式。于是时间是\(O(n^6)\)

dyp的神仙做法:设\(dp_{i,j}\)表示到了第\(i\)个关键点,已经把前\(j\)个数放进来的概率。转移的时候枚举\(k\),把第\(j+1\dots k\)个数丢到关键点\(i\)\(i+1\)之间。时间也是\(O(n^6)\)。(似乎比我的做法好写多了)

实现细节:搞关键点的大小关系的时候,可以\(O(n^2)\)枚举两个点大小关系改变的时间,然后排序。注意可能存在时间相同而产生的诡异情况,所以在交换之前判断一下大小关系是否改变。


using namespace std;
#include <bits/stdc++.h>
#define N 25
#define ll long long
#define mo 998244353
#define fi first
#define se second
#define mp(x,y) make_pair(x,y)
ll qpow(ll x,ll y=mo-2){
	ll r=1;
	for (;y;y>>=1,x=x*x%mo)
		if (y&1)
			r=r*x%mo;
	return r;
}
int n;
int a[N];
ll inv[N];
ll ans;
struct Line{
	int b,k;
} li[N*2];
int re[N*2];
struct Num{int v[N];void clear(){memset(v,0,sizeof v);}};
Num operator+(const Num &a,const Num &b){static Num c;for (int i=0;i<=n;++i) c.v[i]=(a.v[i]+b.v[i])%mo;return c;}
Num operator-(const Num &a,const Num &b){static Num c;for (int i=0;i<=n;++i) c.v[i]=(a.v[i]-b.v[i]+mo)%mo;return c;}
Num operator*(const Num &a,ll b){static Num c;for (int i=0;i<=n;++i) c.v[i]=a.v[i]*b%mo;return c;}
Num operator*(const Num &a,Line l){static Num c;c.v[0]=(ll)a.v[0]*l.b%mo;for (int i=1;i<=n;++i) c.v[i]=((ll)a.v[i]*l.b+(ll)a.v[i-1]*l.k)%mo;return c;}
struct poly{
	Num w[N];
	void clear(){memset(w,0,sizeof w);}
	Num y(Line x){
		Num sum;sum.clear();
		x.k=(x.k+mo)%mo;
		for (int i=n;i>=0;--i){
			sum=sum*x;
			sum=sum+w[i];
		}
		return sum;
	}	
};
poly operator+(const poly &a,const poly &b){static poly c;for (int i=0;i<=n;++i) c.w[i]=a.w[i]+b.w[i];return c;}
poly operator-(const poly &a,const poly &b){static poly c;for (int i=0;i<=n;++i) c.w[i]=a.w[i]-b.w[i];return c;}
poly operator*(const poly &a,ll b){static poly c;for (int i=0;i<=n;++i) c.w[i]=a.w[i]*b;return c;}
void getint(poly &p,poly &q){
	p.w[0].clear();
	for (int i=0;i<n;++i)
		p.w[i+1]=q.w[i]*inv[i+1];
}
struct Func{poly f[N*2];void clear(){memset(f,0,sizeof f);}};
void getint(Func &p,Func &q,int l,int r){
	Num tmp;tmp.clear();
	for (int i=1;i<=r;++i){
		getint(p.f[i],q.f[i]);
		p.f[i].w[0]=p.f[i-1].y(li[i-1])-p.f[i].y(li[i-1]);
	}
	for (int i=1;i<=l;++i) p.f[i].clear();
	for (int i=r+1;i<n*2;++i) p.f[i].clear();
}
Func g[N];
void calc(ll L,ll R){
	if (L==R)
		return;
	g[1].clear();
	for (int i=re[1*2-1]+1;i<=re[1*2];++i)
		g[1].f[i].w[0].v[0]=1;
	for (int i=2;i<=n;++i){
		getint(g[i],g[i-1],re[i*2-1],re[i*2]);
		ll tmp=qpow(a[i-1]-a[i-2]);
		for (int j=1;j<n*2;++j)
			g[i].f[j]=g[i].f[j]*tmp;
	}
	getint(g[n+1],g[n],0,n*2-1);
	Num tmp=g[n+1].f[n*2-1].y(li[n*2-1])*qpow(a[n]-a[n-1]);
	ll pL=1,pR=1;
	for (int i=0;i<=n;++i){
		pL=pL*L%mo;
		pR=pR*R%mo;
		(ans+=tmp.v[i]*inv[i+1]%mo*(pR-pL+mo))%=mo;
	}
	ans=(ans+mo)%mo;
}
bool cmpo(pair<int,int> x,pair<int,int> y){
	x={re[x.fi],re[x.se]};
	y={re[y.fi],re[y.se]};
	return -(ll)(li[x.fi].b-li[x.se].b)*(li[y.fi].k-li[y.se].k)<-(ll)(li[y.fi].b-li[y.se].b)*(li[x.fi].k-li[x.se].k);
}
void doit(){
	li[0]={a[0],0},re[1*2-1]=0;
	for (int i=1;i<n;++i){
		li[i*2-1]={a[i],-i},re[(i+1)*2-1]=i*2-1;
		li[i*2]={a[i],-(i-1)},re[i*2]=i*2;
	}
	li[n*2-1]={a[n],-(n-1)},re[n*2]=n*2-1;
	static pair<int,int> o[N*N*4];
	int k=0;
	for (int i=1;i<=n*2;++i)
		for (int j=1;j<=n*2;++j)
			if (li[re[i]].k<li[re[j]].k && li[re[i]].b>li[re[j]].b)
				o[++k]=mp(j,i);
	sort(o+1,o+k+1,cmpo);
	ll lst=0;
	for (int i=1;i<=k;++i){
		int x=re[o[i].fi],y=re[o[i].se];
		if (x>y) continue;
		ll tmp=(-(li[x].b-li[y].b)*qpow(li[x].k-li[y].k)%mo+mo)%mo;
		calc(lst,tmp);
		swap(li[x],li[y]);
		swap(re[o[i].fi],re[o[i].se]);
		lst=tmp;
	}
}
int main(){
//	freopen("in.txt","r",stdin);
	scanf("%d",&n);
	for (int i=0;i<=n;++i)
		scanf("%d",&a[i]);
	inv[1]=1;
	for (int i=2;i<=n+1;++i)
		inv[i]=(mo-mo/i)*inv[mo%i]%mo;
	doit();
	ans=(ans+mo)%mo;
	printf("%lld\n",ans);
	return 0;
}
posted @ 2021-02-25 19:12  jz_597  阅读(213)  评论(0编辑  收藏  举报