LOJ 3395 集训队作业 Yet Another Permutation Problem 题解 (生成函数技巧)
原排列进行k次操作后,一定会剩下一个长度至少为\(n-k\)的上升区间。观察发现,一个排列能用\(\le k\)次操作得到的充要条件是:其中的最长上升区间长度\(\ge n-k\)。
我们枚举k,对每个k计算答案。最长上升区间长度\(\ge n-k\)的排列个数不太好算,我们可以算出最长上升区间长度\(\le n-k-1\)的排列个数,然后用总数减。为了方便,接下来的\(k\)表示上面的\(n-k-1\)。
我们先把序列划分成一些段,要求每一段都是单调递增,且每段长度\(\le k\),求总方案数(我知道会重复计数,但你先别急)。这是一个经典的生成函数计数题。令\(f_i\)表示i个不同元素排成上升序列的方案数,显然\(f_i=1(i\le k)\),为了方便使用生成函数,我们令\(f_0=0\)。令\(F(x)\)为f的指数型生成函数,\(G(x)\)为f的普通生成函数。\(F(x)=\sum_{i=1}^k \frac{x^i}{i!}\),\(G(x)=\sum_{i=1}^k x^i\)。根据EGF的组合意义,这个问题的答案为:\(([x^n]\sum_{i=0}^\infin F^i(x))\cdot n!\),其中\([x^n]\)表示这个多项式的n次项系数。
再来看看这个问题与原问题的关系。原问题中每个排列在这个问题中都被计数了多次。具体来说,对于一个排列,我们先把他划分成若干极长的上升区间,令它们的长度从左到右为\(l_1\cdots l_m\)。这个排列在这个问题中其实被计数了这么多次:\(\prod_{i=1}^m ([x^{l_i}]\sum_{j=0}^\infin G^j(x))\),从组合意义的角度这个式子也非常好理解。那么我们希望这个排列被计数多少次呢?其实是\(\prod_{i=1}^m [l_i\le k]\)次。如果我们能适当地修改f数组,使得对于任意\(l\),\([x^{l}]\sum_{i=0}^\infin G^i(x)=[l\le k]\),那么只要按照这个问题的流程跑一遍,就得到原问题的答案了。这其实是可以做到的。
如果能求出\(G(x)\),也就确定了f,也就能求出\(F(x)\)从而求出答案了。这里似乎可以用多项式求逆,但是模数不是998244353所以会有问题。写任意模NTT的话常数又似乎太大了。我们来观察一下暴力多项式求逆的过程,看看能不能投机取巧:
为了避免混淆,重申一下"多项式求逆"的含义:对于一个多项式\(A(x)\),求出一个多项式\(B(x)\),使得\(A,B\)乘积的常数项为1,且1次项到\(n\)次项的系数都为0,更高次项的系数不管。
观察上面的暴力,发现其复杂度与\(a\)中不为0的项的数量有关(求那个sigma的时候可以只枚举a中不为0的项)。在求\(G(x)\)的过程中我们要给\(1-x^{k+1}\)求逆,打表(划掉)手画(划掉)观察一下发现\(1-x^{k+1}\)的逆中只有大约\(\frac nk\)项不为0,所以\(G(x),F(x)\)中不为0的项数也差不多是这个数量级。这部分的复杂度是\(O(n^2)\)(算上之前枚举k)。
然后是求出\(F(x)\)之后计算答案。我们需要计算\(\sum_{i=0}^\infin F^i(x)=\frac 1{1-F(x)}\)这个多项式。对\(1-F(x)\)求逆时,由于其中非0项的个数是\(\frac nk\)级别,求逆的复杂度是\(O(\frac {n^2}k)\)的。算上枚举\(k\),总复杂度是\(O(n^2logn)\)的。
点击查看代码
#include <bits/stdc++.h>
#define rep(i,n) for(int i=0;i<n;++i)
#define repn(i,n) for(int i=1;i<=n;++i)
#define LL long long
#define pii pair <int,int>
#define fi first
#define se second
#define pb push_back
#define mpr make_pair
using namespace std;
void fileio()
{
#ifdef LGS
freopen("in.txt","r",stdin);
freopen("out.txt","w",stdout);
#endif
}
void termin()
{
#ifdef LGS
std::cout<<"\n\nEXECUTION TERMINATED";
#endif
exit(0);
}
LL n,MOD,fac[1010],inv[1010];
LL qpow(LL x,LL a)
{
LL res=x,ret=1;
while(a>0)
{
if(a&1) (ret*=res)%=MOD;
a>>=1;
(res*=res)%=MOD;
}
return ret;
}
vector <LL> getInv(vector <LL> A)
{
vector <pii> v;
rep(i,A.size()) if(A[i]!=0) v.pb(mpr(i,A[i]));
vector <LL> ret;
ret.pb(qpow(A[0],MOD-2));
repn(i,A.size()-1)
{
LL hv=0;
rep(j,v.size())
{
if(v[j].fi>i) break;
if(i-v[j].fi<ret.size()) (hv+=v[j].se*ret[i-v[j].fi])%=MOD;
}
hv=(MOD-hv)%MOD;
ret.pb(hv*ret[0]%MOD);
}
return ret;
}
vector <LL> polyMul(vector <LL> A,LL mns)
{
vector <LL> ret;ret.pb(0);rep(i,A.size()-1) ret.pb(A[i]);
LL mul=MOD-1;
rep(i,A.size()) if(i+mns<ret.size()) (ret[i+mns]+=A[i]*mul)%=MOD;
return ret;
}
LL calc(LL k)
{
if(k==0) return 0;
vector <LL> A;
A.pb(1);repn(i,k) A.pb(0);A.pb(MOD-1);while(A.size()<n+1) A.pb(0);
vector <LL> B=getInv(A);
vector <LL> F=polyMul(B,k+1);
rep(i,F.size()) (F[i]*=inv[i])%=MOD;
rep(i,F.size()) F[i]=(MOD-F[i])%MOD;
(++F[0])%=MOD;
F=getInv(F);
LL ret=F[n]*fac[n]%MOD;
return ret;
}
int main()
{
fileio();
cin>>n>>MOD;
fac[0]=1;repn(i,1005) fac[i]=fac[i-1]*i%MOD;
rep(i,1003) inv[i]=qpow(fac[i],MOD-2);
rep(nk,n)
{
LL k=n-nk-1;
//答案=总数-所有上升子段的长度都<=k的排列数量
LL ans=(fac[n]-calc(k)+MOD)%MOD;
printf("%lld\n",ans);
}
termin();
}