UVA_12297

    在研究了好一阵子标程之后终于把标程的解体思路弄懂了,其实关键之处就在于标程给出的递推式:f(n, k) = f(n-k, k) + f(n-k, k-1) * 4 + f(n-k, k-2) * 6 + f(n-k, k-3) * 4 + f(n-k, k-4)。

    这里f(n, k)表示用k张牌组成和为N的方案数,在递推的时候考虑一共有多少张1。①考虑有0张1:这时就相当于用k张没有任何限制的牌组成和为n-k,然后将每张牌的点数+1,这样自然就没有1了,这部分的方案数是f(n-k, k);②考虑有1张1:这时就相当于用k-1张没有任何限制的牌组成和为n-k,然后将每张牌的点数+1,这样也就没有1了,这时总和是n-1,再从4张1种任选一张1就可以使总和为n,这样就恰好用了k张牌,而且k张牌中只有1个1,这部分的方案数是f(n-k, k-1) * C(4,1),也就是f(n-k, k-1) * 4;③考虑有2张1:……(剩下的情况类似,就不再赘述了)。

    有了这个递推式后就会发现是可以用矩阵加速运算了,由于那个矩阵确实有点大,就不再贴图了,在纸上画一画之后还是不难构造出来的。

    此外还有一种解法,但是始终没看懂什么意思,如果各位仁兄看懂了的话还望教小弟一下,链接:http://hi.baidu.com/wjbzbmr/blog/item/88c8b93aed71fb38b9998f45.html

 

View Code (标程)
#include <set>
#include <map>
#include <list>
#include <cmath>
#include <ctime>
#include <deque>
#include <queue>
#include <stack>
#include <cctype>
#include <cstdio>
#include <string>
#include <vector>
#include <cassert>
#include <cstdlib>
#include <cstring>
#include <sstream>
#include <iostream>
#include <algorithm>

using namespace std;

typedef long long i64;

const int MOD = 1000000009;
const int NN = 256;

// recurrence: f(n, k) = f(n-k, k) + f(n-k, k-1) * 4 + f(n-k, k-2) * 6 + f(n-k, k-3) * 4 + f(n-k, k-4)

int n, k, N, pos[18][18];;
int base[NN][NN], res[NN][NN], temp[NN][NN];

int multiply( int A[NN][NN], int B[NN][NN], int C[NN][NN] ) {
    for( int i = 0; i < N; i++ ) for( int j = 0; j < N; j++ ) {
        temp[i][j] = 0;
        for( int k = 0; k < N; k++ ) if( A[i][k] && B[k][j] ) temp[i][j] = ( temp[i][j] + A[i][k] * (i64)B[k][j] ) % MOD;
    }
    for( int i = 0; i < N; i++ ) for( int j = 0; j < N; j++ ) C[i][j] = temp[i][j];
}

void buildBase() {
    for( int i = 0; i < N; i++ ) {
        for( int j = 0; j < N; j++ ) base[i][j] = res[i][j] = 0;
        res[i][i] = 1;
    }
    // build base
    int x = 0;
    for( int j = k; j >= 1; j-- ) for( int i = k + 1; i > 1; i-- ) {
        if( pos[i][j] == -1 ) {
            int p = i - j, q = j;
            if( p >= 1 ) {
                assert( pos[p][q] > -1 ); base[x][ pos[p][q] ] = 1;
                q--;
                if( q > 0 ) {
                    assert( pos[p][q] > -1 ); base[x][ pos[p][q] ] = 4;
                    q--;
                    if( q > 0 ) {
                        assert( pos[p][q] > -1 ); base[x][ pos[p][q] ] = 6;
                        q--;
                        if( q > 0 ) {
                            assert( pos[p][q] > -1 ); base[x][ pos[p][q] ] = 4;
                            q--;
                            if( q > 0 ) {
                                assert( pos[p][q] > -1 ); base[x][ pos[p][q] ] = 1;
                            }
                        }
                    }
                }
            }
        }
        else base[x][ pos[i][j] ] = 1;
        x++;
    }
}

