The 2021 ICPC Asia Shanghai Regional Programming Contest B. Strange Permutations
题目链接
破防了,上周很套路的ds题调半天还以为只是码力下降了,现在发现原本比较自信的计数能力也不行,不明白为什么连简单的容斥都想不清楚...
把题意抽象一下,就是有一个长度为n的排列对应的有向图,需要依次经过每一个点,并且不经过给定的边。
那么显然考虑容斥,枚举至少经过几条给定的边。然而如果对于序列考虑,这个经过关键边的顺序又要受到环长的限制。原本就觉得每一段长度不同的关键边对应的限制都不同,就以为不可做。
但是只要从选取哪些关键边考虑,就会发现每选一条边,其实就是把一个点合并到另一个点。因为每选取一段连续的关键边,就要求从这一段的开头按照顺序走到结尾,那就是把这一段缩成一个点。所以对于每一个环,直接考虑它选取了多少条关键边,就能确定缩成几个点,直接写成一个[环长]次的多项式,然后分治fft即可。
对于这种有比较具体限制的问题(输入不只是几个数),容斥的时候应该从容斥的对象考虑,思考容斥的那个集合所带来的影响是什么。而不是直接从合法答案的角度,去构造满足对应至少不合法个数的解,这样容易陷入混乱。
代码
#include<bits/stdc++.h>
using namespace std;
const int N=(1<<17),P=998244353,G[2]={3,(P+1)/3};
int rv[N],pw2[N];
int fpw(int a,int x){
int s=1;
for(;x;x>>=1,a=1ll*a*a%P) if(x&1) s=1ll*s*a%P;
return s;
}
void dft(int* a,int n,int p){
for(int i=0;i<n;i++) if(i<rv[i]) swap(a[i],a[rv[i]]);
for(int i=1;i<n;i<<=1){
int wn=fpw(G[p],(P-1)/(i*2));
for(int j=0;j<n;j+=(i<<1)){
int w=1;
for(int k=0;k<i;k++,w=1ll*wn*w%P){
int x=a[j+k],y=1ll*a[i+j+k]*w%P;
a[j+k]=(x+y)%P; a[i+j+k]=(x-y+P)%P;
}
}
}
}
int poly_mul(int* A,int* B,int* C,int m){
//puts("poly_mul");
//cout<<m<<endl;
int p=0,n=1;
while(n<=m) n<<=1,p++;
//for(int i=0;i<n;i++) cout<<A[i]<<" "<<B[i]<<endl;
for(int i=0;i<n;i++) rv[i]=(rv[i>>1]>>1)|((i&1)<<(p-1));
dft(A,n,0); dft(B,n,0);
for(int i=0;i<n;i++) C[i]=1ll*A[i]*B[i]%P;
dft(C,n,1); int iv=fpw(n,P-2);
for(int i=0;i<n;i++) C[i]=1ll*C[i]*iv%P;
//for(int i=0;i<n;i++) cout<<C[i]<<endl;
return n;
}
struct Cmb{
int fc[N],iv[N];
void init(int n){
fc[0]=1;
for(int i=1;i<=n;i++) fc[i]=1ll*fc[i-1]*i%P;
iv[n]=fpw(fc[n],P-2);
for(int i=n-1;~i;i--) iv[i]=1ll*iv[i+1]*(i+1)%P;
}
int C(int n,int m){
return 1ll*fc[n]*iv[m]%P*iv[n-m]%P;
}
}cmb;
int n,p[N],m,a[N],s[N],f[N];
int A[N],B[N],C[N];
void work(int l,int r){
int mid=(l+r)>>1;
if(l==r){
for(int i=s[l-1]+1;i<=s[l];i++) f[i]=cmb.C(a[l],i-s[l-1]);
return;
}
work(l,mid); work(mid+1,r);
//cout<<"l="<<l<<" r="<<r<<endl;
for(int i=s[l-1]+1;i<=s[mid];i++) A[i-s[l-1]]=f[i];
for(int i=s[mid]+1;i<=s[r];i++) B[i-s[mid]]=f[i];
int t=poly_mul(A,B,C,s[r]-s[l-1]);
for(int i=s[l-1]+1;i<=s[r];i++) f[i]=C[i-s[l-1]];//cout<<i<<" "<<f[i]<<endl;
for(int i=0;i<t;i++) A[i]=B[i]=C[i]=0;
}
int main()
{
//srand(time(0));
//freopen("1.in","r",stdin);
//freopen("1.out","w",stdout);
//int T; cin>>T; while(T--) work();
cin>>n;
cmb.init(n);
for(int i=1;i<=n;i++) scanf("%d",&p[i]);
for(int i=1;i<=n;i++) if(p[i]){
m++; a[m]=1; //cout<<i<<" "<<m<<endl;
int j=p[i]; p[i]=0;
while(j!=i){
a[m]++;
int t=j;
//cout<<"j="<<j<<endl;
j=p[j];
p[t]=0;
}
}
for(int i=1;i<=m;i++) s[i]=s[i-1]+a[i];
//return 0;
work(1,m);
int ans=0;
for(int i=n,p=1;i;i--,p=-p) (ans+=(P+1ll*f[i]*cmb.fc[i]%P*p)%P)%=P;
cout<<ans<<endl;
return 0;
}