题解 A. Equipment Upgrade 2022“杭电杯”中国大学生算法设计超级联赛(3)

传送门

cdq 分治 FFT 优化概率 dp 的好题!


【大意】

一个等级初始为 \(0\) 的武器。当其等级为 \(i\) 时,升级到 \(i+1\) 的一次开销为 \(c_i\) ;但成功概率为 \(p_i\) 。若失败,武器从等级 \(i\) 掉入等级 \(i-j\) 的概率为 \(\displaystyle {w_j\over \sum_{k=1}^i w_k}\) 。求升级到 \(n\) 的期望开销。


【分析】

非常显然的概率 dp ,但如果和其他概率 dp 一样,考虑当前状态到目标状态的期望开销,则在状态 \(g_i\) 中会含有 \(g_j(j<i)\) 项,不利于求解。

考虑到前一场杭电多校的题目,我们设 \(g_n\) 表示 \(0\) 第一次升到 \(n\) 级的期望开销。

则可以列出转移方程:\(\displaystyle g_{i+1}=g_i+c_i+(1-p_i)\sum_{j=1}^i(g_{i+1}-g_{i-j})\cdot {w_j\over \sum_{k=1}^i w_k}\)

方便起见,我们后面用 \(\displaystyle sumw_k=\sum_{i=1}^k w_i\) 来表示分母

由于 \(\displaystyle \sum_{j=1}^i g_{i+1}\cdot {w_j\over sumw_i}=g_{i+1}\) ,整理上式得到:

\(\displaystyle g_{i+1}={g_i+c_i\over p_i}-{1-p_i\over p_i\cdot sumw_i}\sum_{j=1}^i w_j\cdot g_{i-j}\)

后面这个形式是个非常显然的加法卷积形式,因此可以用 cdq 分治 FFT 在 \(O(n\log^2 n)\) 的时间内求解,仅需要在区间长度为 \(1\) 时,特判乘上 \(-{1-p_i\over p_i\cdot sumw_i}\) ,并加上 \({g_i+c_i\over p_i}\) 作为结果

然而,左侧的系数为 \(g_{i+1}\) ,并不是 \(g_i\) ,不能直接计算

可以考虑令 \(\displaystyle h_i=g_{i+1}={g_i+c_i\over p_i}-{1-p_i\over p_i\cdot sumw_i}\sum_{j=1}^i w_j\cdot g_{i-j}\)

那么,递归到区间长度为 \(1\) 时, \(\displaystyle \sum_{j=1}^i w_j\cdot g_{i-j}\) 的答案记录在 \(g_i\)

于是,我们可以先用 \(g_i\) 更新 \(h_i\)\(\displaystyle h_i\leftrightarrow {h_{i-1}+c_i\over p_i}-{1-p_i\over p_i\cdot sumw_i}\cdot g_i\)

随后,\(g_i\) 的值需要更新为其正确的结果,即 \(\displaystyle g_i\leftarrow h_{i-1}\)

由于每次是先计算 \(g_i\) 表示的卷积结果,再用卷积结果更新 \(h_i\) ,最后再用 \(h_{i-1}\) 更新 \(g_i\) ;一直是用已知的值更新未知的值,故 cdq 分治 FFT 的正确性可以保证


【代码】

#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define sz(a) (int)a.size()
#define de(x) cout << #x <<" = "<<(x)<<endl
#define dd(x) cout << #x <<" = "<<(x)<<" "
#define all(a) a.begin(), a.end()
#define pw(x) (1ll<<(x))
#define lc(x) ((x)<<1)
#define rc(x) ((x)<<1|1)
#define rsz(a, x) a.resize(x)
const int P=998244353, MAXN=1e5+10;
inline int kpow(int a, int x, int p=P) { int ans=1; for(;x;x>>=1, a=(ll)a*a%p) if(x&1) ans=(ll)ans*a%p; return ans; }
inline int exgcd(int a, int b, int &x, int &y) {
	static int g;
	return b?(exgcd(b, a%b, y, x), y-=a/b*x, g):(x=1, y=0, g=a);
}
inline int inv(int a, int p=P) {
	static int x, y;
	return exgcd(a, p, x, y)==1?(x<0?x+p:x):(-1);
}

const int LimBit=18, M=1<<LimBit<<1;
namespace Poly {
	const int G=3;
	struct vir {
		int v;
		vir(int v_=0):v(v_>=P?v_-P:v_) {}
		inline vir operator + (const vir &x) const { return vir(v+x.v); }
		inline vir operator - (const vir &x) const { return vir(v+P-x.v); }
		inline vir operator * (const vir &x) const { return vir((ll)v*x.v%P); }
		
