【BZOJ3771】Triple(生成函数,多项式运算)
【BZOJ3771】Triple(生成函数,多项式运算)
题面
有\(n\)个价值\(w\)不同的物品
可以任意选择\(1,2,3\)个组合在一起
输出能够组成的所有价值以及方案数。
\(n,w<=40000\)
题解
对于每一个出现的价值,就在对应的位置上\(+1\)
于是我们就有了一个生成函数\(A(x)\),代表着出现了一次的价值。
设\(B(x),C(x)\)分别代表着两个物品组成的价值和三个物品组成的价值,我们不难得到以下式子。
\[B(x)=A(x)*A(x)-D(x),C(x)=A(x)*A(x)*A(x)-E(x)
\]
而式子是怎么来的,请回想背包是怎么做的。
其中\(D(x),E(x)\)是算重复的部分,容斥计算。
所以多项式乘法+容斥即可。
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<set>
#include<map>
#include<vector>
#include<queue>
using namespace std;
#define ll long long
#define RG register
#define MAX 155555
const double Pi=acos(-1);
inline int read()
{
RG int x=0,t=1;RG char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')t=-1,ch=getchar();
while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
return x*t;
}
struct Complex{double a,b;}W[MAX],a[MAX],b[MAX],c[MAX],ans[MAX];
Complex operator+(Complex a,Complex b){return (Complex){a.a+b.a,a.b+b.b};}
Complex operator-(Complex a,Complex b){return (Complex){a.a-b.a,a.b-b.b};}
Complex operator*(Complex a,Complex b){return (Complex){a.a*b.a-a.b*b.b,a.a*b.b+a.b*b.a};}
Complex operator*(Complex a,double b){return (Complex){a.a*b,a.b*b};}
int N,r[MAX],mx,n,m,l;
void FFT(Complex *P,int opt)
{
for(int i=0;i<N;++i)if(i<r[i])swap(P[i],P[r[i]]);
for(int i=1;i<N;i<<=1)
for(int p=i<<1,j=0;j<N;j+=p)
for(int k=0;k<i;++k)
{
Complex w=(Complex){W[N/i*k].a,W[N/i*k].b*opt};
Complex X=P[j+k],Y=w*P[i+j+k];
P[j+k]=X+Y;P[i+j+k]=X-Y;
}
if(opt==-1)for(int i=0;i<N;++i)P[i].a/=1.0*N;
}
int main()
{
n=read();
for(int i=1,x;i<=n;++i)
{
x=read();
a[x].a+=1;b[x+x].a+=1;c[x+x+x].a+=1;
mx=max(mx,x);
}
m=mx*3;
for(N=1;N<=m;N<<=1)++l;
for(int i=0;i<N;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
for(int i=1;i<N;i<<=1)
for(int k=0;k<i;++k)W[N/i*k]=(Complex){cos(1.0*k*Pi/i),sin(1.0*k*Pi/i)};
FFT(a,1);FFT(b,1);FFT(c,1);
for(int i=0;i<N;++i)
{
ans[i]=ans[i]+(a[i]*a[i]*a[i]-b[i]*a[i]*3+c[i]*2)*(1.0/6.0);
ans[i]=ans[i]+(a[i]*a[i]-b[i])*0.5;
ans[i]=ans[i]+a[i];
}
FFT(ans,-1);
for(int i=0;i<N;++i)
{
int x=(int)(ans[i].a+0.5);
if(x)printf("%d %d\n",i,x);
}
return 0;
}