【XSY2535】整数 NTT

题目描述

  问有多少个满足以下要求的\(k\)进制数:

   1.每个数字出现的次数不超过\(n\)

   2.\(0\)没有出现过

   3.若\(g_{i,j}=0\),则\(i\)不能出现恰好\(j\)次。

  两次询问之间会修改\(g\)中一个位置的值(\(0\)\(1\)\(1\)\(0\))。

  输出所有询问的答案的和。

  \(3\leq k\leq 10,n\leq 14000,m\leq 20\)

  模数\(p=786433\),原根\(g=10\)

题解

  假设第\(i\)个数用了\(c_i\)个,答案为

\[\frac{(\sum c_i)!}{\prod c_i!} \]

  构造多项式

\[f_i(x)=\sum_{j=0}^n\frac{g_{i,j}}{j!}x^j \]

  把这\(k-1\)个多项式乘起来后,第\(i\)项乘以\(i!\)的和就是答案。

  因为求的是答案的和,所以可以在点值表达的形式下累加答案,最后IDFT回来。

​ 怎么求没修改前的答案?

  直接DFT

  怎么求修改的贡献?

  观察NTT的公式:

\[y_k=\sum_{j=0}^{n-1}a_j{(g^\frac{p-1}{n})}^{kj} \]

  对于一个单点修改操作,可以看成在某个多项式上加上一个只有一项系数不为\(0\)的多项式。这个多项式DFT后就是一个等比数列,直接加到原多项式上就完了。

  对于所有多项式的乘积:如果所有多项式的每一项都非\(0\),就直接乘以逆元。现在有\(0\),就记录每一项\(0\)的个数和非\(0\)的乘积。

  时间复杂度:\(O(nk^2\log (nk)+mnk)\)

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
const ll p=786433;
const ll g=10;
ll inv[1000010];
ll fac[1000010];
ll ifac[1000010];
ll pg[1000010];
ll fp(ll a,ll b)
{
	ll s=1;
	while(b)
	{
		if(b&1)
			s=s*a%p;
		a=a*a%p;
		b>>=1;
	}
	return s;
}
namespace ntt
{
	int n;
	ll w1[150000];
	ll w2[150000];
	int rev[150000];
	void init(int x)
	{
		n=1;
		while(n<=x)
			n<<=1;
		int i;
		for(i=1;i<=n;i<<=1)
		{
			w1[i]=fp(g,(p-1)/i);
			w2[i]=inv[w1[i]];
		}
		rev[0]=0;
		for(i=1;i<n;i++)
			rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
	}
	void ntt(ll *a,int t)
	{
		int i,j,k;
		ll u,v,w,wn;
		for(i=0;i<=n-1;i++)
			if(rev[i]<i)
				swap(a[i],a[rev[i]]);
		for(i=2;i<=n;i<<=1)
		{
			wn=(t==1?w1[i]:w2[i]);
			for(j=0;j<n;j+=i)
			{
				w=1;
				for(k=j;k<j+i/2;k++)
				{
					u=a[k];
					v=a[k+i/2]*w%p;
					a[k]=(u+v)%p;
					a[k+i/2]=(u-v)%p;
					w=w*wn%p;
				}
			}
		}
		if(t==-1)
			for(i=0;i<n;i++)
				a[i]=a[i]*inv[n]%p;
	}
}
int &nn=ntt::n;
char s[14010];
int c[12][14010];
void init()
{
	int i;
	inv[0]=inv[1]=1;
	for(i=2;i<=p-1;i++)
		inv[i]=(-(p/i)*inv[p%i]%p+p)%p;
	fac[0]=ifac[0]=1;
	for(i=1;i<=p-1;i++)
	{
		fac[i]=fac[i-1]*i%p;
		ifac[i]=ifac[i-1]*inv[i]%p;
	}
}
ll ans;
ll d[12][150000];
ll f[150000];
ll f2[150000];
ll a[150000];
int k,n,m;
int main()
{
	freopen("a.in","r",stdin);
	freopen("a.out","w",stdout);
	init();
	int i,j;
	scanf("%d%d%d",&k,&n,&m);
	ntt::init((k-1)*n);
	pg[0]=1;
	for(i=1;i<=p-2;i++)
		pg[i]=pg[i-1]*g%p;
	for(i=1;i<=k-1;i++)
	{
		scanf("%s",s);
		for(j=0;j<=n;j++)
			c[i][j]=s[j]-'0';
	}
	ans=0;
	for(i=0;i<nn;i++)
		f[i]=1;
	for(i=1;i<=k-1;i++)
	{
		ll *u=d[i];
		for(j=0;j<nn;j++)
			u[j]=0;
		for(j=0;j<=n;j++)
			u[j]=c[i][j]*ifac[j]%p;
		ntt::ntt(u,1);
		for(j=0;j<nn;j++)
		{
			if(u[j]<0)
				u[j]+=p;
			if(u[j])
				f[j]=f[j]*u[j]%p;
			else
				f2[j]++;
		}
	}
	for(i=0;i<nn;i++)
		if(!f2[i])
			a[i]=(a[i]+f[i])%p;
	int x,y;
	int t;
	for(t=1;t<=m;t++)
	{
		scanf("%d%d",&x,&y);
		c[x][y]^=1;
		for(i=0;i<nn;i++)
			if(d[x][i])
				f[i]=f[i]*inv[d[x][i]]%p;
			else
				f2[i]--;
		ll s1=pg[((p-1)/nn*y)%(p-1)],s2=ifac[y];
		if(!c[x][y])
			s2=p-s2;
		for(i=0;i<nn;i++)
		{
			d[x][i]+=s2;
			if(d[x][i]>=p)
				d[x][i]-=p;
			s2=s2*s1%p;
		}
		for(i=0;i<nn;i++)
		{
			if(d[x][i])
				f[i]=f[i]*d[x][i]%p;
			else
				f2[i]++;
			if(!f2[i])
				a[i]=(a[i]+f[i])%p;
		}
	}
	ntt::ntt(a,-1);
	for(i=1;i<nn;i++)
		ans=(ans+a[i]*fac[i])%p;
	ans=(ans%p+p)%p;
	printf("%lld\n",ans);
	return 0;
}
posted @ 2018-03-06 10:55  ywwyww  阅读(185)  评论(0编辑  收藏  举报