【XSY2744】信仰圣光 分治FFT 多项式exp 容斥原理

题目描述

  有一个\(n\)个元素的置换,你要选择\(k\)个元素,问有多少种方案满足:对于每个轮换,你都选择了其中的一个元素。

  对\(998244353\)取模。

  \(k\leq n\leq 152501\)

题解

吐槽

  为什么一道FFT题要把\(n\)设为\(150000\)

解法一

  先把轮换拆出来。

  直接DP。

  设\(f_{i,j}\)为前\(i\)个轮换选择了\(j\)个元素,且每个轮换都选择了至少一个元素的方案数。

\[f_{i,j}=\sum_{k=1}^{a_i}f_{i-1,j-k}\binom{a_i}{k} \]

  时间复杂度为\(O(n^2)\),因为枚举的是第\(i\)组和前\(i-1\)组的配对,而任意两个元素之间最多被配对一次。

  可以分治FFT做到\(O(n\log^2 n)\)

解法二

  考虑容斥。

  设\(m\)为轮换个数。

  枚举有哪些轮换\(S\)中可能有被选中的元素,容斥系数就是\({(-1)}^{m-|S|}\)\(sum\)为这些轮换的大小总和):

  或者枚举哪些轮换\(S\)中没有被选中的元素,容斥系数就是\({(-1)}^{|S|}\)

\[\begin{align} s&=\sum_{S}{(-1)}^{m-|S|}\binom{sum}{k}\\ s&=\sum_{S}{(-1)}^{|S|}\binom{n-sum}{k}\\ \end{align} \]

  现在我们要对于每一个\(i\),计算\(f_i=\sum_{S,sum=i}{(-1)}^{|S|}\)

  构造生成函数\(A_i(x)=1-x^{a_i}\),那么\(F(x)=\prod_{i=1}^mA_i(x)\)

  直接做还是\(O(n\log^2n)\)的。我们需要一些优化。

\[\begin{align} F(x)&=\prod_{i=1}^m1-x^{a_i}\\ \ln(F(x))&=\sum_{i=1}^n\ln(1-x^{a_i})\\ \ln(F(x))&=\sum_{i=1}^n\sum_{j=a_i}-\frac{x^{ja_i}}{j} \end{align} \]

  那么可以在\(O(n\log n)\)内算出\(\ln(F(x))\),然后\(\exp\)一下。

  时间复杂度:\(O(n\log n)\)

  由于常数过大,所以要用下面那条式子(因为只用计算到\(x^{n-k}\))。

