【BZOJ3771】Triple 生成函数 FFT 容斥原理

题目大意

  有\(n\)把斧头,不同斧头的价值都不同且都是\([0,m]\)的整数。你可以选\(1\)~\(3\)把斧头,总价值为这三把斧头的价值之和。请你对于每种可能的总价值,求出有多少种选择方案。

  选\(2\)把斧头时,\((a,b)\)\((b,a)\)视为一种方案。选\(3\)把斧头时,\((a,b,c),(b,c,a),(c,a,b),(c,b,a),(b,a,c),(a,c,b)\)视为一种方案。

  \(m\leq 40000\).

题解

  考虑生成函数。

​ 设\(X\)是每种斧头取一个的生成函数,\(Y\)是每种斧头取两个的生成函数,\(Z\)是每种斧头取三个的生成函数,\(A\)是只取一个斧头的答案的生成函数,\(B\)是取两个斧头的答案的生成函数,\(C\)是取三个斧头的答案的生成函数。

  容斥一下。

\[A=X\\ B=\frac{X^2-Y}2\\ C=\frac{X^3-3XY+2Z}6 \]

  我来讲解一下第三个式子

  下文中第一项代表\(X^3\),第二项代表\(XY\),第三项代表\(Z\)

  对于方案\((a,a,a)\),会在第一项中出现\(1\)次,在第二项中出现\(3\)次,在第三项中出现\(1\)次。

  对于方案\((a,a,b)\),会在第一项中出现\(3\)次,在第二项中出现\(1\)次,在第三项中出现\(0\)次。

  对于方案\((a,b,c)\),会在第一项中出现\(6\)次,在第二项中出现\(0\)次,在第三项中出现\(0\)次。

  这样\((a,a,a)\)\((a,a,b)\)的贡献就会全部被消掉,所以答案是对的。

  然后用FFT加速。

  时间复杂度:\(O(m\log m)\)

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
double pi=acos(-1);
int n=131072;
struct cp
{
	double x,y;
	cp(double a=0,double b=0)
	{
		x=a;
		y=b;
	}
};
cp operator +(cp a,cp b){return cp(a.x+b.x,a.y+b.y);}
cp operator -(cp a,cp b){return cp(a.x-b.x,a.y-b.y);}
cp operator *(cp a,cp b){return cp(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
cp operator *(cp a,double b){return cp(a.x*b,a.y*b);}
cp operator /(cp a,double b){return cp(a.x/b,a.y/b);}
cp a[200010];
cp b[200010];
cp c[200010];
cp w1[200010];
cp w2[200010];
int rev[200010];
void fft(cp *a,int t)
{
	int i,j,k;
	cp w,wn,u,v;
	for(i=0;i<n;i++)
		if(rev[i]<i)
			swap(a[i],a[rev[i]]);
	for(i=2;i<=n;i<<=1)
	{
		wn=(t==1?w1[i]:w2[i]);
		for(j=0;j<n;j+=i)
		{
			w=cp(1,0);
			for(k=j;k<j+i/2;k++)
			{
				u=a[k];
				v=a[k+i/2]*w;
				a[k]=u+v;
				a[k+i/2]=u-v;
				w=w*wn;
			}
		}
	}
	if(t==-1)
		for(i=0;i<n;i++)
			a[i]=a[i]/n;
}
int main()
{
	int i;
	for(i=2;i<=n;i<<=1)
	{
		w1[i]=cp(cos(2*pi/i),sin(2*pi/i));
		w2[i]=cp(cos(2*pi/i),-sin(2*pi/i));
	}
	rev[0]=0;
	for(i=1;i<n;i++)
		rev[i]=(rev[i>>1]>>1)|(i&1?n/2:0);
	int m;
	scanf("%d",&m);
	int x;
	for(i=1;i<=m;i++)
	{
		scanf("%d",&x);
		a[x].x+=1;
		b[2*x].x+=1;
		c[3*x].x+=1;
	}
	fft(a,1);
	fft(b,1);
	fft(c,1);
	for(i=0;i<n;i++)
		a[i]=a[i]+(a[i]*a[i]-b[i])/2+(a[i]*a[i]*a[i]-a[i]*b[i]*3+c[i]*2)/6;
	fft(a,-1);
	for(i=0;i<n;i++)
	{
		ll s=round(a[i].x);
		if(s>0)
			printf("%d %lld\n",i,s);
	}
	return 0;
}
posted @ 2018-03-05 20:22  ywwyww  阅读(416)  评论(0编辑  收藏  举报