CF1553I Stairs题解--zhengjun
题目传送门
谈点其他的
这个题目翻译好像有亿点问题,我重新发一波。
给定一个长度为 \(n\) 的排列 \(p\)。
令其中第 \(i\) 个位置的权值为最长的包含 \(i\) 的单调区间的长度(不仅要权值单调,而且要连续)。例如,\(p=[4,1,2,3,7,6,5]\) 中,第 \(6\) 个位置的权值为 \(3\)(\([5,7]\) 这个区间单调递减而且连续),第 \(2\) 个位置的权值为 \(3\)(\([2,4]\) 这个区间单调递增而且连续)。
将这些权值依次拼在一起,就得到了 \(p\) 的 \(\lceil\) 阶梯序列 \(\rfloor\)。
给定 \(a\),你需要求出存在多少个 \(p\),使得 \(a\) 为 \(p\) 的 \(\lceil\) 阶梯序列 \(\rfloor\)。
思路
首先可以看出来,就是这个排列的这些单调、权值连续的区间一定是独立的(没有相交的情况),这一步应该很好理解。
然后就可以根据 \(a\) 推出这些区间,比如样例1
:\(a=\{3,3,3,1,1,1\}\),那么这些区间就是 \(\{[1,3],[4,4],[5,5],[6,6]\}\)。
然后我们重新定义一下 \(a\),就是每个区间的长度依次组成的序列。例如样例1
,\(a=\{3,1,1,1\}\)。
注意接下来所有说的 \(a\) 都是新定义的 \(a\),\(len\) 就是新定义的 \(a\) 的长度
接下来就可以将 \(1\sim n\) 分配给这些区间,那么其实就是 \(len!\) 种,就考虑让这些区间一个一个地取数,例如样例1
,如果取数的顺序为 \({3,1,2,4}\),那么区间 \([1,3]\) 就分配到了 \(2,3,4\) 这三个数(因为 \(1\) 被排在第一个的区间 \(3\) 取走了),所以这个排列 \(p\) 就是 \(\{2,3,4,5,1,6\}\) 或 \(\{4,3,2,5,1,6\}\)。
这里看出了一个问题,就是所有长度大于等于 \(2\) 的区间都有递增和递减两种情况,所以还要乘上一个 \(2^x\),这个 \(x\) 就是长度大于等于 \(2\) 的区间的个数。
但是还会有一个问题,就是分配出来的东西,可能相邻的两个区间可以一起组成一个大的单调、连续的区间,那么要把这种情况排除掉,就可以容斥。
答案就是
至少有 \(0\) 对相邻的区间合并 \(-\) 至少有 \(1\) 对相邻的区间合并 \(+\) 至少有2对相邻的区间合并\(\dots+(-1)^{len}\)至少有 \((len-1)\) 对相邻的区间合并
Solution1
复杂度:\(O(n^2)\)
考虑 dp。
设 \(f_{i,j}\) 表示区间编号为 \(1\sim i\),合并成 \(j\) 段的值是多少。
转移方程式:
初始化:
然后答案就是:
代码就不贴了,因为我比较勤快懒惰,很不想写代码 QWQ。
Solution
复杂度:\(O(n\log^2n)\)
可以注意到在新的 \(a\) 序列中合并两个区间时,左边是 \(x\) 段,右边是 \(y\) 段,合并起来就是 \(x+y\) 段,合并起来的值就是乘积,想到什么???
卷积!!!
可以分治 + NTT,这里模数正好是 \(98244353\),原根就是 \(3\)。
但是还有一些问题,就比如说 \([l,mid]\) 的最右边的那个数是大于 \(1\) 的,\([mid+1,r]\) 的最左边的那个数也是大于 \(1\) 的,合并起来就要除以 \(2\)。
两个都是 \(1\),就要乘以 \(2\)。
然后,细节多得要死,常数大得要死。
最后,注意如果在 \((mid,mid+1)\) 这里分一刀,那么直接卷就好了,如果不分,那么段数要 \(-1\)。
代码
注释看看最后一份代码好了。
然后顺利地 TLE 了……
然后,@275307894a 大佬,给了我一些建议:
每一次乘法时都要 NTT 3 遍,其实可以先 NTT 一遍,乘的时候直接对位相乘就好了。
但是,你发现又 TLE 了。
然后继续优化,发现 NTT 里面有多次调用 qpow,所以可以预处理。
然后发现 div 里面每次都只是除以二,所以不用每次算逆元,每次都直接乘以 2 的逆元就好了。
最后的代码(有注释)
#pragma GCC optimize(3)
#pragma GCC target("avx")
#pragma GCC optimize("Ofast")
#pragma GCC optimize(2)
#include<cstdio>
#include<vector>
#include<algorithm>
#include<ctime>
#define int long long
#define ll long long
using namespace std;
const ll mod=998244353,g=3,ginv=332748118,N=4e5+10,inv2=499122177;
ll qpow(ll x,ll y){
ll ans=1;
while(y){
if(y&1)(ans*=x)%=mod;
(x*=x)%=mod;
y>>=1;
}
return ans;
}
int rr[N];
ll gg[N],gginv[N],lim[N];
ll cnt;
vector<ll> NTT(vector<ll> a,int limit,int type){//用 vector 好写一些吧
for(int i=0;i<limit;i++)if(rr[i]<i)swap(a[rr[i]],a[i]);
for(int mid=1;mid<limit;mid<<=1){
ll wn=type==1?gg[mid<<1]:gginv[mid<<1];//main 里面处理
for(int j=0;j<limit;j+=(mid<<1)){
ll w=1;
for(int k=0;k<mid;k++,(w*=wn)%=mod){
cnt++;ll x=a[j+k],y=w*a[j+k+mid]%mod;//NTT 基本操作,蝴蝶效应
a[j+k]=(x+y)%mod;a[j+k+mid]=(x-y+mod)%mod;
}
}
}if(~type)return a;for(int i=0;i<limit;i++) (a[i]*=lim[limit])%=mod;return a;//别忘了如果是逆回去,要除以 limit,这里也预处理了
}
vector<ll> add(vector<ll> a,vector<ll> b){//相加
int n=a.size(),m=b.size();
while(n<m)a.push_back(0),n++;//要把 0 补起来,不然要 RE
while(n>m)b.push_back(0),m++;
for(int i=0;i<n;i++)(a[i]+=b[i])%=mod;
return a;
}
vector<ll> add_(vector<ll> a,vector<ll> b){//错位加
int n=a.size(),m=b.size();
while(n<m)a.push_back(0),n++;
while(n>m)b.push_back(0),m++;
for(int i=0;i<n-1;i++)(a[i]+=b[i+1])%=mod;
return a;
}
vector<ll> div(vector<ll> a){//除以二
int n=a.size();
for(int i=0;i<n;i++)(a[i]*=inv2)%=mod;
return a;
}
vector<ll> mul(vector<ll> a,ll x){//乘以x
int n=a.size();
for(int i=0;i<n;i++)(a[i]*=x)%=mod;
return a;
}
struct zj{
vector<ll>a[2][2];//分别表示左右端点是否大于 1
zj(){a[0][0].push_back(0),a[0][1].push_back(0),a[1][0].push_back(0),a[1][1].push_back(0);}//分成 0 段的方案数就是 0
};
int n,a[N],b[N],m;
vector<ll> mul(vector<ll> a,vector<ll> b){//卷积
int limit=1,l=0,n=a.size(),m=b.size();
while(limit<=n+m)limit<<=1,l++;
for(int i=0;i<limit;i++)rr[i]=(rr[i>>1]>>1)|((i&1)<<(l-1));
while(n<limit)a.push_back(0),n++;
while(m<limit)b.push_back(0),m++;
a=NTT(a,limit,1);b=NTT(b,limit,1);
for(int i=0;i<limit;i++)(a[i]*=b[i])%=mod;
a=NTT(a,limit,-1);
while(a.size()>1&&a[a.size()-1]==0)a.pop_back();//注意要把无用的 0 去掉,不然会很耗内存和时间
return a;
}
vector<ll> mul_(vector<ll>a,vector<ll>b){//对位乘
int n=a.size(),m=b.size();
while(n<m)a.push_back(0),n++;
while(n>m)b.push_back(0),m++;
for(int i=0;i<n;i++)(a[i]*=b[i])%=mod;
return a;
}
zj merge(int l,int r){
if(l==r){//边界
zj x;
if(a[l]==1)x.a[0][0].push_back(1);
else x.a[1][1].push_back(2);
return x;
}
int mid=(l+r)>>1;
zj x=merge(l,mid),y=merge(mid+1,r);
zj tmp,tmpp;
if(l+1==r){
for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)tmp.a[i][j]=add(tmp.a[i][j],mul(x.a[i][ii],y.a[jj][j]));
for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)if(ii!=jj)tmp.a[1][1]=add_(tmp.a[1][1],mul(x.a[i][ii],y.a[jj][j]));else if(!ii)tmp.a[1][1]=add_(tmp.a[1][1],mul(mul(x.a[i][ii],y.a[jj][j]),2));else tmp.a[1][1]=add_(tmp.a[1][1],div(mul(x.a[i][ii],y.a[jj][j])));
return tmp;
}
else if(l+2==r){
for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)tmp.a[i][j]=add(tmp.a[i][j],mul(x.a[i][ii],y.a[jj][j]));
for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)if(ii!=jj)tmp.a[i][1]=add_(tmp.a[i][1],mul(x.a[i][ii],y.a[jj][j]));else if(!ii)tmp.a[i][1]=add_(tmp.a[i][1],mul(mul(x.a[i][ii],y.a[jj][j]),2));else tmp.a[i][1]=add_(tmp.a[i][1],div(mul(x.a[i][ii],y.a[jj][j])));
return tmp;
}
//这两个特盘是因为长度在 3 以内时,合并会改变左右端点的 0/1
int limit=1,llll=0,lena=0,lenb=0;//tmp 是 (mid,mid+1) 切一刀,tmpp 是不切,一定要分开,因为不切要错位加,点值是不能错位加的,要 NTT 回去再错位加起来
for(int i=0;i<2;i++)for(int j=0;j<2;j++)lena=max(lena,(int)x.a[i][j].size());
for(int i=0;i<2;i++)for(int j=0;j<2;j++)lenb=max(lenb,(int)y.a[i][j].size());
while(limit<=lena+lenb)limit<<=1,llll++;
for(int i=0;i<limit;i++)rr[i]=(rr[i>>1]>>1)|((i&1)<<(llll-1));//NTT 基本操作
for(int i=0;i<2;i++)for(int j=0;j<2;j++)while(x.a[i][j].size()<limit)x.a[i][j].push_back(0);//注意补 0
for(int i=0;i<2;i++)for(int j=0;j<2;j++)while(y.a[i][j].size()<limit)y.a[i][j].push_back(0);
for(int i=0;i<2;i++)for(int j=0;j<2;j++)while(tmp.a[i][j].size()<limit)tmp.a[i][j].push_back(0);
for(int i=0;i<2;i++)for(int j=0;j<2;j++)while(tmpp.a[i][j].size()<limit)tmpp.a[i][j].push_back(0);
for(int i=0;i<2;i++)for(int j=0;j<2;j++)x.a[i][j]=NTT(x.a[i][j],limit,1);
for(int i=0;i<2;i++)for(int j=0;j<2;j++)y.a[i][j]=NTT(y.a[i][j],limit,1);
for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)tmp.a[i][j]=add(tmp.a[i][j],mul_(x.a[i][ii],y.a[jj][j]));//NTT 之后要对位乘
for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)if(ii!=jj)tmpp.a[i][j]=add(tmpp.a[i][j],mul_(x.a[i][ii],y.a[jj][j]));else if(!ii)tmpp.a[i][j]=add(tmpp.a[i][j],mul(mul_(x.a[i][ii],y.a[jj][j]),2));else tmpp.a[i][j]=add(tmpp.a[i][j],div(mul_(x.a[i][ii],y.a[jj][j])));//不能再这里写错位加,我 TM 调了半天!!!一定要 NTT 回去再错位加
for(int i=0;i<2;i++)for(int j=0;j<2;j++)tmp.a[i][j]=NTT(tmp.a[i][j],limit,-1);
for(int i=0;i<2;i++)for(int j=0;j<2;j++)tmpp.a[i][j]=NTT(tmpp.a[i][j],limit,-1);
for(int i=0;i<2;i++)for(int j=0;j<2;j++)tmp.a[i][j]=add_(tmp.a[i][j],tmpp.a[i][j]);//错位加
for(int i=0;i<2;i++)for(int j=0;j<2;j++)while(tmp.a[i][j].size()>1&&tmp.a[i][j][tmp.a[i][j].size()-1]==0)tmp.a[i][j].pop_back();//删去无用 0 ,但是要留一个
return tmp;
}
signed main(){
scanf("%lld",&m);
//上文讲的预处理
for(int mid=1;mid<N;mid<<=1)gg[mid]=qpow(g,(mod-1)/mid),gginv[mid]=qpow(ginv,(mod-1)/mid),lim[mid]=qpow(mid,mod-2);
for(int i=1;i<=m;i++)scanf("%lld",&b[i]);
for(int i=1;i<=m;){
a[++n]=b[i];
for(int j=0;j<b[i];j++){
if(i+j>m)return printf("0\n"),0;
if(b[i+j]!=b[i])return printf("0\n"),0;
}
i=i+b[i];
}//变幻成新定义的 a 数组
zj ans=merge(1,n);
for(int i=0;i<2;i++)for(int j=0;j<2;j++){
while(ans.a[i][j].size()<=n+1)ans.a[i][j].push_back(0);
}
for(int i=1;i<=n;i++)(ans.a[0][0][i]+=ans.a[0][1][i]+ans.a[1][0][i]+ans.a[1][1][i])%=mod;
ll now=1,ANS=0;
for(int i=1;i<=n;i++){//容斥,和上面的公式一样的
(now*=i)%=mod;
(ANS+=((n-i+1)&1?1:-1)*now*ans.a[0][0][i]%mod)%=mod;
(ANS+=mod)%=mod;
}
return printf("%lld",ANS),0;
}
最后说点什么
第一次发黑题题解,第一次独立做出黑题。
需要卡常,不知道为什么我跑了 \(2.4\) min,别人都是 \(10\) s。