题解 工业题
感谢这题题面提醒我我工业的高精还不会打,先咕着(逃
看完题解之后发现的确不难但考场上就是想不出正解
总觉得是个矩阵快速幂,虽然明知道光输入的数组放矩阵里跑\(n^3\)就炸上天了
有想过计算每个\(f_i,0\)对结果的贡献,但觉得应该没这么麻烦,还觉得那样像计数dp,就没往那边想
其实因为每次都是加和,每个\(f_{i,0}\)和\(f_{0,i}\)的贡献其实互不影响
那就可以分开考虑到达 \(f_{n,m}\) 的路径数和a,b的次幂,则贡献t为
\[t=f_{i,0}*\frac{(n-i+m-1)!}{(n-i)!(m-1)!}*a^{m}*b^{n-i}
\]
\[t=f_{0,i}*\frac{(n-1+m-i)!}{(n-1)!(m-i)!}*a^{m-i}*b^{n}
\]
特别注意这里分子上的减1,被卡了巨久
是因为从\(f_{i,0}\)出发,第一步必须向右走,所以这一步的方案是固定的,\(f_{0,i}\)同理
- 还有记得读入取模一定要检查取模取全了没有! 只要有忘取模的就炸锅了
卡细节
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 600010
#define ll long long
#define ld long double
#define usd unsigned
#define ull unsigned long long
//#define int long long
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
char buf[1<<21], *p1=buf, *p2=buf;
inline ll read() {
ll ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, m;
ll a, b, ans;
const ll p=998244353;
namespace force{
ll f[2010][2010];
void solve() {
for (int i=1; i<=n; ++i) f[i][0]=read()%p;
for (int i=1; i<=m; ++i) f[0][i]=read()%p;
for (int i=1; i<=n; ++i)
for (int j=1; j<=m; ++j)
f[i][j] = (f[i][j-1]*a%p+f[i-1][j]*b%p)%p;
printf("%lld\n", f[n][m]);
//for (int i=0; i<=n; ++i) {for (int j=0; j<=m; ++j) cout<<setw(4)<<f[i][j]<<' '; cout<<endl;}
exit(0);
}
}
namespace task1{
ll f[3][N];
void solve() {
for (int i=1; i<=n; ++i) f[i][0]=read()%p;
for (int i=1; i<=m; ++i) f[0][i]=read()%p;
for (int i=1; i<=n; ++i)
for (int j=1; j<=m; ++j)
f[i][j] = (f[i][j-1]*a%p+f[i-1][j]*b%p)%p;
printf("%lld\n", f[n][m]);
//for (int i=0; i<=n; ++i) {for (int j=0; j<=m; ++j) cout<<setw(4)<<f[i][j]<<' '; cout<<endl;}
exit(0);
}
}
namespace task{
ll fac[N], inv[N], pa[N], pb[N], t;
inline ll C(ll n, ll k) {return fac[n]*inv[n-k]%p*inv[k]%p;}
inline ll calc(ll n, ll m) {return fac[n+m]*inv[n]%p*inv[m]%p;}
void solve() {
int lim=2*max(n, m);
fac[0]=fac[1]=1; inv[0]=inv[1]=1;
for (int i=2; i<=lim; ++i) fac[i]=fac[i-1]*i%p;
for (int i=2; i<=lim; ++i) inv[i]=(p-p/i)*inv[p%i]%p;
for (int i=2; i<=lim; ++i) inv[i]=inv[i]*inv[i-1]%p;
pa[0]=pb[0]=1;
for (int i=1; i<=lim; ++i) pa[i]=pa[i-1]*a%p;
for (int i=1; i<=lim; ++i) pb[i]=pb[i-1]*b%p;
for (int i=1; i<=n; ++i) {t=read()%p; ans=(ans+t*calc(n-i, m-1)%p*pb[n-i]%p*pa[m]%p)%p;}
for (int i=1; i<=m; ++i) {t=read()%p; ans=(ans+t*calc(n-1, m-i)%p*pb[n]%p*pa[m-i]%p)%p;}
printf("%lld\n", ans%p);
exit(0);
}
}
signed main()
{
#ifdef DEBUG
freopen("1.in", "r", stdin);
#endif
n=read(); m=read(); a=read()%p; b=read()%p;
//if (n>1) force::solve();
//else task1::solve();
task::solve();
return 0;
}