BZOJ3513[MUTC2013]idiots——FFT+生成函数
题目描述
给定n个长度分别为a_i的木棒,问随机选择3个木棒能够拼成三角形的概率。
输入
第一行T(T<=100),表示数据组数。
接下来若干行描述T组数据,每组数据第一行是n,接下来一行有n个数表示a_i。
3≤N≤10^5,1≤a_i≤10^5
输出
T行,每行一个整数,四舍五入保留7位小数。
样例输入
2
4
1 3 3 4
4
2 3 3 4
4
1 3 3 4
4
2 3 3 4
样例输出
0.5000000
1.0000000
1.0000000
提示
T<=20
N<=100000
首先开一个桶就可以得到长度分别为[1,100000]的木棒个数,只要将桶自己与自己卷积FFT一下就能得到两个木棒组成的任意长度的方案数(注意去重)。三个木棒不合法的情况当且仅当两个木棒之和小于等于第三个木棒,对桶求一个后缀和(或对方案数求一个前缀和)即可。
#include<set> #include<map> #include<queue> #include<stack> #include<cmath> #include<cstdio> #include<bitset> #include<vector> #include<cstring> #include<iostream> #include<algorithm> #define ll long long using namespace std; const double pi=acos(-1.0); int n,T,x; ll t[400010]; struct miku { double x,y; miku(double X=0,double Y=0){x=X,y=Y;} }f[400010]; miku operator + (miku a,miku b){return miku(a.x+b.x,a.y+b.y);} miku operator - (miku a,miku b){return miku(a.x-b.x,a.y-b.y);} miku operator * (miku a,miku b){return miku(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);} int l,r[400010]; int a[100010]; int mask; inline void DFT(miku *A) { for(int i=0;i<mask;i++) { if(i<r[i]) { swap(A[i],A[r[i]]); } } for(int mid=1;mid<mask;mid<<=1) { miku id(cos(pi/mid),sin(pi/mid)); for(int i=mid<<1,j=0;j<mask;j+=i) { miku w(1,0); for(int k=0;k<mid;k++,w=w*id) { miku x=A[j+k],y=w*A[j+k+mid]; A[j+k]=x+y; A[j+k+mid]=x-y; } } } } inline void IDFT(miku *A) { for(int i=0;i<mask;i++) { if(i<r[i]) { swap(A[i],A[r[i]]); } } for(int mid=1;mid<mask;mid<<=1) { miku id(cos(pi/mid),-1.0*sin(pi/mid)); for(int i=mid<<1,j=0;j<mask;j+=i) { miku w(1,0); for(int k=0;k<mid;k++,w=w*id) { miku x=A[j+k],y=w*A[j+k+mid]; A[j+k]=x+y; A[j+k+mid]=x-y; } } } } int main() { scanf("%d",&T); mask=1; l=0; while(mask<=200000) { mask<<=1; l++; } for(int i=0;i<mask;i++) { r[i]=(r[i>>1]>>1)|((i&1)<<(l-1)); } while(T--) { scanf("%d",&n); memset(t,0,sizeof(t)); int mx=0; for(int i=0;i<mask;i++) { f[i]=0; } for(int i=1;i<=n;i++) { scanf("%d",&x); a[i]=x; f[x].x++; mx=max(mx,x); } DFT(f); for(int i=0;i<mask;i++) { f[i]=f[i]*f[i]; } IDFT(f); for(int i=0;i<mask;i++) { f[i].x/=mask; } for(int i=1;i<=n;i++) { f[a[i]<<1].x--; } for(int i=1;i<=mx;i++) { t[i]=t[i-1]+(ll)(f[i].x/2+0.1); } ll ans=0; for(int i=1;i<=n;i++) { ans+=t[a[i]]; } printf("%.7f\n",1-(1.0*ans/(1.0*n*(n-1)/2*(n-2)/3))); } }