LOJ#2719. 「NOI2018」冒泡排序 DP+组合+树状数组
打表发现一个排列满足要求,当且仅当这个排列以被拆成两个 LIS(其中一个可以为空)
考虑状压 DP:$f[S][i][0/1]$ 表示加进去数集 $S$,第二个 LIS 的最大值为 $i$,是否顶上界.
我们拆分 $LIS$ 的方式是贪心地将较大的值分给第一个,剩下的分给第二个.
这个时间复杂度大概是 $O(2^n n^2)$.
先不考虑字典序的限制.
假设已经加进去了数集 $S$,$S$ 中的最大值为 $mx$.
则如果新加入数 $i$ 有两种可能:$i>mx$ 或 $i<mx$ 但 $i$ 必须是未填入数中最小值.
我们发现,如果每一步都遵循上述转移方式,则一定合法,也就不用记录数集 $S$ 了.
直接令 $f[i][j]$ 表示还有 $i$ 个数没有加入,且有 $j$ 个数大于已加入数的最大值.
转移的话有 :
$f[i][j]=f[i-1][j]$ (加入最小值)
$f[i][j]=f[i-1][j-k]$ (加入一个数后会导致原来大于 $mx$ 的数小于新的 $mx$ 了)
总:$f[i][j]=\sum_{k=0}^{j}f[i-1][j-k]$
边界:$f[0][0]$.
再整理一下得 $f[i][j]=f[i-1][j]+f[i][j-1].$ 其中 $j \leqslant i$
预处理这个是 $O(n^2)$ 的.
$f[i][j]$ 的转移可以看作是从 $(0,0)$ 走到 $(i,j)$ 且不经过 $y=x+1$ 的方案数.
可以用总方案数 $\binom{i+j}{i}$ 减掉不合法的,求法和生成字符串那道题一样.
至此,我们就得到了一个 $O(n)$ 预处理,$O(1)$ 求出 $f(i,j)$ 的解法.
考虑枚举 $LCP$,然后再枚举第一个大于的位置,这部分裸做的话是 $O(n^3)$ 的.
即判断当前是否可行,再统计未填入数中大于前面最大值的个数,累加一个 $f(n-i,cn)$
考虑枚举到 $i$,且前 $i-1$ 位合法.
令 $b[i]$ 表示 $i$ 后面大于当前位的个数.
则 $mx(i)=\min_{j=1}^{i}b[j]$.
考虑当前 $i$ 位填入数 $j$ 且 $j>q[i]$.
则 $j$ 必须从 $mx(i)$ 中选取,因为如果 $j$ 不在这里选取的话会导致第二个 $LIS$ 不合法.
所以 $j$ 就会贡献一个占位,这部分的贡献就是 $\sum_{k=0}^{nw-1}f[n-i][k]$
可以直接表示成 $f(n-i+1,nw-1)$.
判断不合法的话就是如果 $q[i]$ 不是前缀最大值且 $i$ 前面小于 $q[i]$ 的个数不等于 $q[i]-1$.
复杂度的瓶颈在于求后面大于 $i$ 的个数和前面小于 $i$ 的个数.
总时间复杂度为 $O(n \log n)$.
#include <cstdio> #include <vector> #include <cstring> #include <algorithm> #define N 2000004 #define ll long long #define mod 998244353 #define setIO(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout) using namespace std; int n,ans; int q[N],b[N],c[N],inv[N],fac[N]; inline int ADD(int x,int y) { return (x+y)>=mod?x+y-mod:x+y; } inline int C(int x,int y) { return (ll)fac[x]*inv[y]%mod*inv[x-y]%mod; } inline int GE(int x,int y) { //if(y==0) return 1; //if(y==1) return x; return ADD(C(x+y,x),mod-C(x+y,x+1)); } void init() { fac[0]=1; for(int i=1;i<N;++i) fac[i]=(ll)fac[i-1]*i%mod; inv[1]=1; for(int i=2;i<N;++i) inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; inv[0]=1; for(int i=1;i<N;++i) inv[i]=(ll)inv[i-1]*inv[i]%mod; } struct BIT { int sum[N]; inline int lowbit(int x) { return x&(-x); } void upd(int x) { while(x<N) ++sum[x],x+=lowbit(x); } int query(int x) { int re=0; for(;x;x-=lowbit(x)) { re+=sum[x]; } return re; } void clr() { memset(sum,0,sizeof(sum)); } }op; void solve() { scanf("%d",&n); for(int i=1;i<=n;++i) scanf("%d",&q[i]); for(int i=n;i>=1;--i) { b[i]=n-i-op.query(q[i]); op.upd(q[i]); // 前面小于我的=i-1-前面大于我的 c[i]=i-1-(n-q[i]-b[i]); } op.clr(); ans=0; int nw=n; for(int i=1;i<=n;++i) { bool flag=(b[i]<nw); nw=min(nw,b[i]); if(nw<=0) break; ans=ADD(ans,GE(n-i+1,nw-1)); if(!flag&&c[i]!=q[i]-1) break; } printf("%d\n",ans); } int main() { //setIO("inverse"); int T,x,y,z; init(); scanf("%d",&T); while(T--) { solve(); } return 0; }