超简单(super)
超简单(super)
题目描述
有一个n面的骰子,第i面的数是vi,朝上的概率是pi。
教室的最后一排有一个人,不停地抛这个骰子,直到某一面朝上了两次,就停止抛骰子,但他不知道所有朝上的面的数字的和的期望E是多少。
老班一脸嘲讽:“这不是超简单嘛。”
输入
输入的第一行包含一个正整数n。
输入的第二行包含n个正整数,表示vi。
输入的第三行包含n个非负整数,表示模998244353意义下的pi,保证所有pi的和为1。
n,vi,pi的含义见问题描述。
输出
输出一行一个非负整数E表示模998244353意义下的E。
样例输入
<span style="color:#333333"><span style="color:#333333">【样例输入】
2
1 2
332748118 665496236
</span></span>
样例输出
<span style="color:#333333"><span style="color:#333333">【样例输出】
961272344
</span></span>
提示
【样例说明】
骰子共有2个面。
第一面的数为1,朝上的概率为1/3;
第二面的数为2,朝上的概率为2/3。
所有情况列举如下:
第1次朝上的面 |
第2次朝上的面 |
第3次朝上的面 |
朝上的面的和 |
概率 |
1 |
1 |
/ |
2 |
1/9 |
1 |
2 |
1 |
4 |
2/27 |
1 |
2 |
2 |
5 |
4/27 |
2 |
1 |
1 |
4 |
2/27 |
2 |
1 |
2 |
5 |
4/27 |
2 |
2 |
/ |
4 |
4/9 |
所以E=2*1/9+4*2/27+5*4/27+4*2/27+5*4/27+4*4/9=110/27。
【子任务】
测试点 |
n |
vi,pi |
1~4 |
≤8 |
<998244353 |
5~8 |
≤50 |
|
9~12 |
≤100 |
|
13~20 |
≤500 |
solution
期望dp
令f[i][j]表示前i张牌选j张得期望
若f[i][j]=PS,新加入i点,那么新的期望为P*pi*(S+vi)
展开得到PS*pi+P*pi*vi
于是我们还需维护期望的和g[i][j]=P转移式有了
f[i][j]=f[i-1][j]+f[i-1][j-1]*pi+g[i-1][j-1]*vi
我们可以枚举哪一位为出现两次的
效率O(n^3)
jyc神犇有优化
因为这个dp与数的顺序无关(不同顺序丢进去出来的是一个结果)
我们可以把最后一维当成我要禁掉的
倒推出n-1维
效率O(n^2)
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#define maxn 505
#define ll long long
#define mod 998244353
using namespace std;
int n;
ll v[maxn],p[maxn],f[maxn][maxn],g[maxn][maxn],ans,h[maxn];
int main(){
cin>>n;
for(int i=1;i<=n;i++)scanf("%lld",&v[i]);
for(int i=1;i<=n;i++)scanf("%lld",&p[i]);
h[0]=1;
for(int i=1;i<=n;i++)h[i]=(h[i-1]*i)%mod;
for(int i=0;i<=n;i++)g[i][0]=1;
for(int i=1;i<=n;i++)
for(int j=1;j<=i;j++){
f[i][j]=f[i-1][j]+(f[i-1][j-1]*p[i])%mod+((g[i-1][j-1]*p[i])%mod*v[i])%mod;
f[i][j]%=mod;
g[i][j]=g[i-1][j]%mod+g[i-1][j-1]*p[i]%mod;
}
for(int b=1;b<=n;b++){
int vb=v[b],pb=p[b];
for(int j=1;j<=n;j++){
f[n-1][j]=f[n][j]-(f[n-1][j-1]*pb)%mod-((g[n-1][j-1]*pb)%mod*vb)%mod;
f[n][j]%=mod;
g[n-1][j]=g[n][j]-g[n-1][j-1]*pb%mod;
}
n--;
for(int i=0;i<=n;i++){
ll tmp=(f[n][i]*pb)%mod*pb;tmp%=mod;
tmp=tmp+g[n][i]*pb%mod*pb%mod*2*vb%mod;
ans=ans+(tmp*h[i+1])%mod;ans%=mod;
}
n++;
}
ans=(ans%mod+mod)%mod;
cout<<ans<<endl;
return 0;
}