题解 CF1654H Three Minimums
题意简述
给出长 \(m\) 的不等号序列 \(s\),求长度为 \(n\) 的排列 \(p\) 个数,使得:
-
对于所有 \(1 \leq l < r \leq n,r-l>1\) 且最小值与次小值在区间两端的区间,区间内第三小值为 \(p_{l+1}\) 或 \(p_{r-1}\)。
-
对于所有 \(t \in [1,m]\),若 \(s_t=<\),则 \(p_t<p_{t+1}\);否则 \(p_t>p_{t+1}\)。
答案对 \(998244353\) 取模。
\(n \leq 2 \times 10^5,m \leq 100\)。
题目分析
标记记号 \(check[i,<]=[i>m \vee s_i='<']\),\(check[i,<]=[i>m \vee s_i='>']\)。
考虑一个序列是否是好的:
-
若 \(1\) 不在两端,设 \(p_i=1\),则 \(p[1:n]\) 合法当且仅当 \(p[1:i]\) 与 \(p[i:n]\) 合法;
-
若 \(1\) 在两端(不妨设 \(p_1=1\)):
-
若 \(2\) 不在另一端,设 \(p_j=2\),则 \(p[1:n]\) 合法当且仅当 \(p[1:j]\) 与 \(p[j:n]\) 合法。
-
否则 \(p_n=2\),则 \(p[1:n]\) 合法当且仅当 「\(p_2=3\),\(p[2:n]\) 是好的」 或者 「\(p_{n-1}=3\),\(p[1:n-1]\) 是好的」。
-
于是设计一个区间 DP。设 \(a_{**}(l,r),a_{1*}(l,r),a_{*1}(l,r),a_{12}(l,r),a_{21}(l,r)\) 分别表示对于区间 \((l,r)\),满足给定条件的「区间两端无限制」,「区间左端点为最小值」,「区间右端点为最小值」,「区间左端点为最小值,右端点为次小值」,「区间左端点为次小值,右端点为最小值」的方案数。于是有:
暴力做时间复杂度 \(O(n^3)\),无法接受。
答案为 \(a_{**}(1,n)\),故要对所有 \(k \in [k,n]\) 计算出 \(a_{*1}(1,k)\) 及 \(a_{1*}(k,n)\)。
- Part 1:\(a_{*1}(1,k)\)
DP 式为 \(a_{*1}(1,1)=1\),\(a_{*1}(1,k)=\sum\limits_{i=1}^{k-1} a_{*1}(1,i) a_{21}(i,k) \dbinom{k-2}{i-1}\)。
观察到给左右两边的 \(a_{*1}\) 同乘阶乘会使形式简化许多。设 \(a_{*1}(1,k)=(k-1)!x_{k-1},x_0=1\),则有:
这个形式很好看,但是 \(a_{21}(i+1,k+1)\) 有两个变量而难以处理。
注意到限制只在前 \(m\) 位处,故对于 \(l>m\) 的状态,其 DP 值只与长度 \(r-l+1\) 有关。设 \(b_{??}(k)=a_{??}(m+1,m+k)\)。上式就可以改写为:
记 \(u_i=\dfrac{b_{21}(i+2)}{i!}\),\(v_{k-1}=\sum\limits_{i=0}^{\min(k-1,m-1)} x_i \dfrac{a_{21}(i+1,k+1)-b_{21}(k-i+1)}{(k-i-1)!}\)。则:
写成生成函数就是 \(X'=V+UX\),解这个微分方程得到 \(X=\exp(\int U)\Big[1+\int \Big( V \exp(-\int U) \Big) \Big]\)。
注意到 \(b_{21}(k)=2^{k-2}+[k=1]\),对所有 \(l \in [1,m+1]\),\(r \in (l,n]\) 暴力 DP 出 \(a_{21}(l,r)\) 即可得到 \(b_{21}(k)\),再对所有 \(t \in [1,m]\) 暴力 DP 得到 \(x_t\),即可计算 \(U\) 与 \(V\)。
- Part 2:\(a_{1*}(k,n)\)
注意到 \(k>m\) 时 \(a_{1*}(k,n)=b_{1*}(n-k+1)\),考虑先求出 \(b_{1*}(k)\),求出后可对 \(k \in [1,m]\) 暴力 DP 出 \(a_{1*}(k,n)\)。
DP 式为 \(a_{1*}(n,n)=1,a_{1*}(k,n)=\sum\limits_{i=k+1}^n a_{12}(k,i) a_{1*}(i,n) \dbinom{n-k-1}{i-k-1}\),即:
令 \(b_{1*}(k)=(k-1)!y_{k-1}\),代入:
于是有 \(Y'=UY\),解得 \(Y=\exp(\int U)\)。
总时间复杂度 \(O(n(\log n+m))\)。
代码
...
long long n,m,ans;
char S[N];
bool ck(long long pos,bool id){
if(pos>m) return 1;
return id==(S[pos]=='>');
}
long long a12[M][O],a21[M][O],a_1[N],a1_[N],b12[N],b21[N],b1_[N];
long long x[N],U[N],V[N],P[N],Q[N],F[N];
int main(){
init(N-5);
cin>>n>>m;
scanf("\n%s",S+1);
for(long long l=1;l<=n;l++){
for(long long i=1;i+l-1<=n && i<=m+1;i++){
long long j=i+l-1;
if(j==i+1) a12[i][i+1]=ck(i,0);
else if(j==i+2) a12[i][j]=ck(i,0)*ck(j-1,1);
else a12[i][j]=(ck(i,0)*((i==m+1)?pw[j-(i+1)-2]:a21[i+1][j])%mod+ck(j-1,1)*a12[i][j-1]%mod)%mod;
if(j==i+1) a21[i][i+1]=ck(i,1);
else if(j==i+2) a21[i][j]=ck(i,0)*ck(j-1,1);
else a21[i][j]=(ck(i,0)*((i==m+1)?pw[j-(i+1)-2]:a21[i+1][j])%mod+ck(j-1,1)*a12[i][j-1]%mod)%mod;
}
}
//a12/a21[1,m+1][1,n]
for(long long i=1;i<=n-m;i++) b12[i]=a12[m+1][m+i],b21[i]=a21[m+1][m+i];
x[0]=1;
for(long long i=1;i<=m;i++){
for(long long j=0;j<i;j++) x[i]=(x[i]+x[j]*a21[j+1][i+1]%mod%mod*inv[i-j-1]%mod)%mod;
x[i]=x[i]*invv[i]%mod;
}
//x[1,m]
for(long long i=0;i<=n+1;i++){
U[i]=b21[i+2]*inv[i]%mod;
if(i) for(long long j=0;j<=min(i-1,m-1);j++)
V[i-1]=(V[i-1]+x[j]*(a21[j+1][i+1]+mod-b21[i+1-j])%mod*inv[i-j-1]%mod)%mod;
}
//U/V[0,n]
poly::Limit(U,U,n);
poly::Exp(P,U,n);
INV::solve(Q,P,n);
NTT::solve(V,V,Q,n,n);
poly::Limit(V,V,n);
V[0]=(V[0]+1)%mod;
NTT::solve(F,P,V,n,n);
//F/P
for(int i=1;i<=n;i++) a_1[i]=F[i-1]*mul[i-1]%mod,b1_[i]=P[i-1]*mul[i-1]%mod;
for(int i=m;i>=1;i--)
for(long long t=i+1;t<=n;t++)
a1_[i]=(a1_[i]+a12[i][t]*((t>m)?b1_[n-t+1]:a1_[t])%mod*C(n-i-1,t-i-1)%mod)%mod;
//a1_[1,m][n]
for(int i=1;i<=n;i++) ans=(ans+a_1[i]*((i>m)?b1_[n-i+1]:a1_[i])%mod*C(n-1,i-1)%mod)%mod;
cout<<ans;
}