解法一

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<utility>
#include<iostream>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
void open(const char *s)
{
#ifndef ONLINE_JUDGE
	char str[100];
	sprintf(str,"%s.in",s);
	freopen(str,"r",stdin);
	sprintf(str,"%s.out",s);
	freopen(str,"w",stdout);
#endif
}
int rd()
{
	int s=0,c;
	while((c=getchar())<'0'||c>'9');
	s=c-'0';
	while((c=getchar())>='0'&&c<='9')
		s=s*10+c-'0';
	return s;
}
const int p=998244353;
const int g=3;
ll fp(ll a,ll b)
{
	ll s=1;
	for(;b;b>>=1,a=a*a%p)
		if(b&1)
			s=s*a%p;
	return s;
}
ll inv[200010];
ll fac[200010];
ll ifac[200010];
int a[200010];
int n,m,k;
int c[200010];
int b[200010];
ll getc(int x,int y)
{
	return fac[x]*ifac[y]%p*ifac[x-y]%p;
}
ll *f[500010];
int len[500010];
int cnt;
int a1[600010];
int a2[600010];
int rev[600010];
void ntt(int *a,int n,int t)
{
	for(int i=1;i<n;i++)
	{
		rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
		if(i>rev[i])
			swap(a[i],a[rev[i]]);
	}
	for(int i=2;i<=n;i<<=1)
	{
		int wn=fp(g,(p-1)/i*(t==1?1:i-1));
		for(int j=0;j<n;j+=i)
		{
			int w=1;
			for(int k=j;k<j+i/2;k++)
			{
				int u=a[k];
				int v=(ll)a[k+i/2]*w%p;
				a[k]=(u+v)%p;
				a[k+i/2]=(u-v)%p;
				w=(ll)w*wn%p;
			}
		}
	}
	if(t==-1)
	{
		int inv=fp(n,p-2);
		for(int i=0;i<n;i++)
			a[i]=(ll)a[i]*inv%p;
	}
}
void solve(int &now,int l,int r)
{
	now=++cnt;
	if(l==r)
	{
		len[now]=min(a[l],k);
		f[now]=new ll[len[now]+1];
		f[now][0]=0;
		for(int i=1;i<=len[now];i++)
			f[now][i]=ifac[i]*ifac[a[l]-i]%p;
		return;
	}
	int ls,rs,mid=(l+r)>>1;
	solve(ls,l,mid);
	solve(rs,mid+1,r);
	len[now]=min(len[ls]+len[rs],k);
	f[now]=new ll[len[now]+1];
	int v=1;
	while(v<=len[ls]+len[rs])
		v<<=1;
	for(int i=0;i<v;i++)
		a1[i]=(i<=len[ls]?f[ls][i]:0);
	for(int i=0;i<v;i++)
		a2[i]=(i<=len[rs]?f[rs][i]:0);
	ntt(a1,v,1);
	ntt(a2,v,1);
	for(int i=0;i<v;i++)
		a1[i]=(ll)a1[i]*a2[i]%p;
	ntt(a1,v,-1);
	for(int i=0;i<=len[now];i++)
		f[now][i]=a1[i];
	delete [] f[ls];
	delete [] f[rs];
}
void solve()
{
//	scanf("%d%d",&n,&k);
	n=rd();
	k=rd();
	for(int i=1;i<=n;i++)
		c[i]=rd();
//		scanf("%d",&c[i]);
	if(k==n)
	{
		printf("1\n");
		return;
	}
	m=0;
	cnt=0;
	memset(b,0,sizeof b);
	memset(a,0,sizeof a);
	for(int i=1;i<=n;i++)
		if(!b[i])
		{
			m++;
			for(int j=i;!b[j];j=c[j])
			{
				b[j]=1;
				a[m]++;
			}
		}
	if(k<m)
	{
		printf("0\n");
		return;
	}
	int rt;
	solve(rt,1,m);
	ll ans=f[rt][k];
	ans=ans*fp(getc(n,k),p-2)%p;
	for(int i=1;i<=m;i++)
		ans=ans*fac[a[i]]%p;
	ans=(ans+p)%p;
	printf("%lld\n",ans);
}
int main()
{
	open("a");
	inv[1]=fac[0]=fac[1]=ifac[0]=ifac[1]=1;
	for(int i=2;i<=200000;i++)
	{
		inv[i]=-p/i*inv[p%i]%p;
		fac[i]=fac[i-1]*i%p;
		ifac[i]=ifac[i-1]*inv[i]%p;
	}
	int t;
//	scanf("%d",&t);
	t=rd();
	while(t--)
		solve();
	return 0;
}