int solve() {
    int dp[20][17] = {0};

    dp[0][0] = 1;
    for( int i = 1; i < 20; i++ ) {
        for( int j = 1; j <= k; j++ ) {
            int p = i - j, q = j;
            if( p >= 0 && q >= 0 ) dp[i][j] = (dp[i][j] + dp[p][q]) % MOD;
            q--; if( p >= 0 && q >= 0 ) dp[i][j] = (dp[i][j] + 4 * (i64) dp[p][q]) % MOD;
            q--; if( p >= 0 && q >= 0 ) dp[i][j] = (dp[i][j] + 6 * (i64) dp[p][q]) % MOD;
            q--; if( p >= 0 && q >= 0 ) dp[i][j] = (dp[i][j] + 4 * (i64) dp[p][q]) % MOD;
            q--; if( p >= 0 && q >= 0 ) dp[i][j] = (dp[i][j] + dp[p][q]) % MOD;
        }
    }
    if( n < 20 ) {
        int r = 0;
        for( int i = 1; i <= k; i++ ) r = ( r + dp[n][i] ) % MOD;
        return r;
    }
    int f[NN] = {0}, f1[NN] = {0};
    N = 0;
    memset( pos, -1, sizeof(pos) );
    for( int i = k; i >= 1; i-- ) for( int j = k; j >= 1; j-- ) {
        f[N] = dp[j][i];
        pos[j][i] = N++;
    }

    buildBase();

    int p = n - k;

    while( p ) {
        if( p & 1 ) multiply( res, base, res );
        multiply( base, base, base );
        p >>= 1;
    }

    for( int i = 0; i < N; i++ ) for( int k = 0; k < N; k++ ) if( f[k] && res[i][k] ) f1[i] = (f1[i] + res[i][k] * (i64)f[k]) % MOD;

    int r = 0;
    for( int i = 0; i < N; i += k ) r = ( r + f1[i] ) % MOD;
    return r;
}

int main() {
    double cl = clock();

    while( scanf("%d %d", &n, &k) == 2 && n ) {
        printf("%d\n", solve());
    }

    cl = clock() - cl;
    fprintf(stderr, "Total Execution Time = %lf seconds\n", cl / CLOCKS_PER_SEC);

    return 0;
}

 

View Code (My code)
#include<stdio.h>
#include<string.h>
#define MAXD 120
#define MOD 1000000009
typedef long long LL;
int temp[MAXD][MAXD], f[15][15], N, K;
int n;
const int d[] = {1, 4, 6, 4, 1};
struct matrix
{
    int a[MAXD][MAXD];
    void init(int x)
    {
        for(int i = 0; i < n; i ++)
            for(int j = 0; j < n; j ++)
                a[i][j] = x;
    }
    matrix operator * (const matrix &t) const
    {
        matrix ans;
        ans.init(0);
        for(int i = 0; i < n; i ++)
            for(int k = 0; k < n; k ++)
                if(a[i][k])
                {
                    for(int j = 0; j < n; j ++)
                        if(t.a[k][j])
                            ans.a[i][j] = (ans.a[i][j] + (LL)a[i][k] * t.a[k][j]) % MOD;
                }
        
        return ans;
    }
}mat, unit;
void prepare()
{
    memset(f[0], 0, sizeof(f[0]));
    f[0][0] = 1;
    for(int i = 1; i <= 10; i ++)
        for(int j = 1; j <= 10; j ++)
        {
            f[i][j] = 0;
            if(i >= j)
            {
                for(int k = 0; k < 5 && j - k >= 0; k ++)
                    f[i][j] = (f[i][j] + (LL)d[k] * f[i - j][j - k]) % MOD;
            }
        }
}
void build()
{
    n = K * (K + 1);
    mat.init(0);
    for(int i = 0; i < K; i ++)
        for(int j = 0; j <= K; j ++)
            mat.a[i * (K + 1) + j][0] = f[K - 1 - i][j];
    
    unit.init(0);
    for(int j = 1; j <= K; j ++)
    {
        for(int k = 0; k < 5 && j - k >= 0; k ++)
            unit.a[j][(j - 1) * (K + 1) + j - k] = d[k];
    }
    for(int i = 1; i < K; i ++)
        for(int j = 0; j <= K; j ++)
            unit.a[i * (K + 1) + j][(i - 1) * (K + 1) + j] = 1;    
}
void powmod(int n)
{
    while(n)
    {
        if(n & 1)
            mat = unit * mat;
        n >>= 1, unit = unit * unit;    
    }
}
void solve()
{
    if(N <= 10)
    {
        int ans = 0;
        for(int i = 1; i <= K; i ++)
            ans = (ans + f[N][i]) % MOD;
        printf("%d\n", ans);
        return ;
    }
    build();
    powmod(N - K + 1);
    int ans = 0;
    for(int i = 1; i <= K; i ++)
        ans = (ans + mat.a[i][0]) % MOD;
    printf("%d\n", ans);
}
int main()
{
    prepare();
    while(scanf("%d%d", &N, &K) == 2, N || K)
        solve();
    return 0;    
}

 

posted on 2012-08-03 18:06  Staginner  阅读(588)  评论(0编辑  收藏  举报