CF1553I Stairs题解--zhengjun

更好的阅读体验

题目传送门

LuoguCodeforces

谈点其他的

这个题目翻译好像有亿点问题,我重新发一波。

给定一个长度为 \(n\) 的排列 \(p\)

令其中第 \(i\) 个位置的权值为最长的包含 \(i\) 的单调区间的长度(不仅要权值单调,而且要连续)。例如,\(p=[4,1,2,3,7,6,5]\) 中,第 \(6\) 个位置的权值为 \(3\)\([5,7]\) 这个区间单调递减而且连续),第 \(2\) 个位置的权值为 \(3\)\([2,4]\) 这个区间单调递增而且连续)。

将这些权值依次拼在一起,就得到了 \(p\)\(\lceil\) 阶梯序列 \(\rfloor\)

给定 \(a\),你需要求出存在多少个 \(p\),使得 \(a\)\(p\)\(\lceil\) 阶梯序列 \(\rfloor\)

思路

首先可以看出来,就是这个排列的这些单调、权值连续的区间一定是独立的(没有相交的情况),这一步应该很好理解。

然后就可以根据 \(a\) 推出这些区间,比如样例1\(a=\{3,3,3,1,1,1\}\),那么这些区间就是 \(\{[1,3],[4,4],[5,5],[6,6]\}\)

然后我们重新定义一下 \(a\),就是每个区间的长度依次组成的序列。例如样例1\(a=\{3,1,1,1\}\)

注意接下来所有说的 \(a\) 都是新定义的 \(a\)\(len\) 就是新定义的 \(a\) 的长度

接下来就可以将 \(1\sim n\) 分配给这些区间,那么其实就是 \(len!\) 种,就考虑让这些区间一个一个地取数,例如样例1,如果取数的顺序为 \({3,1,2,4}\),那么区间 \([1,3]\) 就分配到了 \(2,3,4\) 这三个数(因为 \(1\) 被排在第一个的区间 \(3\) 取走了),所以这个排列 \(p\) 就是 \(\{2,3,4,5,1,6\}\)\(\{4,3,2,5,1,6\}\)

这里看出了一个问题,就是所有长度大于等于 \(2\) 的区间都有递增和递减两种情况,所以还要乘上一个 \(2^x\),这个 \(x\) 就是长度大于等于 \(2\) 的区间的个数。

但是还会有一个问题,就是分配出来的东西,可能相邻的两个区间可以一起组成一个大的单调、连续的区间,那么要把这种情况排除掉,就可以容斥。

答案就是

至少有 \(0\) 对相邻的区间合并 \(-\) 至少有 \(1\) 对相邻的区间合并 \(+\) 至少有2对相邻的区间合并\(\dots+(-1)^{len}\)至少有 \((len-1)\) 对相邻的区间合并

Solution1

复杂度:\(O(n^2)\)

考虑 dp。

\(f_{i,j}\) 表示区间编号为 \(1\sim i\),合并成 \(j\) 段的值是多少。

转移方程式:

\[ f_{i,j}=\left\{ \begin{array}{rcl} j\times2\times\sum\limits_{k=1}^{j-1}f_{i-1,k} & {2\le j\le i\le len,a_j>1}\\ j\times(\sum\limits_{k=1}^{j-1}2f_{i-1,k}-f_{i-1,j-1}) & {2\le j\le i\le len,a_j=1} \end{array} \right. \]

初始化:

\[ f_{1,1}=\left\{ \begin{array}{rcl} 2 & {a_1>1}\\ 1 & {a_1=1}\\ \end{array} \right. \]

\[f_{i,1}=2,i\ge2 \]

然后答案就是:

\[\sum\limits_{i=1}^{len}(-1)^{len-i}\times f_{n,i}\times i! \]

代码就不贴了,因为我比较勤快懒惰,很想写代码 QWQ。

Solution

复杂度:\(O(n\log^2n)\)

可以注意到在新的 \(a\) 序列中合并两个区间时,左边是 \(x\) 段,右边是 \(y\) 段,合并起来就是 \(x+y\) 段,合并起来的值就是乘积,想到什么???

卷积!!!

可以分治 + NTT,这里模数正好是 \(98244353\),原根就是 \(3\)

但是还有一些问题,就比如说 \([l,mid]\) 的最右边的那个数是大于 \(1\) 的,\([mid+1,r]\) 的最左边的那个数也是大于 \(1\) 的,合并起来就要除以 \(2\)

两个都是 \(1\),就要乘以 \(2\)

然后,细节多得要死,常数大得要死。

最后,注意如果在 \((mid,mid+1)\) 这里分一刀,那么直接卷就好了,如果不分,那么段数要 \(-1\)

代码

在这里

注释看看最后一份代码好了。

然后顺利地 TLE 了……

然后,@275307894a 大佬,给了我一些建议:

每一次乘法时都要 NTT 3 遍,其实可以先 NTT 一遍,乘的时候直接对位相乘就好了。

但是,你发现又 TLE 了。

然后继续优化,发现 NTT 里面有多次调用 qpow,所以可以预处理。

然后发现 div 里面每次都只是除以二,所以不用每次算逆元,每次都直接乘以 2 的逆元就好了。

最后的代码(有注释)

#pragma GCC optimize(3)
#pragma GCC target("avx")
#pragma GCC optimize("Ofast")
#pragma GCC optimize(2)
#include<cstdio>
#include<vector>
#include<algorithm>
#include<ctime>
#define int long long
#define ll long long
using namespace std;
const ll mod=998244353,g=3,ginv=332748118,N=4e5+10,inv2=499122177;
ll qpow(ll x,ll y){
    ll ans=1;
    while(y){
        if(y&1)(ans*=x)%=mod;
        (x*=x)%=mod;
        y>>=1;
    }
    return ans;
}
int rr[N];
ll gg[N],gginv[N],lim[N];
ll cnt;
vector<ll> NTT(vector<ll> a,int limit,int type){//用 vector 好写一些吧
    for(int i=0;i<limit;i++)if(rr[i]<i)swap(a[rr[i]],a[i]);
    for(int mid=1;mid<limit;mid<<=1){
        ll wn=type==1?gg[mid<<1]:gginv[mid<<1];//main 里面处理
        for(int j=0;j<limit;j+=(mid<<1)){
            ll w=1;
            for(int k=0;k<mid;k++,(w*=wn)%=mod){
                cnt++;ll x=a[j+k],y=w*a[j+k+mid]%mod;//NTT 基本操作,蝴蝶效应
                a[j+k]=(x+y)%mod;a[j+k+mid]=(x-y+mod)%mod;
            }
        }
    }if(~type)return a;for(int i=0;i<limit;i++) (a[i]*=lim[limit])%=mod;return a;//别忘了如果是逆回去,要除以 limit,这里也预处理了
}
vector<ll> add(vector<ll> a,vector<ll> b){//相加
    int n=a.size(),m=b.size();
    while(n<m)a.push_back(0),n++;//要把 0 补起来,不然要 RE
    while(n>m)b.push_back(0),m++;
    for(int i=0;i<n;i++)(a[i]+=b[i])%=mod;
    return a;
}
vector<ll> add_(vector<ll> a,vector<ll> b){//错位加
    int n=a.size(),m=b.size();
    while(n<m)a.push_back(0),n++;
    while(n>m)b.push_back(0),m++;
    for(int i=0;i<n-1;i++)(a[i]+=b[i+1])%=mod;
    return a;
}
vector<ll> div(vector<ll> a){//除以二
    int n=a.size();
    for(int i=0;i<n;i++)(a[i]*=inv2)%=mod;
    return a;
}
vector<ll> mul(vector<ll> a,ll x){//乘以x
    int n=a.size();
    for(int i=0;i<n;i++)(a[i]*=x)%=mod;
    return a;
}
struct zj{
    vector<ll>a[2][2];//分别表示左右端点是否大于 1
    zj(){a[0][0].push_back(0),a[0][1].push_back(0),a[1][0].push_back(0),a[1][1].push_back(0);}//分成 0 段的方案数就是 0
};
int n,a[N],b[N],m;
vector<ll> mul(vector<ll> a,vector<ll> b){//卷积
    int limit=1,l=0,n=a.size(),m=b.size();
    while(limit<=n+m)limit<<=1,l++;
    for(int i=0;i<limit;i++)rr[i]=(rr[i>>1]>>1)|((i&1)<<(l-1));
    while(n<limit)a.push_back(0),n++;
    while(m<limit)b.push_back(0),m++;
    a=NTT(a,limit,1);b=NTT(b,limit,1);
    for(int i=0;i<limit;i++)(a[i]*=b[i])%=mod;
    a=NTT(a,limit,-1);
    while(a.size()>1&&a[a.size()-1]==0)a.pop_back();//注意要把无用的 0 去掉,不然会很耗内存和时间
    return a;
}
vector<ll> mul_(vector<ll>a,vector<ll>b){//对位乘
    int n=a.size(),m=b.size();
    while(n<m)a.push_back(0),n++;
    while(n>m)b.push_back(0),m++;
    for(int i=0;i<n;i++)(a[i]*=b[i])%=mod;
    return a;
}
zj merge(int l,int r){
    if(l==r){//边界
        zj x;
        if(a[l]==1)x.a[0][0].push_back(1);
        else x.a[1][1].push_back(2);
        return x;
    }
    int mid=(l+r)>>1;
    zj x=merge(l,mid),y=merge(mid+1,r);
    zj tmp,tmpp;
    if(l+1==r){
        for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)tmp.a[i][j]=add(tmp.a[i][j],mul(x.a[i][ii],y.a[jj][j]));
        for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)if(ii!=jj)tmp.a[1][1]=add_(tmp.a[1][1],mul(x.a[i][ii],y.a[jj][j]));else if(!ii)tmp.a[1][1]=add_(tmp.a[1][1],mul(mul(x.a[i][ii],y.a[jj][j]),2));else tmp.a[1][1]=add_(tmp.a[1][1],div(mul(x.a[i][ii],y.a[jj][j])));
        return tmp;
    }
    else if(l+2==r){
        for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)tmp.a[i][j]=add(tmp.a[i][j],mul(x.a[i][ii],y.a[jj][j]));
        for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)if(ii!=jj)tmp.a[i][1]=add_(tmp.a[i][1],mul(x.a[i][ii],y.a[jj][j]));else if(!ii)tmp.a[i][1]=add_(tmp.a[i][1],mul(mul(x.a[i][ii],y.a[jj][j]),2));else tmp.a[i][1]=add_(tmp.a[i][1],div(mul(x.a[i][ii],y.a[jj][j])));
        return tmp;
    }
    //这两个特盘是因为长度在 3 以内时,合并会改变左右端点的 0/1
    int limit=1,llll=0,lena=0,lenb=0;//tmp 是 (mid,mid+1) 切一刀,tmpp 是不切,一定要分开,因为不切要错位加,点值是不能错位加的,要 NTT 回去再错位加起来
    for(int i=0;i<2;i++)for(int j=0;j<2;j++)lena=max(lena,(int)x.a[i][j].size());
    for(int i=0;i<2;i++)for(int j=0;j<2;j++)lenb=max(lenb,(int)y.a[i][j].size());
    while(limit<=lena+lenb)limit<<=1,llll++;
    for(int i=0;i<limit;i++)rr[i]=(rr[i>>1]>>1)|((i&1)<<(llll-1));//NTT 基本操作
    for(int i=0;i<2;i++)for(int j=0;j<2;j++)while(x.a[i][j].size()<limit)x.a[i][j].push_back(0);//注意补 0
    for(int i=0;i<2;i++)for(int j=0;j<2;j++)while(y.a[i][j].size()<limit)y.a[i][j].push_back(0);
    for(int i=0;i<2;i++)for(int j=0;j<2;j++)while(tmp.a[i][j].size()<limit)tmp.a[i][j].push_back(0);
    for(int i=0;i<2;i++)for(int j=0;j<2;j++)while(tmpp.a[i][j].size()<limit)tmpp.a[i][j].push_back(0);
    for(int i=0;i<2;i++)for(int j=0;j<2;j++)x.a[i][j]=NTT(x.a[i][j],limit,1);
    for(int i=0;i<2;i++)for(int j=0;j<2;j++)y.a[i][j]=NTT(y.a[i][j],limit,1);
    for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)tmp.a[i][j]=add(tmp.a[i][j],mul_(x.a[i][ii],y.a[jj][j]));//NTT 之后要对位乘
    for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)if(ii!=jj)tmpp.a[i][j]=add(tmpp.a[i][j],mul_(x.a[i][ii],y.a[jj][j]));else if(!ii)tmpp.a[i][j]=add(tmpp.a[i][j],mul(mul_(x.a[i][ii],y.a[jj][j]),2));else tmpp.a[i][j]=add(tmpp.a[i][j],div(mul_(x.a[i][ii],y.a[jj][j])));//不能再这里写错位加,我 TM 调了半天!!!一定要 NTT 回去再错位加
    for(int i=0;i<2;i++)for(int j=0;j<2;j++)tmp.a[i][j]=NTT(tmp.a[i][j],limit,-1);
    for(int i=0;i<2;i++)for(int j=0;j<2;j++)tmpp.a[i][j]=NTT(tmpp.a[i][j],limit,-1);
    for(int i=0;i<2;i++)for(int j=0;j<2;j++)tmp.a[i][j]=add_(tmp.a[i][j],tmpp.a[i][j]);//错位加
    for(int i=0;i<2;i++)for(int j=0;j<2;j++)while(tmp.a[i][j].size()>1&&tmp.a[i][j][tmp.a[i][j].size()-1]==0)tmp.a[i][j].pop_back();//删去无用 0 ,但是要留一个
    return tmp;
}
signed main(){
    scanf("%lld",&m);
    //上文讲的预处理
    for(int mid=1;mid<N;mid<<=1)gg[mid]=qpow(g,(mod-1)/mid),gginv[mid]=qpow(ginv,(mod-1)/mid),lim[mid]=qpow(mid,mod-2);
    for(int i=1;i<=m;i++)scanf("%lld",&b[i]);
    for(int i=1;i<=m;){
        a[++n]=b[i];
        for(int j=0;j<b[i];j++){
            if(i+j>m)return printf("0\n"),0;
            if(b[i+j]!=b[i])return printf("0\n"),0;
        }
        i=i+b[i];
    }//变幻成新定义的 a 数组
    zj ans=merge(1,n);
    for(int i=0;i<2;i++)for(int j=0;j<2;j++){
        while(ans.a[i][j].size()<=n+1)ans.a[i][j].push_back(0);
    }
    for(int i=1;i<=n;i++)(ans.a[0][0][i]+=ans.a[0][1][i]+ans.a[1][0][i]+ans.a[1][1][i])%=mod;
    ll now=1,ANS=0;
    for(int i=1;i<=n;i++){//容斥,和上面的公式一样的
        (now*=i)%=mod;
        (ANS+=((n-i+1)&1?1:-1)*now*ans.a[0][0][i]%mod)%=mod;
        (ANS+=mod)%=mod;
    }
    return printf("%lld",ANS),0;
}

最后说点什么

第一次发黑题题解,第一次独立做出黑题。

需要卡常,不知道为什么我跑了 \(2.4\) min,别人都是 \(10\) s。

谢谢--zhengjun

posted @ 2022-06-11 15:29  A_zjzj  阅读(35)  评论(0编辑  收藏  举报