解法二

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<utility>
#include<iostream>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
int rd()
{
	int s=0,c;
	while((c=getchar())<'0'||c>'9');
	s=c-'0';
	while((c=getchar())>='0'&&c<='9')
		s=s*10+c-'0';
	return s;
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
	char str[100];
	sprintf(str,"%s.in",s);
	freopen(str,"r",stdin);
	sprintf(str,"%s.out",s);
	freopen(str,"w",stdout);
#endif
}
const int p=998244353;
const int g=3;
ll fp(ll a,ll b)
{
	ll s=1;
	for(;b;b>>=1,a=a*a%p)
		if(b&1)
			s=s*a%p;
	return s;
}
ll inv[300010];
ll fac[300010];
ll ifac[300010];
namespace ntt
{
	int rev[600000];
	void ntt(int *a,int n,int t)
	{
		for(int i=1;i<n;i++)
		{
			rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
			if(i>rev[i])
				swap(a[i],a[rev[i]]);
		}
		for(int i=2;i<=n;i<<=1)
		{
			int wn=fp(g,(p-1)/i*(t==1?1:i-1));
			for(int j=0;j<n;j+=i)
			{
				int w=1;
				for(int k=j;k<j+i/2;k++)
				{
					int u=a[k];
					int v=(ll)a[k+i/2]*w%p;
					a[k]=(u+v)%p;
					a[k+i/2]=(u-v)%p;
					w=(ll)w*wn%p;
				}
			}
		}
		if(t==-1)
		{
			int inv=fp(n,p-2);
			for(int i=0;i<n;i++)
				a[i]=(ll)a[i]*inv%p;
		}
	}
	void getinv(int *a,int *b,int n)
	{
		if(n==1)
		{
			b[0]=fp(a[0],p-2);
			return;
		}
		getinv(a,b,n>>1);
		static int a1[600000],a2[600000];
		for(int i=0;i<n;i++)
			a1[i]=a[i];
		for(int i=n;i<n<<1;i++)
			a1[i]=0;
		for(int i=0;i<n>>1;i++)
			a2[i]=b[i];
		for(int i=n>>1;i<n<<1;i++)
			a2[i]=0;
		ntt(a1,n<<1,1);
		ntt(a2,n<<1,1);
		for(int i=0;i<n<<1;i++)
			a1[i]=a2[i]*(2-(ll)a1[i]*a2[i]%p)%p;
		ntt(a1,n<<1,-1);
		for(int i=0;i<n;i++)
			b[i]=a1[i];
	}
	void getln(int *a,int *b,int n)
	{
		static int a1[600000],a2[600000];
		for(int i=1;i<n;i++)
			a1[i-1]=(ll)a[i]*i%p;
		a1[n-1]=0;
		getinv(a,a2,n);
		for(int i=n;i<n<<1;i++)
			a1[i]=a2[i]=0;
		ntt(a1,n<<1,1);
		ntt(a2,n<<1,1);
		for(int i=0;i<n<<1;i++)
			a1[i]=(ll)a1[i]*a2[i]%p;
		ntt(a1,n<<1,-1);
		for(int i=1;i<n;i++)
			b[i]=(ll)a1[i-1]*inv[i]%p;
		b[0]=0;
	}
	void getexp(int *a,int *b,int n)
	{
		if(n==1)
		{
			b[0]=1;
			return;
		}
		getexp(a,b,n>>1);
		static int a1[600000],a2[600000],a3[600000];
		for(int i=n>>1;i<n;i++)
			b[i]=0;
		getln(b,a3,n);
		for(int i=0;i<n>>1;i++)
		{
			a1[i]=b[i];
			a2[i]=(a[i+(n>>1)]-a3[i+(n>>1)])%p;
		}
		for(int i=n>>1;i<n;i++)
			a1[i]=a2[i]=0;
		ntt(a1,n,1);
		ntt(a2,n,1);
		for(int i=0;i<n;i++)
			a1[i]=(ll)a1[i]*a2[i]%p;
		ntt(a1,n,-1);
		for(int i=0;i<n>>1;i++)
			b[i+(n>>1)]=a1[i];
	}
}
int a[200010];
int n,m,k;
int c[200010];
int b[200010];
int cnt;
ll ans;
int d[300010];
int s[300010];
int f[300010];
ll getc(int x,int y)
{
	if(y>x||y<0)
		return 0;
	return fac[x]*ifac[y]%p*ifac[x-y]%p;
}
void dfs(int x,int y,int v)
{
	if(x>m)
	{
		ans=(ans+v*getc(y,k))%p;
		return;
	}
	dfs(x+1,y,v);
	dfs(x+1,y+a[x],-v);
}
void solve()
{
//	scanf("%d%d",&n,&k);
	n=rd();
	k=rd();
	for(int i=1;i<=n;i++)
		c[i]=rd();
//		scanf("%d",&c[i]);
	if(k==n)
	{
		printf("1\n");
		return;
	}
	m=0;
	cnt=0;
	memset(b,0,sizeof b);
	memset(a,0,sizeof a);
	for(int i=1;i<=n;i++)
		if(!b[i])
		{
			m++;
			for(int j=i;!b[j];j=c[j])
			{
				b[j]=1;
				a[m]++;
			}
		}
	if(k<m)
	{
		printf("0\n");
		return;
	}
	memset(d,0,sizeof d);
	memset(s,0,sizeof s);
	for(int i=1;i<=m;i++)
		d[a[i]]++;
	for(int i=1;i<=n;i++)
		if(d[i])
			for(int j=1;i*j<=n;j++)
				s[i*j]=(s[i*j]-inv[j]*d[i])%p;
	int l=1;
	while(l<=n-k)
		l<<=1;
	s[0]=1;
	ntt::getexp(s,f,l);
	ans=0;
	for(int i=0;i<=n-k;i++)
		ans=(ans+f[i]*getc(n-i,k))%p;
//		ans=(ans+f[i]*getc(i,k))%p;
	ans=ans*fp(getc(n,k),p-2)%p;
//	if(m&1)
//		ans=-ans;
	ans=(ans+p)%p;
	printf("%lld\n",ans);
}
int main()
{
	open("a");
	inv[1]=fac[0]=fac[1]=ifac[0]=ifac[1]=1;
	for(int i=2;i<=300000;i++)
	{
		inv[i]=-p/i*inv[p%i]%p;
		fac[i]=fac[i-1]*i%p;
		ifac[i]=ifac[i-1]*inv[i]%p;
	}
	int t;
//	scanf("%d",&t);
	t=rd();
	while(t--)
		solve();
	return 0;
}
posted @ 2018-03-13 19:57  ywwyww  阅读(377)  评论(0编辑  收藏  举报