【HDU4609】3-idiots-FFT+生成函数

测试地址:3-idiots
题目大意:n条线段,问从里面随机取3条线段,能组成三角形的概率。
做法:本题需要用到FFT+生成函数。
首先,求概率就是用合法方案数除以总方案数,这里总方案数显然是n(n1)(n2)/6,因此我们只需要求合法方案数。最暴力的思路就是O(n3)枚举,但是这样显然无法通过本题,需要考虑优化。
设选出的三条线段长为a,b,c,我们发现若令0<abc,那么一定有a+c>bb+c>a,所以只需要保证a+b>c即可,这是三条线段能组成三角形的充要条件。
因此我们想到枚举c的长度,然后累加能使得0<abca+b>c的线段a,b的对数到答案中。为了方便,考虑分类讨论:
1.当三角形中仅含有一条长度为c的边时,令sum1为长度小于c的线段数,sum2为长度和小于等于c的无序对数,s为长度为c的线段数,可知这种情况的方案有s[sum1(sum11)/2sum2]种。
2.当三角形中含有两条长度为c的边时,延续上述定义,则可知这种情况的方案有s(s1)/2×sum1种。
3.当三角形中含有三条长度为c的边时,延续上述定义,则可知这种情况的方案有s(s1)(s2)/6种。
那么如果我们可以用较好的方法预处理出sum1sum2,我们就可以O(n)求出方案数。sum1显然用前缀和就可以了,问题是sum2,这时候就要用生成函数的思想来解决这个问题。
令多项式A(x)xi项的系数为长度为i的线段数,那么A(x)2xi项的系数就为长度和为i的线段对数。这里我们还要去重,因为这里包含了同一条线段取两次的情况,因此我们要减去一个A(x2)。去掉这种情况后,因为我们要求的是无序对数,而这里剩下的是有序对数,因此要再除以2。那么[A(x)2A(x2)]/2就是我们要求的生成函数了,很显然求A(x)2可以用FFT优化到O(nlogn),而其他的多项式运算都是O(n)的,所以这一步骤的时间复杂度为O(nlogn)。我们求出这个序列后,对其求前缀和,就是前面的sum2了。
最后本题的时间复杂度为O(nlogn),成功地解决了这道题。
以下是本人代码:

#include <bits/stdc++.h>
#define ll long long
using namespace std;
int T,n,r[400010];
const double pi=acos(-1.0);
ll v1[400010],v2[400010],ans,sum1,sum2;
struct Complex
{
    double x,y;
}a[400010],b[400010];
Complex operator + (Complex a,Complex b) {Complex s;s.x=a.x+b.x,s.y=a.y+b.y;return s;}
Complex operator - (Complex a,Complex b) {Complex s;s.x=a.x-b.x,s.y=a.y-b.y;return s;}
Complex operator * (Complex a,Complex b) {Complex s;s.x=a.x*b.x-a.y*b.y,s.y=a.x*b.y+a.y*b.x;return s;}

void FFT(Complex *a,int type)
{
    for(int i=0;i<n;i++)
        if (i<r[i]) swap(a[i],a[r[i]]);
    for(int mid=1;mid<n;mid<<=1)
    {
        Complex W={cos(pi/mid),type*sin(pi/mid)};
        for(int l=0;l<n;l+=(mid<<1))
        {
            Complex w={1.0,0.0};
            for(int k=0;k<mid;k++,w=w*W)
            {
                Complex x=a[l+k],y=w*a[l+mid+k];
                a[l+k]=x+y;
                a[l+mid+k]=x-y;
            }
        }
    }
    if (type==-1)
    {
        for(int i=0;i<n;i++)
            a[i].x/=n;
    }
}

int main()
{
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d",&n);
        memset(a,0,sizeof(a));
        int mx=0;
        ll saved=(ll)n;
        for(int i=1;i<=n;i++)
        {
            int x;
            scanf("%d",&x);
            a[x].x+=1.0;
            mx=max(mx,x);
        }

        n=mx+1;
        n=(n<<1)-1;
        int bit=0,x=1;
        while(x<n) bit++,x<<=1;
        n=x;
        for(int i=0;i<n;i++)
            r[i]=(r[i>>1]>>1)|((i&1)<<(bit-1));

        for(int i=0;i<n;i++) b[i]=a[i];
        FFT(a,1);
        for(int i=0;i<n;i++) a[i]=a[i]*a[i];
        FFT(a,-1);

        for(int i=0;i<n;i++)
        {
            v1[i]=(ll)(b[i].x+0.5);
            v2[i]=(ll)(a[i].x+0.5);
            if (!(i%2)) v2[i]-=v1[i>>1];
            v2[i]>>=1;
        }
        ans=sum1=sum2=0;
        for(int i=1;i<=mx;i++)
        {
            sum1+=v1[i-1];
            sum2+=v2[i];
            ans+=(((sum1*(sum1-1))>>1)-sum2)*v1[i];
            ans+=((v1[i]*(v1[i]-1))>>1)*sum1;
            ans+=v1[i]*(v1[i]-1)*(v1[i]-2)/6;
        }
        printf("%.7lf\n",(double)ans/(double)((saved*(saved-1)*(saved-2))/6));
    }

    return 0;
}
posted @ 2018-02-15 21:41  Maxwei_wzj  阅读(107)  评论(0编辑  收藏  举报