Gym - 100341C FFT优化DP

题目链接传送门

 

题解:

  设定dp[i][j]在深度为i下,使用j个节点的方案数

  显然的转移方程组就是 dp[h][n] = dp[h-1][i] * dp[h-1][n-i-1] + 2*dp[h-1][i]*dp[h-2][n-i-1];

  卷积形式

  利用FFT加速求解dp[h]

 

下面是AC代码

#include<bits/stdc++.h>
using namespace std;
#pragma comment(linker, "/STACK:102400000,102400000")
#define ls i<<1
#define rs ls | 1
#define mid ((ll+rr)>>1)
#define pii pair<int,int>
#define MP make_pair
typedef long long LL;
typedef unsigned long long ULL;
const long long INF = 1e18+1LL;
const double pi = acos(-1.0);
const int N = 7e5+10, M = 1e3+20,inf = 2e9;

const long long P=786433LL,mod = 786433LL;
const LL G=10LL;

LL mul(LL x,LL y){
    return (x*y-(LL)(x/(long double)P*y+1e-3)*P+P)%P;
}
LL qpow(LL x,LL k){
    LL ret=1;
    while(k){
        if(k&1) ret=mul(ret,x);
        k>>=1;
        x=mul(x,x);
    }
    return ret;
}
LL wn[50];
void getwn(){
    for(int i=1; i<=40; ++i){
        int t=1<<i;
        wn[i]=qpow(G,(P-1)/t);
    }
}

int len;
void NTT(LL y[],int op){
    for(int i=1,j=len>>1,k; i<len-1; ++i){
        if(i<j) swap(y[i],y[j]);
        k=len>>1;
        while(j>=k){
            j-=k;
            k>>=1;
        }
        if(j<k) j+=k;
    }
    int id=0;
    for(int h=2; h<=len; h<<=1) {
        ++id;
        for(int i=0; i<len; i+=h){
            LL w=1;
            for(int j=i; j<i+(h>>1); ++j){
                LL u=y[j],t=mul(y[j+h/2],w);
                y[j]=u+t;
                if(y[j]>=P) y[j]-=P;
                y[j+h/2]=u-t+P;
                if(y[j+h/2]>=P) y[j+h/2]-=P;
                w=mul(w,wn[id]);
            }
        }
    }
    if(op==-1){
        for(int i=1; i<len/2; ++i) swap(y[i],y[len-i]);
        LL inv=qpow(len,P-2);
        for(int i=0; i<len; ++i) y[i]=mul(y[i],inv);
    }
}
LL dp[20][N],tmp[N];
int n,h;
int main() {
    freopen("avl.in", "r", stdin);
    freopen("avl.out", "w", stdout);
    getwn();
    scanf("%d%d",&n,&h);
    dp[0][1] = 1,dp[1][2] = 2,dp[1][3] = 1;
    len = 1;
    for(int i = 2; i <= h; ++i) {
        len = (1<<(i+1));
        NTT(dp[i-2],1);NTT(dp[i-1],1);
        for(int j = 0; j < len; ++j)
            tmp[j] = (dp[i-1][j] * dp[i-2][j])%mod;
        NTT(tmp,-1);
        for(int j = 1; j < len; ++j)
            dp[i][j] = 2LL*tmp[j-1] % mod;
         for(int j = 0; j < len; ++j)
            tmp[j] = (dp[i-1][j] * dp[i-1][j])%mod;
         NTT(tmp,-1);
         for(int j = 1; j < len; ++j)
            dp[i][j] += tmp[j-1], dp[i][j] %= mod;
        NTT(dp[i-2],-1);
        NTT(dp[i-1],-1);
    }
    printf("%lld\n",dp[h][n]);
    return 0;
}

 

posted @ 2017-08-02 21:04  meekyan  阅读(599)  评论(0编辑  收藏  举报