【CF1286D】LCC
题目
题目链接:https://codeforces.com/problemset/problem/1286/D
一条无限长的管道中有 \(n\) 个粒子,第 \(i\) 个粒子的位置为 \(x_i\),保证对于 \(1\leq i <n\),有 \(x_i<x_{i+1}\)。
一次实验开始时,第 \(i\) 个粒子会获得 \(v_i\) 的初速度,并有 \(\frac{p_i}{100}\) 的概率向右移动,\(1-\frac{p_i}{100}\) 的概率向左移动。
当任意两个粒子移动到相同位置的时候,我们称之为一次碰撞。一次实验耗费的时间为所有碰撞发生时间的最小值。特别地,如果没有发生任何碰撞,我们认为这次实验耗费的时间为 \(0\)。
求一次实验期望耗费的时间,对 \(998244353\) 取模后输出。
\(n\leq 10^5\)。
思路
第一次碰撞显然发生在相邻的两个粒子之间。并且至多有 \(2(n-1)\) 种不同的情况。两种情况不同当且仅当碰撞的粒子不同或者两个碰撞粒子中存在一个方向不同。
考虑分别计算这 \(2(n-1)\) 种情况作为第一次碰撞的期望。把所有情况按照碰撞时间排序,那么对于第 \(i\) 种情况,我们需要计算的是前 \(i-1\) 种情况都不出现,且第 \(i\) 中情况恰好出现的期望。
记 \(g[i][0/1][0/1]\) 表示粒子 \(i-1\) 往左/右,粒子 \(i\) 往左/右这种情况能不能出现。那么如果我们钦定第 \(i\) 种情况是第一次出现的碰撞,那么前 \(i-1\) 中情况分别对应了一个 \(g\) 为 \(0\);且第 \(i\) 种情况所对应的两个粒子之间另外三种情况都为 \(0\)。
设 \(f[i][0/1]\) 表示选到了第 \(i\) 个粒子的方向为左/右,在前 \(i\) 个粒子方向满足所有 \(g\) 的条件的前提下的期望。
转移为
直接 dp 需要对至多 \(2(n-1)\) 种情况都跑一次,复杂度是 \(O(n^2)\) 的。观察到从第 \(i\) 种情况到第 \(i+1\) 种情况,\(g\) 改变的数量是 \(O(1)\) 的,所以用动态 dp 搞一下就可以了。
时间复杂度 \(O(k^3n\log n)\),其中 \(k=2\)。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=100010,MOD=998244353,INV=828542813;
int n,m,a[N],v[N],p[N];
ll ans;
struct node
{
int x,y,z,s,t;
}b[N*2];
bool cmp(node x,node y)
{
return 1.0*x.s/x.t<1.0*y.s/y.t;
}
ll fpow(ll x,ll k)
{
ll ans=1;
for (;k;k>>=1,x=x*x%MOD)
if (k&1) ans=ans*x%MOD;
return ans;
}
struct Matrix
{
ll a[3][3];
Matrix() { memset(a,0,sizeof(a)); }
friend Matrix operator *(Matrix a,Matrix b)
{
Matrix c;
for (int i=1;i<=2;i++)
for (int j=1;j<=2;j++)
for (int k=1;k<=2;k++)
c.a[i][j]=(c.a[i][j]+a.a[i][k]*b.a[k][j])%MOD;
return c;
}
}g;
struct SegTree
{
Matrix f[N*4];
void build(int x,int l,int r)
{
if (l==r)
{
f[x].a[1][1]=1LL*(100-p[l])*INV%MOD; f[x].a[1][2]=1LL*p[l]*INV%MOD;
if (l>1) f[x].a[2][1]=1LL*(100-p[l])*INV%MOD,f[x].a[2][2]=1LL*p[l]*INV%MOD;
return;
}
int mid=(l+r)>>1;
build(x*2,l,mid); build(x*2+1,mid+1,r);
f[x]=f[x*2]*f[x*2+1];
}
void update1(int x,int l,int r,int k,int y,int z)
{
if (l==r) { f[x].a[y][z]=0; return; }
int mid=(l+r)>>1;
if (k<=mid) update1(x*2,l,mid,k,y,z);
if (k>mid) update1(x*2+1,mid+1,r,k,y,z);
f[x]=f[x*2]*f[x*2+1];
}
void update2(int x,int l,int r,int k)
{
if (l==r) { swap(f[x],g); return; }
int mid=(l+r)>>1;
if (k<=mid) update2(x*2,l,mid,k);
if (k>mid) update2(x*2+1,mid+1,r,k);
f[x]=f[x*2]*f[x*2+1];
}
}seg;
int main()
{
scanf("%d%d%d%d",&n,&a[1],&v[1],&p[1]);
for (int i=2;i<=n;i++)
{
scanf("%d%d%d",&a[i],&v[i],&p[i]);
b[++m]=(node){i,2,1,a[i]-a[i-1],v[i]+v[i-1]};
if (v[i]>v[i-1]) b[++m]=(node){i,1,1,a[i]-a[i-1],v[i]-v[i-1]};
if (v[i]<v[i-1]) b[++m]=(node){i,2,2,a[i]-a[i-1],v[i-1]-v[i]};
}
sort(b+1,b+1+m,cmp);
seg.build(1,1,n);
for (int i=1;i<=m;i++)
{
int x=b[i].x,y=b[i].y,z=b[i].z,s=b[i].s,t=b[i].t;
memset(g.a,0,sizeof(g.a));
g.a[y][z]=1LL*INV*((z==1)?(100-p[x]):p[x])%MOD;
seg.update2(1,1,n,x);
ans=(ans+(seg.f[1].a[1][1]+seg.f[1].a[1][2])*s%MOD*fpow(t,MOD-2))%MOD;
seg.update2(1,1,n,x);
seg.update1(1,1,n,x,y,z);
}
cout<<ans;
return 0;
}