bzoj3771:Triple
传送门
生成函数
设生成函数\(f(x)\),可以将系数定为选的方案数,指数定为代价
那么
\[f(x)=\sum_{i=1}^{n}x^{w_i}
\]
然后答案就是\(f^3(x)+f^2(x)+f(x)\)然后去掉重复的情况
然后我们设
\[A(x)=\sum_{i=1}^{n}x^{2w_i}\\
B(x)=\sum_{i=1}^{n}x^{3w_i}
\]
重复的情况是哪些呢,由于\((a,b,c)\)和\((a,c,b),(b,c,a),(b,a,c),(c,a,b),(c,b,a)\)全部都看作一种情况
所以选三个的情况也就是\(\frac{f^3(x)-3A(x)f(x)+2B(x)}{6}\)
选两个的情况是\(\frac{f^2(x)-A(x)}{2}\)
选一个的情况显然就是\(f(x)\)
因为有多项式乘法,直接做\(O(n^2)\)会TLE,然而NTT需要注意爆模数的问题,所以还是用FFT加速
代码:
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cmath>
using namespace std;
void read(int &x) {
char ch; bool ok;
for(ok=0,ch=getchar(); !isdigit(ch); ch=getchar()) if(ch=='-') ok=1;
for(x=0; isdigit(ch); x=x*10+ch-'0',ch=getchar()); if(ok) x=-x;
}
#define rg register
const int maxn=4e5+10;
const double pi=acos(-1);
int n,m,r[maxn],len;
double ans[maxn];
struct complex{double x,y;}a[maxn],b[maxn],c[maxn],d[maxn];
complex operator-(complex a,complex b){return (complex){a.x-b.x,a.y-b.y};}
complex operator+(complex a,complex b){return (complex){a.x+b.x,a.y+b.y};}
complex operator*(complex a,complex b){return (complex){a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};}
void fft(complex *a,int f){
for(rg int i=0;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
for(rg int i=1;i<n;i<<=1){
complex wn=(complex){cos(pi/i),f*sin(pi/i)};
for(rg int j=0;j<n;j+=(i<<1)){
complex w=(complex){1,0};
for(rg int k=0;k<i;k++){
complex x=a[j+k],y=w*a[j+k+i];
a[j+k]=x+y,a[j+k+i]=x-y,w=w*wn;
}
}
}
if(f==-1)for(rg int i=0;i<n;i++)a[i].x/=n;
}
int main()
{
read(n);complex now,i3={3,0},i2={2,0};m=n;
for(rg int i=1,x;i<=n;i++)read(x),a[x].x++,b[x*2].x++,c[x*3].x++,m=max(x*3,m);
for(n=1;n<=m;n<<=1)len++;
for(rg int i=0;i<n;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(len-1));
fft(a,1),fft(b,1),fft(c,1);
for(rg int i=0;i<n;i++)d[i]=a[i]*a[i]*a[i];
fft(d,-1);
for(rg int i=0;i<n;i++)ans[i]=ans[i]+d[i].x/6.0;
for(rg int i=0;i<n;i++)d[i]=i3*a[i]*b[i];
fft(d,-1);
for(rg int i=0;i<n;i++)ans[i]=ans[i]-d[i].x/6.0;
for(rg int i=0;i<n;i++)d[i]=i2*c[i];
fft(d,-1);
for(rg int i=0;i<n;i++)ans[i]=ans[i]+d[i].x/6.0;
for(rg int i=0;i<n;i++)d[i]=a[i]*a[i];
fft(d,-1);
for(rg int i=0;i<n;i++)ans[i]=ans[i]+d[i].x/2.0;
fft(b,-1),fft(a,-1);
for(rg int i=0;i<n;i++)ans[i]=ans[i]-b[i].x/2.0+a[i].x;
for(rg int i=0;i<n;i++)if(ans[i]+0.1>1)printf("%d %d\n",i,(int)(ans[i]+0.1));
}