【BZOJ3771】Triple 生成函数 FFT 容斥原理
题目大意
有\(n\)把斧头,不同斧头的价值都不同且都是\([0,m]\)的整数。你可以选\(1\)~\(3\)把斧头,总价值为这三把斧头的价值之和。请你对于每种可能的总价值,求出有多少种选择方案。
选\(2\)把斧头时,\((a,b)\)和\((b,a)\)视为一种方案。选\(3\)把斧头时,\((a,b,c),(b,c,a),(c,a,b),(c,b,a),(b,a,c),(a,c,b)\)视为一种方案。
\(m\leq 40000\).
题解
考虑生成函数。
设\(X\)是每种斧头取一个的生成函数,\(Y\)是每种斧头取两个的生成函数,\(Z\)是每种斧头取三个的生成函数,\(A\)是只取一个斧头的答案的生成函数,\(B\)是取两个斧头的答案的生成函数,\(C\)是取三个斧头的答案的生成函数。
容斥一下。
\[A=X\\
B=\frac{X^2-Y}2\\
C=\frac{X^3-3XY+2Z}6
\]
我来讲解一下第三个式子
下文中第一项代表\(X^3\),第二项代表\(XY\),第三项代表\(Z\)
对于方案\((a,a,a)\),会在第一项中出现\(1\)次,在第二项中出现\(3\)次,在第三项中出现\(1\)次。
对于方案\((a,a,b)\),会在第一项中出现\(3\)次,在第二项中出现\(1\)次,在第三项中出现\(0\)次。
对于方案\((a,b,c)\),会在第一项中出现\(6\)次,在第二项中出现\(0\)次,在第三项中出现\(0\)次。
这样\((a,a,a)\)和\((a,a,b)\)的贡献就会全部被消掉,所以答案是对的。
然后用FFT加速。
时间复杂度:\(O(m\log m)\)
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
double pi=acos(-1);
int n=131072;
struct cp
{
double x,y;
cp(double a=0,double b=0)
{
x=a;
y=b;
}
};
cp operator +(cp a,cp b){return cp(a.x+b.x,a.y+b.y);}
cp operator -(cp a,cp b){return cp(a.x-b.x,a.y-b.y);}
cp operator *(cp a,cp b){return cp(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
cp operator *(cp a,double b){return cp(a.x*b,a.y*b);}
cp operator /(cp a,double b){return cp(a.x/b,a.y/b);}
cp a[200010];
cp b[200010];
cp c[200010];
cp w1[200010];
cp w2[200010];
int rev[200010];
void fft(cp *a,int t)
{
int i,j,k;
cp w,wn,u,v;
for(i=0;i<n;i++)
if(rev[i]<i)
swap(a[i],a[rev[i]]);
for(i=2;i<=n;i<<=1)
{
wn=(t==1?w1[i]:w2[i]);
for(j=0;j<n;j+=i)
{
w=cp(1,0);
for(k=j;k<j+i/2;k++)
{
u=a[k];
v=a[k+i/2]*w;
a[k]=u+v;
a[k+i/2]=u-v;
w=w*wn;
}
}
}
if(t==-1)
for(i=0;i<n;i++)
a[i]=a[i]/n;
}
int main()
{
int i;
for(i=2;i<=n;i<<=1)
{
w1[i]=cp(cos(2*pi/i),sin(2*pi/i));
w2[i]=cp(cos(2*pi/i),-sin(2*pi/i));
}
rev[0]=0;
for(i=1;i<n;i++)
rev[i]=(rev[i>>1]>>1)|(i&1?n/2:0);
int m;
scanf("%d",&m);
int x;
for(i=1;i<=m;i++)
{
scanf("%d",&x);
a[x].x+=1;
b[2*x].x+=1;
c[3*x].x+=1;
}
fft(a,1);
fft(b,1);
fft(c,1);
for(i=0;i<n;i++)
a[i]=a[i]+(a[i]*a[i]-b[i])/2+(a[i]*a[i]*a[i]-a[i]*b[i]*3+c[i]*2)/6;
fft(a,-1);
for(i=0;i<n;i++)
{
ll s=round(a[i].x);
if(s>0)
printf("%d %lld\n",i,s);
}
return 0;
}