[BZOJ 3771]Triple(FFT+容斥原理)
Description
我们讲一个悲伤的故事。
从前有一个贫穷的樵夫在河边砍柴。
这时候河里出现了一个水神,夺过了他的斧头,说:
“这把斧头,是不是你的?”
樵夫一看:“是啊是啊!”
水神把斧头扔在一边,又拿起一个东西问:
“这把斧头,是不是你的?”
樵夫看不清楚,但又怕真的是自己的斧头,只好又答:“是啊是啊!”
水神又把手上的东西扔在一边,拿起第三个东西问:
“这把斧头,是不是你的?”
樵夫还是看不清楚,但是他觉得再这样下去他就没法砍柴了。
于是他又一次答:“是啊是啊!真的是!”
水神看着他,哈哈大笑道:
“你看看你现在的样子,真是丑陋!”
之后就消失了。
樵夫觉得很坑爹,他今天不仅没有砍到柴,还丢了一把斧头给那个水神。
于是他准备回家换一把斧头。
回家之后他才发现真正坑爹的事情才刚开始。
水神拿着的的确是他的斧头。
但是不一定是他拿出去的那把,还有可能是水神不知道怎么偷偷从他家里拿走的。
换句话说,水神可能拿走了他的一把,两把或者三把斧头。
樵夫觉得今天真是倒霉透了,但不管怎么样日子还得过。
他想统计他的损失。
樵夫的每一把斧头都有一个价值,不同斧头的价值不同。总损失就是丢掉的斧头价值和。
他想对于每个可能的总损失,计算有几种可能的方案。
注意:如果水神拿走了两把斧头a和b,(a,b)和(b,a)视为一种方案。拿走三把斧头时,(a,b,c),(b,c,a),(c,a,b),(c,b,a),(b,a,c),(a,c,b)视为一种方案。
Solution
比较容易能看出来是FFT,但是去重不太好处理
可以用三个多项式a,b,c分别表示一把斧头构成的多项式、两把相同的斧头构成的多项式、三把相同的斧头构成的多项式
于是取一把斧头就是 a
取两把斧头为 (a2-b)/2
取三把斧头为 (a3-a*b*3+2*c)/6 (在减去a*b*3时其实减了三遍c,所以要再加上两个)
#include<iostream> #include<cstdio> #include<cstdlib> #include<cstring> #include<cmath> #define MAXN 40000 #define pi acos(-1) using namespace std; int n; struct cp { double r,i; cp(double r=0,double i=0):r(r),i(i){} cp operator + (const cp& x) const {return cp(r+x.r,i+x.i);} cp operator - (const cp& x) const {return cp(r-x.r,i-x.i);} cp operator * (const cp& x) const {return cp(r*x.r-i*x.i,r*x.i+i*x.r);} }a[MAXN*4],b[MAXN*4],c[MAXN*4]; int read() { int x=0,f=1;char c=getchar(); while(c<'0'||c>'9'){if(x=='-')f=-1;c=getchar();} while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();} return x*f; } void brc(cp* x,int len) { int k=len/2; for(int i=1;i<len-1;i++) { if(i<k)swap(x[i],x[k]); int j=len/2; while(k>=j) { k-=j; j>>=1; } if(k<j)k+=j; } } void fft(cp* x,int len,int on) { brc(x,len); for(int h=2;h<=len;h<<=1) { cp wn=cp(cos(2*on*pi/h),sin(2*on*pi/h)); for(int i=0;i<len;i+=h) { cp w=cp(1,0); for(int j=i;j<i+h/2;j++) { cp u=x[j]; cp t=w*x[j+h/2]; x[j]=u+t; x[j+h/2]=u-t; w=w*wn; } } } if(on==-1)for(int i=0;i<len;i++)x[i].r/=len; } int main() { n=read();int maxn=0; for(int i=1;i<=n;i++) { int x=read(); maxn=max(x,maxn); a[x].r++,b[x*2].r++,c[x*3].r++; } int len=1; while(len<maxn*3)len<<=1; fft(a,len,1),fft(b,len,1),fft(c,len,1); for(int i=0;i<len;i++) a[i]=a[i]+(a[i]*a[i]-b[i])*cp(0.5,0)+(a[i]*a[i]*a[i]-a[i]*b[i]*cp(3,0)+c[i]*cp(2,0))*cp(1.0/6.0,0); fft(a,len,-1); for(int i=0;i<len;i++) { int x=(int)(a[i].r+0.1); if(x)printf("%d %d\n",i,x); } return 0; }