LOJ #575. 不等关系 做题笔记
LOJ #575. 不等关系 做题笔记
我们把这个式子看作是一堆小于号中间被几个大于号分割成了好几段,如果我们不考虑两端的交界部分的大于关系,我们就可以得到答案:
\[Ans=\dfrac{(n+1)!}{\prod len_i!}
\]
然后我们考虑容斥: \(边界大于=不考虑边界大于小于-边界小于\)
于是我们设 \(f_i=满足i-1个限制的i排列个数\)
我们假设第 \(0\) 个位置是 \(>\) ( \(n-1\) 个限制由 \(1\) 开始),令 \(a_i\) 表示从 \(0\) 到 \(i-1\) 中有多少个 \(>\) ,然后我们可以得到递推式:
\[f_i=\sum_{i=0}^{i-1}\left[s_i='>'\right]\times(-1)^{a_i-a_j-1}\times f_j\times \dbinom{i}{j}
\]
我们考虑这三个式子,第一个表示必须由 \(s_i='>'\) 的地方转移而来,第二个是容斥系数,表示中间我们把多少个 \(>\) 当成了 \(<\) ,减一是容斥系数的细节,第三个是表示 \(j\) 的答案,第四个表示我们从 \(i\) 个数中选择 \(j\) 个数字组成 \(f_j\) 的方案(显然我们没有用到排列的性质,只要大小关系满足就好),剩下的 \(i-j\) 个数字按升序放在 \(j+1\) 到 \(i\) 的位置上。
然后我们就可以开心的推柿子啦!
\[f_i=\sum_{j=0}^{i-1}\left[s_i='>'\right]\times(-1)^{a_i-a_j-1}\times f_j\times \dbinom{i}{j}\\f_i=-\sum_{j=0}^{i-1}\left[s_i='>'\right]\times(-1)^{a_j-a_i}\times f_j\times \dfrac{i!}{j!(i-j)!}\\\dfrac{(-1)^{a_i}\times f_i}{i!}=-\sum_{j=0}^{i-1}[s_j='>']\times\dfrac{(-1)^{a_j}\times f_j}{j!}\times\dfrac{1}{(i-j)!}
\]
然后你就发现这是个分治 \(FFT\) 的板子了!
时间复杂度:\(\Theta(n\log^2n)\)
\(code\) :
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <cstdlib>
#define FUP(i,x,y) for(int i=(x);i<=(y);i++)
#define FDW(i,x,y) for(int i=(x);i>=(y);i--)
#define FED(i,x) for(int i=head[x];i;i=ed[i].nxt)
#define pr pair<int,int>
#define mkp(a,b) make_pair(a,b)
#define fi first
#define se second
#define MAXN 400010
#define INF 0x3f3f3f3f
#define LLINF 0x3f3f3f3f3f3f3f3f
#define eps 1e-9
#define MOD 998244353
#define ll long long
#define db double
using namespace std;
int read()
{
int w=0,flg=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-'){flg=-1;}ch=getchar();}
while(ch<='9'&&ch>='0'){w=w*10+ch-'0',ch=getchar();}
return w*flg;
}
ll poww(ll a,int b)
{
ll ans=1,base=a;
while(b)
{
if(b&1) ans=ans*base%MOD;
base=base*base%MOD;
b>>=1;
}
return ans;
}
const ll G=3,Gi=332748118;
int limit,L,rev[MAXN];
void NTT_init(int len)
{
limit=1,L=0;
while(limit<=len) limit<<=1,L++;
FUP(i,0,limit-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
}
void NTT(ll *a,ll typ)
{
FUP(i,0,limit-1) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int mid=1;mid<limit;mid<<=1)
{
ll Wn=poww(typ,(MOD-1)/(mid<<1));
for(int j=0;j<limit;j+=(mid<<1))
{
ll w=1;
for(int k=0;k<mid;k++,w=w*Wn%MOD)
{
ll x=a[j+k],y=a[j+k+mid]*w%MOD;
a[j+k]=(x+y)%MOD;
a[j+k+mid]=(x-y+MOD)%MOD;
}
}
}
}
int n,bl[MAXN];
char str[MAXN];
ll fac[MAXN],invfac[MAXN],ans[MAXN],a[MAXN],b[MAXN];
#define mid ((l+r)>>1)
void solve(int l,int r)
{
if(l==r)
{
if(!l) return;
ans[l]=MOD-ans[l];
//printf("l=%d ans=%lld\n",l,ans[l]);
return;
}
solve(l,mid);
NTT_init(r-l+mid-l);
FUP(i,0,r-l) a[i]=invfac[i];
FUP(i,l,mid) b[i-l]=(str[i]!='<')*ans[i]%MOD;
NTT(a,G),NTT(b,G);
FUP(i,0,limit-1) a[i]=a[i]*b[i]%MOD;
NTT(a,Gi);
ll inv=poww(limit,MOD-2);
FUP(i,0,limit-1) a[i]=a[i]*inv%MOD;
FUP(i,mid+1,r) ans[i]=(ans[i]+a[i-l])%MOD;
/*printf("%d~%d :",mid+1,r);
FUP(i,0,limit-1) printf("%lld ",ans[i]);
puts("");*/
FUP(i,0,limit-1) a[i]=b[i]=0;
solve(mid+1,r);
}
int main(){
scanf("%s",str+1);
n=strlen(str+1)+1;
fac[0]=1;
FUP(i,1,n) fac[i]=fac[i-1]*i%MOD;
invfac[n]=poww(fac[n],MOD-2);
FDW(i,n-1,0) invfac[i]=invfac[i+1]*(i+1)%MOD;
str[0]='>';
int cur=0;
FUP(i,0,n-1) if(str[i]!='<') cur^=1;
//FUP(i,0,n) printf("%d ",bl[i]);
//puts("");
ans[0]=1;
solve(0,n);
if(cur) ans[n]=MOD-ans[n];
printf("%lld\n",ans[n]*fac[n]%MOD);
return 0;
}