		inline vir operator - () const { return vir(P-v); }
		inline vir operator ! () const { return vir(inv(v)); }
		inline operator int() const { return v; }
	};
	struct poly : public vector<vir> {
		inline friend ostream& operator << (ostream& out, const poly &p) {
			if(!p.empty()) out<<(int)p[0];
			for(int i=1; i<sz(p); ++i) out<<" "<<(int)p[i];
			return out;
		}
	};
	
	int N, N_, Stk[M], curStk, rev[M];
	vir invN, Inv[M], w[2][M];
	inline void init() {
		N_=-1;
		curStk=0;
		Inv[1]=1;
		for(int i=2; i<M; ++i)
			Inv[i]=-vir(P/i)*Inv[P%i];
	}
	void work() {
		if(N_==N) return ;
		N_=N;
		int d=__builtin_ctz(N);
		vir x(kpow(G, (P-1)/N)), y=!x;
		w[0][0]=w[1][0]=1;
		for(int i=1; i<N; ++i) {
			rev[i]=(rev[i>>1]>>1)|((i&1)<<(d-1));
			w[0][i]=x*w[0][i-1], w[1][i]=y*w[1][i-1];
		}
		invN=!vir(N);
	}
	inline void FFT(vir a[M], int f) {
		static auto make = [=](vir w, vir &a, vir &b) { w=w*a; a=b-w; b=b+w; };
		for(int i=0; i<N; ++i) if(i<rev[i]) swap(a[i], a[rev[i]]);
		for(int i=1; i<N; i<<=1)
			for(int j=0, t=N/(i<<1); j<N; j+=i<<1)
				for(int k=0, l=0; k<i; ++k, l+=t)
					make(w[f][l], a[j+k+i], a[j+k]);
		if(f) for(int i=0; i<N; ++i) a[i]=a[i]*invN;
	}
	
	vir p1[M], p0[M];
	inline void get_mul(poly &a, poly &b, int na, int nb) {
		for(N=1; N<na+nb-1; N<<=1);
		for(int i=0; i<na; ++i) p1[i]=(int)a[i]; for(int i=na; i<N; ++i) p1[i]=0;
		for(int i=0; i<nb; ++i) p0[i]=(int)b[i]; for(int i=nb; i<N; ++i) p0[i]=0;
		work(); FFT(p1, 0); FFT(p0, 0);
		for(int i=0; i<N; ++i) p1[i]=p1[i]*p0[i];
		FFT(p1, 1);
		rsz(a, na+nb-1); for(int i=0; i<sz(a); ++i) a[i]=p1[i];
	}
}
using Poly::poly;
using Poly::vir;
poly a, f, w, tmp;
int n;
vir p[MAXN], c[MAXN], sumw;

void work_cdq(poly &f, poly &g, poly &h, int l, int r) {
	if(l>r) return ;
	if(l==r) {
		if(!l) return ;
		sumw=sumw+w[l];
		vir fi=h[l-1], ci=c[l], pi=p[l];
		h[l]=(fi+ci)*(!pi)-(vir(1)-pi)*(!(pi*sumw))*g[l];
		g[l]=fi;
		return ;
	}
	int mid=l+r>>1;
	work_cdq(f, g, h, l, mid);
	rsz(a, mid-l+1);
	for(int i=0, j=l; j<=mid; ++i, ++j) a[i]=g[j];
	get_mul(a, f, sz(a), r-l+1);
	for(int i=mid+1, j=i-l; i<=r&&j<=sz(a); ++i, ++j)
		g[i]=g[i]+a[j];
	work_cdq(f, g, h, mid+1, r);
}
inline void get_cdq(poly &f, poly &g, poly &h, int n, int g0=1) {
	sumw=0;
	g.clear(); rsz(g, n); g[0]=g0;
	h.clear(); rsz(h, n); h[0]=c[0];
	work_cdq(f, g, h, 0, n-1);
}
inline int ans() {
	static vir inv100=!vir(100);
	cin>>n;
	for(int i=0, v; i<n; ++i) {
		cin>>v;
		p[i]=vir(v)*inv100;
		cin>>v;
		c[i]=v;
	}
	rsz(w, n+1);
	w[0]=w[n]=0;
	for(int i=1, v; i<n; ++i) {
		cin>>v;
		w[i]=v;
	}
	
	get_cdq(w, f, tmp, n+1, 0);
	return f[n];
}
int main() {
	ios::sync_with_stdio(0);
	cin.tie(0); cout.tie(0);
	int t; cin>>t;
	while(t--)
		cout<<ans()<<"\n";
	cout.flush();
	return 0;
}
posted @ 2022-07-26 19:15  JustinRochester  阅读(116)  评论(0编辑  收藏  举报