【BZOJ3456】城市规划-多项式求逆

测试地址:城市规划
题目大意:n个点带标号简单无向连通图(即无重边,无自环)的数目。
做法:本题需要用到多项式求逆。
如果不要求连通,这题就是水题了,答案显然为2Cn2(即枚举每条边选或不选)。
然而这题显然没那么简单,我们令f(n)为我们要求的答案,g(n)=2Cn2,有如下递推式:
g(n)=i=1n1Cn1i1f(i)g(ni)
上面式子的意义实质上是枚举点1所在的连通块的大小i进行转移。
当然,我们要求的不是g而是f,所以我们需要找到一个方式用g表示出f
把上面的组合数拆开,得到:
g(n)=(n1)!i=1n1f(i)(i1)!g(ni)(ni)!
两边同除一个(n1)!,这就是一个卷积形式的式子,于是我们可以用三个生成函数(多项式)来表示这三个部分:
A(x)=B(x)C(x)
C为包含f(i)的那个部分,于是C(x)=A(x)B1(x),其中C1(x)是多项式C的逆。
于是接下来就涉及多项式求逆的问题了,在这里,令C的最高次项的次数为n,那么它的逆C1(x)应该满足:
C(x)C1(x)1(modxn+1)
于是问题转化成求一个模意义下的多项式D(x),使得:
C(x)D(x)1(modxn)
假设我们已经求出一个D(x)满足:
C(x)D(x)1(modxn2)
于是我们有:
D(x)D(x)0(modxn2)
两边平方,把模数扩展到n(也有可能是n+1,反正不小于n,不管):
D2(x)2D(x)D(x)D2(x)(modxn)
两边再乘一个C(x),有:
D(x)2D(x)C(x)D2(x)(modxn)
而我们知道,C(x)在模x意义下的逆元,就等于它常数项的逆元。于是这样我们就能倍增+NTT算出多项式的逆了,算法的时间复杂度为:
T(n)=T(n2)+O(nlogn)=O(nlogn)
这样我们就解决了这一题。注意因为上面的所有计算都是在模意义下进行的,所以每一步都要将多项式取模(其实就是截取前面的某一段)。
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=1004535809;
const ll g=3;
ll n,rev[600010]={0},fac[600010],inv[600010];
ll A[600010]={0},B[600010]={0},C[600010]={0};
ll p[600010]={0},s[600010]={0};

ll power(ll a,ll b)
{
    ll s=1,ss=a;
    b=(b%(mod-1)+mod-1)%(mod-1);
    while(b)
    {
        if (b&1) s=s*ss%mod;
        ss=ss*ss%mod;b>>=1;
    }
    return s;
}

ll NTT(ll *a,ll type,ll n)
{
    for(ll i=0;i<n;i++)
        if (i<rev[i]) swap(a[i],a[rev[i]]);
    for(ll mid=1;mid<n;mid<<=1)
    {
        ll W=power(g,type*(mod-1)/(mid<<1));
        for(ll l=0;l<n;l+=(mid<<1))
        {
            ll w=1;
            for(ll k=0;k<mid;k++,w=w*W%mod)
            {
                ll x=a[l+k],y=w*a[l+mid+k]%mod;
                a[l+k]=(x+y)%mod;
                a[l+mid+k]=(x-y+mod)%mod;
            }
        }
    }
    if (type==-1)
    {
        ll inv=power(n,mod-2);
        for(ll i=0;i<=n;i++)
            a[i]=a[i]*inv%mod;
    }
}

ll calc_rev(ll limit)
{
    ll bit=0,x=1;
    while(x<=limit) bit++,x<<=1;
    rev[0]=0;
    for(ll i=1;i<=x;i++)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    return x;
}

void calc_inv(ll *a,ll len)
{
    if (len==1)
    {
        s[0]=power(a[0],mod-2);
        return;
    }
    calc_inv(a,(len+1)>>1);

    ll x=calc_rev((len<<1)-1);
    memset(p,0,sizeof(p));
    for(ll i=0;i<len;i++)
        p[i]=a[i];
    NTT(s,1,x),NTT(p,1,x);
    for(ll i=0;i<=x;i++)
        s[i]=(2ll*s[i]-p[i]*s[i]%mod*s[i]%mod+mod)%mod;
    NTT(s,-1,x);
    for(ll i=len;i<=x;i++)
        s[i]=0;
}

int main()
{
    scanf("%lld",&n);
    fac[0]=fac[1]=1;
    for(ll i=1;i<=n;i++)
        fac[i]=fac[i-1]*i%mod;
    inv[n]=power(fac[n],mod-2);
    for(ll i=n;i>=1;i--)
        inv[i-1]=inv[i]*i%mod;

    for(ll i=0;i<=n;i++)
    {
        if (i>0) A[i]=power(2ll,i*(i-1)/2ll)*inv[i-1]%mod;
        B[i]=power(2ll,i*(i-1)/2ll)*inv[i]%mod;
    }
    calc_inv(B,n+1);

    ll x=calc_rev(n);
    NTT(s,1,x),NTT(A,1,x);
    for(ll i=0;i<=x;i++)
        C[i]=s[i]*A[i]%mod;
    NTT(C,-1,x);
    printf("%lld",C[n]*fac[n-1]%mod);

    return 0;
}
posted @ 2018-06-20 19:17  Maxwei_wzj  阅读(115)  评论(0编辑  收藏  举报