题解 A. Equipment Upgrade 2022“杭电杯”中国大学生算法设计超级联赛(3)
cdq 分治 FFT 优化概率 dp 的好题!
【大意】
一个等级初始为 \(0\) 的武器。当其等级为 \(i\) 时,升级到 \(i+1\) 的一次开销为 \(c_i\) ;但成功概率为 \(p_i\) 。若失败,武器从等级 \(i\) 掉入等级 \(i-j\) 的概率为 \(\displaystyle {w_j\over \sum_{k=1}^i w_k}\) 。求升级到 \(n\) 的期望开销。
【分析】
非常显然的概率 dp ,但如果和其他概率 dp 一样,考虑当前状态到目标状态的期望开销,则在状态 \(g_i\) 中会含有 \(g_j(j<i)\) 项,不利于求解。
考虑到前一场杭电多校的题目,我们设 \(g_n\) 表示 \(0\) 第一次升到 \(n\) 级的期望开销。
则可以列出转移方程:\(\displaystyle g_{i+1}=g_i+c_i+(1-p_i)\sum_{j=1}^i(g_{i+1}-g_{i-j})\cdot {w_j\over \sum_{k=1}^i w_k}\)
方便起见,我们后面用 \(\displaystyle sumw_k=\sum_{i=1}^k w_i\) 来表示分母
由于 \(\displaystyle \sum_{j=1}^i g_{i+1}\cdot {w_j\over sumw_i}=g_{i+1}\) ,整理上式得到:
\(\displaystyle g_{i+1}={g_i+c_i\over p_i}-{1-p_i\over p_i\cdot sumw_i}\sum_{j=1}^i w_j\cdot g_{i-j}\)
后面这个形式是个非常显然的加法卷积形式,因此可以用 cdq 分治 FFT 在 \(O(n\log^2 n)\) 的时间内求解,仅需要在区间长度为 \(1\) 时,特判乘上 \(-{1-p_i\over p_i\cdot sumw_i}\) ,并加上 \({g_i+c_i\over p_i}\) 作为结果
然而,左侧的系数为 \(g_{i+1}\) ,并不是 \(g_i\) ,不能直接计算
可以考虑令 \(\displaystyle h_i=g_{i+1}={g_i+c_i\over p_i}-{1-p_i\over p_i\cdot sumw_i}\sum_{j=1}^i w_j\cdot g_{i-j}\)
那么,递归到区间长度为 \(1\) 时, \(\displaystyle \sum_{j=1}^i w_j\cdot g_{i-j}\) 的答案记录在 \(g_i\) 上
于是,我们可以先用 \(g_i\) 更新 \(h_i\) : \(\displaystyle h_i\leftrightarrow {h_{i-1}+c_i\over p_i}-{1-p_i\over p_i\cdot sumw_i}\cdot g_i\)
随后,\(g_i\) 的值需要更新为其正确的结果,即 \(\displaystyle g_i\leftarrow h_{i-1}\)
由于每次是先计算 \(g_i\) 表示的卷积结果,再用卷积结果更新 \(h_i\) ,最后再用 \(h_{i-1}\) 更新 \(g_i\) ;一直是用已知的值更新未知的值,故 cdq 分治 FFT 的正确性可以保证
【代码】
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define sz(a) (int)a.size()
#define de(x) cout << #x <<" = "<<(x)<<endl
#define dd(x) cout << #x <<" = "<<(x)<<" "
#define all(a) a.begin(), a.end()
#define pw(x) (1ll<<(x))
#define lc(x) ((x)<<1)
#define rc(x) ((x)<<1|1)
#define rsz(a, x) a.resize(x)
const int P=998244353, MAXN=1e5+10;
inline int kpow(int a, int x, int p=P) { int ans=1; for(;x;x>>=1, a=(ll)a*a%p) if(x&1) ans=(ll)ans*a%p; return ans; }
inline int exgcd(int a, int b, int &x, int &y) {
static int g;
return b?(exgcd(b, a%b, y, x), y-=a/b*x, g):(x=1, y=0, g=a);
}
inline int inv(int a, int p=P) {
static int x, y;
return exgcd(a, p, x, y)==1?(x<0?x+p:x):(-1);
}
const int LimBit=18, M=1<<LimBit<<1;
namespace Poly {
const int G=3;
struct vir {
int v;
vir(int v_=0):v(v_>=P?v_-P:v_) {}
inline vir operator + (const vir &x) const { return vir(v+x.v); }
inline vir operator - (const vir &x) const { return vir(v+P-x.v); }
inline vir operator * (const vir &x) const { return vir((ll)v*x.v%P); }
inline vir operator - () const { return vir(P-v); }
inline vir operator ! () const { return vir(inv(v)); }
inline operator int() const { return v; }
};
struct poly : public vector<vir> {
inline friend ostream& operator << (ostream& out, const poly &p) {
if(!p.empty()) out<<(int)p[0];
for(int i=1; i<sz(p); ++i) out<<" "<<(int)p[i];
return out;
}
};
int N, N_, Stk[M], curStk, rev[M];
vir invN, Inv[M], w[2][M];
inline void init() {
N_=-1;
curStk=0;
Inv[1]=1;
for(int i=2; i<M; ++i)
Inv[i]=-vir(P/i)*Inv[P%i];
}
void work() {
if(N_==N) return ;
N_=N;
int d=__builtin_ctz(N);
vir x(kpow(G, (P-1)/N)), y=!x;
w[0][0]=w[1][0]=1;
for(int i=1; i<N; ++i) {
rev[i]=(rev[i>>1]>>1)|((i&1)<<(d-1));
w[0][i]=x*w[0][i-1], w[1][i]=y*w[1][i-1];
}
invN=!vir(N);
}
inline void FFT(vir a[M], int f) {
static auto make = [=](vir w, vir &a, vir &b) { w=w*a; a=b-w; b=b+w; };
for(int i=0; i<N; ++i) if(i<rev[i]) swap(a[i], a[rev[i]]);
for(int i=1; i<N; i<<=1)
for(int j=0, t=N/(i<<1); j<N; j+=i<<1)
for(int k=0, l=0; k<i; ++k, l+=t)
make(w[f][l], a[j+k+i], a[j+k]);
if(f) for(int i=0; i<N; ++i) a[i]=a[i]*invN;
}
vir p1[M], p0[M];
inline void get_mul(poly &a, poly &b, int na, int nb) {
for(N=1; N<na+nb-1; N<<=1);
for(int i=0; i<na; ++i) p1[i]=(int)a[i]; for(int i=na; i<N; ++i) p1[i]=0;
for(int i=0; i<nb; ++i) p0[i]=(int)b[i]; for(int i=nb; i<N; ++i) p0[i]=0;
work(); FFT(p1, 0); FFT(p0, 0);
for(int i=0; i<N; ++i) p1[i]=p1[i]*p0[i];
FFT(p1, 1);
rsz(a, na+nb-1); for(int i=0; i<sz(a); ++i) a[i]=p1[i];
}
}
using Poly::poly;
using Poly::vir;
poly a, f, w, tmp;
int n;
vir p[MAXN], c[MAXN], sumw;
void work_cdq(poly &f, poly &g, poly &h, int l, int r) {
if(l>r) return ;
if(l==r) {
if(!l) return ;
sumw=sumw+w[l];
vir fi=h[l-1], ci=c[l], pi=p[l];
h[l]=(fi+ci)*(!pi)-(vir(1)-pi)*(!(pi*sumw))*g[l];
g[l]=fi;
return ;
}
int mid=l+r>>1;
work_cdq(f, g, h, l, mid);
rsz(a, mid-l+1);
for(int i=0, j=l; j<=mid; ++i, ++j) a[i]=g[j];
get_mul(a, f, sz(a), r-l+1);
for(int i=mid+1, j=i-l; i<=r&&j<=sz(a); ++i, ++j)
g[i]=g[i]+a[j];
work_cdq(f, g, h, mid+1, r);
}
inline void get_cdq(poly &f, poly &g, poly &h, int n, int g0=1) {
sumw=0;
g.clear(); rsz(g, n); g[0]=g0;
h.clear(); rsz(h, n); h[0]=c[0];
work_cdq(f, g, h, 0, n-1);
}
inline int ans() {
static vir inv100=!vir(100);
cin>>n;
for(int i=0, v; i<n; ++i) {
cin>>v;
p[i]=vir(v)*inv100;
cin>>v;
c[i]=v;
}
rsz(w, n+1);
w[0]=w[n]=0;
for(int i=1, v; i<n; ++i) {
cin>>v;
w[i]=v;
}
get_cdq(w, f, tmp, n+1, 0);
return f[n];
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
int t; cin>>t;
while(t--)
cout<<ans()<<"\n";
cout.flush();
return 0;
}