CF1608F MEX counting
为什么一种dp状态不能优化另一种dp状态就可以了啊。
这题首先我想的是设\(f_{p,i,j}\)表示到了第\(p\)个位置,当前mex为\(j\),还有\(i\)个点大于\(j\)的方案数,每次转移分类讨论:
1.若当前不改变mex,且当前点大于mex,则\(f_{p,i,j}\to f_{p+1,i+1,j}\)
2.若当前不改变mex,且当前点小于mex,则\(f_{p,i,j}\times j\to f_{p+1,i,j}\)
3.若当前点改变mex,则枚举mex到的点\(x\)与在这区间内的点的数量\(y\),有\(f_{p,i,j}\times C_{i}^{y}\times g_{x-i-1,y}\to f_{p+1,i-y,x}\),其中\(g_{i,j}(j\leq i)\)表示在一段长度为\(j\)的序列上放\(j\)种颜色的球,要求每个位置只能放一个且每种颜色至少一个,\(g_{i,j}\)可以用容斥求出:枚举没有颜色的球的个数,有\(g_{i,j}=\sum\limits_{k=0}^{j}{(-1)^kC_{j}^{k}(j-k)^i}\)。
然后你发现这个东西复杂度至少是\(O(n^5)\)的而且很难拆系数。如果往这个方向硬推不知道能不能推出来。不过大概率像我一样死定了。
在这个过程中我们独立了\(f\)与\(g\)的计算,我们考虑能不能让它看上去和谐一点。
回到\(g\)的定义,根据这个定义我们可以用一个dp求\(g\),设\(dp_{i,j}\)表示到了序列上第\(i\)个位置,已经有了\(j\)种颜色,分类讨论当前是否新开一种颜色即可。
如果我们这样计算相当于强行拆出了一个中间变量序列长度,我们考虑将两个状态融合一下。
重设\(f_{p,i,j}\)表示到了第\(p\)个位置,mex为\(j\),当前有\(i\)种数的方案数,同样分三类讨论:
1.若当前不新开一个数,则\(f_{p,i,j}\times i\to f_{p+1,i,j}\)。
2.若当前新开一个数,但是没有开到mex的位置,则\(f_{p,i,j}\to f_{p,i+1,j}\)。
3.若当前新开了一个数,且等于mex,则枚举mex要到的位置\(x\),有\(f_{p,i,j}\times A_{i-j}^{x-j}\to f_{p+1,i+1,x}\)
这样看上去是\(O(n^4)\)的且非常简洁。
首先我们发现mex这一维是\(O(k)\)的,因此复杂度其实是\(O(n^2k^2)\)的。
而后把那个排列数拆了,发现分别与\(j\)和\(x\)独立,因此可以做出\(x\)的前缀和,就可以将转移优化成\(O(1)\)的。
时间复杂度\(O(n^2k)\)
code:
#include<bits/stdc++.h>
#define Gc() getchar()
#define Me(x,y) memset(x,y,sizeof(x))
#define Mc(x,y) memcpy(x,y,sizeof(x))
#define R(n) (rnd()%(n))
#define Pc(x) putchar(x)
#define LB lower_bound
#define UB upper_bound
#define PB push_back
using ll=long long;using db=double;using lb=long db;using ui=unsigned;using ull=unsigned ll;using u128=__int128;
using namespace std;const int N=2e3+5,M=4e3+5,K=2e6+5,mod=998244353,Mod=mod-1;ll INF=1e18+7;const db eps=1e-5;mt19937 rnd(time(0));
int n,k,B[N],L[N],R[N];ll f[N][N],g[N][N],C[N][N],frc[N],Inv[N],Q[N],Ans,Ts;
ll mpow(ll x,int y=mod-2){ll Ans=1;while(y) y&1&&(Ans=Ans*x%mod),y>>=1,x=x*x%mod;return Ans;}
ll A(int x,int y){return frc[x]*Inv[x-y]%mod;}
int main(){
freopen("1.in","r",stdin);
int i,j,h,x,y;scanf("%d%d",&n,&k);for(i=1;i<=n;i++) scanf("%d",&B[i]),L[i]=max(0,B[i]-k),R[i]=min(n,B[i]+k);
for(i=0;i<=n;i++) for(C[i][0]=j=1;j<=i;j++) C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod;for(frc[0]=Inv[0]=i=1;i<=n;i++) frc[i]=frc[i-1]*i%mod,Inv[i]=mpow(frc[i]);
f[0][0]=1;for(i=1;i<=n;i++){
for(j=0;j<=i;j++) for(h=L[i-1];h<=R[i-1];h++) g[j][h]=f[j][h],f[j][h]=0;
for(j=1;j<=i;j++){
L[i-1]&&(Q[L[i-1]-1]=0);for(h=L[i-1];h<=min(R[i-1],j-1);h++) Q[h]=((h?Q[h-1]:0)+g[j-1][h]*frc[j-h-1])%mod;
for(h=L[i];h<=min(R[i],j);h++){
f[j][h]=(g[j][h]*j+g[j-1][h])%mod;Ts=0;
//for(x=L[i-1];x<=min(R[i-1],h-1);x++) Ts=(Ts+g[j-1][x]*frc[j-x-1])%mod;
//if(Ts^Q[min(R[i-1],h-1)]) printf("%d %d\n",i,j);assert(Ts==Q[min(R[i-1],h-1)]);
f[j][h]=(f[j][h]+(min(R[i-1],h-1)<L[i-1]?0:Q[min(R[i-1],h-1)])*Inv[j-h])%mod;
}
}
for(j=0;j<=i;j++) for(h=L[i-1];h<=R[i-1];h++) g[j][h]=0;
}for(i=L[n];i<=R[n];i++) for(j=i;j<=n;j++) Ans=(Ans+f[j][i]*A(n-i,j-i))%mod;printf("%lld\n",Ans%mod);
}