【HDU4609】3-idiots-FFT+生成函数
测试地址:3-idiots
题目大意:有条线段,问从里面随机取条线段,能组成三角形的概率。
做法:本题需要用到FFT+生成函数。
首先,求概率就是用合法方案数除以总方案数,这里总方案数显然是,因此我们只需要求合法方案数。最暴力的思路就是枚举,但是这样显然无法通过本题,需要考虑优化。
设选出的三条线段长为,我们发现若令,那么一定有和,所以只需要保证即可,这是三条线段能组成三角形的充要条件。
因此我们想到枚举的长度,然后累加能使得且的线段的对数到答案中。为了方便,考虑分类讨论:
1.当三角形中仅含有一条长度为的边时,令为长度小于的线段数,为长度和小于等于的无序对数,为长度为的线段数,可知这种情况的方案有种。
2.当三角形中含有两条长度为的边时,延续上述定义,则可知这种情况的方案有种。
3.当三角形中含有三条长度为的边时,延续上述定义,则可知这种情况的方案有种。
那么如果我们可以用较好的方法预处理出和,我们就可以求出方案数。显然用前缀和就可以了,问题是,这时候就要用生成函数的思想来解决这个问题。
令多项式的项的系数为长度为的线段数,那么的项的系数就为长度和为的线段对数。这里我们还要去重,因为这里包含了同一条线段取两次的情况,因此我们要减去一个。去掉这种情况后,因为我们要求的是无序对数,而这里剩下的是有序对数,因此要再除以。那么就是我们要求的生成函数了,很显然求可以用FFT优化到,而其他的多项式运算都是的,所以这一步骤的时间复杂度为。我们求出这个序列后,对其求前缀和,就是前面的了。
最后本题的时间复杂度为,成功地解决了这道题。
以下是本人代码:
#include <bits/stdc++.h>
#define ll long long
using namespace std;
int T,n,r[400010];
const double pi=acos(-1.0);
ll v1[400010],v2[400010],ans,sum1,sum2;
struct Complex
{
double x,y;
}a[400010],b[400010];
Complex operator + (Complex a,Complex b) {Complex s;s.x=a.x+b.x,s.y=a.y+b.y;return s;}
Complex operator - (Complex a,Complex b) {Complex s;s.x=a.x-b.x,s.y=a.y-b.y;return s;}
Complex operator * (Complex a,Complex b) {Complex s;s.x=a.x*b.x-a.y*b.y,s.y=a.x*b.y+a.y*b.x;return s;}
void FFT(Complex *a,int type)
{
for(int i=0;i<n;i++)
if (i<r[i]) swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1)
{
Complex W={cos(pi/mid),type*sin(pi/mid)};
for(int l=0;l<n;l+=(mid<<1))
{
Complex w={1.0,0.0};
for(int k=0;k<mid;k++,w=w*W)
{
Complex x=a[l+k],y=w*a[l+mid+k];
a[l+k]=x+y;
a[l+mid+k]=x-y;
}
}
}
if (type==-1)
{
for(int i=0;i<n;i++)
a[i].x/=n;
}
}
int main()
{
scanf("%d",&T);
while(T--)
{
scanf("%d",&n);
memset(a,0,sizeof(a));
int mx=0;
ll saved=(ll)n;
for(int i=1;i<=n;i++)
{
int x;
scanf("%d",&x);
a[x].x+=1.0;
mx=max(mx,x);
}
n=mx+1;
n=(n<<1)-1;
int bit=0,x=1;
while(x<n) bit++,x<<=1;
n=x;
for(int i=0;i<n;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++) b[i]=a[i];
FFT(a,1);
for(int i=0;i<n;i++) a[i]=a[i]*a[i];
FFT(a,-1);
for(int i=0;i<n;i++)
{
v1[i]=(ll)(b[i].x+0.5);
v2[i]=(ll)(a[i].x+0.5);
if (!(i%2)) v2[i]-=v1[i>>1];
v2[i]>>=1;
}
ans=sum1=sum2=0;
for(int i=1;i<=mx;i++)
{
sum1+=v1[i-1];
sum2+=v2[i];
ans+=(((sum1*(sum1-1))>>1)-sum2)*v1[i];
ans+=((v1[i]*(v1[i]-1))>>1)*sum1;
ans+=v1[i]*(v1[i]-1)*(v1[i]-2)/6;
}
printf("%.7lf\n",(double)ans/(double)((saved*(saved-1)*(saved-2))/6));
}
return 0;
}