蓝桥杯省赛[新手向题解] 2019 第十届 C/C++ A组 第十题 = BZOJ4737 数论(lucas)+DP

蓝桥杯省赛[新手向题解] 2019 第十届 C/C++ A组 第十题 = BZOJ4737[传送门:https://www.lydsy.com/JudgeOnline/problem.php?id=4737]


更新

  屎一样的代码写完了,能跑样例。求hack数据

Problem Description

  给定\(N\)\(M\),求有多少对\(i\)\(j\),满足\(i≤N,j≤M\)\(C^{j}_{i}\)\(k\)的整数倍,\(k\)是给定素数(忘了题目有没有素数这条)

Input

  第一行俩整数\(t\)\(k\)
  接下来t行每行俩整数\(n\)\(m\)
  (数据范围忘了,似乎\(n,m≤10^{18}\)

Output

  对每组\(n\)\(m\),输出满足条件的方案数

Sample Input 1

1 2
3 3

Sample Output 1

1

Sample Input 2

2 5
4 5
6 7

Sample Output 2

0
7

Sample Input 3

3 23
23333333 23333333
233333333 233333333
2333333333 2333333333

Sample Output 3

851883128
959557926
680723120


更新题意:

  确定k是素数
  结果对\(10^{9}+7\)取模
    \(1 ≤ k ≤ 10^{8}\)
    \(1 ≤ t ≤ 10^{5}\)
    \(1 ≤ n, m ≤ 10^{18}\)


并不想按格式写:

赛时:

  还剩40分钟的分钟的时候摸到最后两题,感觉这题可能要莫比乌斯懒得推其实是不会做,果断光速骗分然后做倒数第二题(还是DP写得快)
  出考场走在路上听到有人说卢卡斯定理,秒懂,上体系结构课摸鱼做了一下(入坑学的第一个数论定理,丢人)

Lucas定理部分:

对于素数k,有:
  \(C^{m}_{n}\ mod\ k=C^{m_0}_{n_0}C^{m_1}_{n_2}......C^{m_p}_{n_p}\ mod\ k\)
其中:
  \(n=n_0+n_1k+n_2k^2+......+n_pk^p\)
  \(m=m_0+m_1k+m_2k^2+......+m_pk^p\)

  有一个通俗的理解:\(n_i\)\(n\)\(k\)进制表示下各位的数,\(m_i\)同理。

举个栗子🌰:
  \(n=23,m=7,k=5\)
  此时23和7转为5进制的结果分别为43和12,根据卢卡斯定理有:

\[C^{7}_{23}\ mod\ 5=C^2_3C^1_4\ mod\ 5 = (C^2_3\ mod\ 5)(C^1_4\ mod\ 5)\ mod\ 5=3*\ mod\ 5=2$$  ($C_3^2$的3和2分别来自5进制下$n$和$m$的第1位(从右到左),$C^1_4$的4和1分别来自5进制下$n$和$m$的第2位)   与直接计算的结果相同:$$C^{7}_{23}\ mod\ 5=245157\ mod\ 5=2$$   \]

  考虑到本题要求是k的倍数,也就是说对k取模结果为0,既然卢卡斯定理展开后的各项是相乘的关系,只要保证\(C^{m_0}_{n_0}\),\(C^{m_1}_{n_2}\),......\(C^{m_p}_{n_p}\)中至少有一项对k取模为0就行了。

  • 很显然,令\(n_i<m_i\)是一种可行的方案,即\(C^{m_i}_{n_i}=0\).(因为从2个球里取4个球的方案数为0)
  • 另一种方案是令\(C^{m_i}_{n_i}\)结果为k的倍数,但这样是不可能的,因为在k进制下,\(n_i<k\)\(m_i<k\),也就是说\(C^{m_i}_{n_i}\)展开后\(\frac{n_{i+1}!}{m_{i+1}!(n_{i+1}-m_{i+1})!}\),分子分母都是一些小于k的数相乘,因为k是质数,小于k的数进行乘除是得不到k的倍数的,所以只会有\(C^{m_i}_{n_i}=0\)的情况。
现在问题简化成了,求有多少对\(i,j\)符合\(i>j\),且\(i\)\(k\)进制下至少有一位小于\(j\)\(k\)进制下对应位置的数

举个小一点的例子:\(n=7,m=5,k=3\),即5进制下\(n=21,m=12\)
    取\(i=21\)\(j=12\),对应的十进制为7和2,即\(C_{7}^2\ mod\)\(5=C_1^2C_2^1\ mod\ 5\),其中的\(C_1^2\)为0
    取\(i=21\)\(j=02\),对应的十进制为7和5,即\(C_{7}^5\ mod\)\(5=C_1^2C_2^0\ mod\ 5\),其中的\(C_1^2\)为0
    取\(i=20\)\(j=12\),对应的十进制为6和5,即\(C_{6}^5\ mod\)\(5=C_0^2C_2^1\ mod\ 5\),其中的\(C_0^2\)为0
    取\(i=20\)\(j=11\),对应的十进制为6和4,即\(C_{6}^4\ mod\)\(5=C_0^1C_2^1\ mod\ 5\),其中的\(C_0^1\)为0
    取\(i=20\)\(j=02\),对应的十进制为6和2,即\(C_{6}^2\ mod\)\(5=C_0^2C_2^0\ mod\ 5\),其中的\(C_0^2\)为0
    取\(i=20\)\(j=01\),对应的十进制为6和1,即\(C_{6}^1\ mod\)\(5=C_0^1C_2^0\ mod\ 5\),其中的\(C_0^1\)为0
    取\(i=11\)\(j=02\),对应的十进制为4和2,即\(C_{4}^2\ mod\)\(5=C_1^2C_1^0\ mod\ 5\),其中的\(C_1^2\)为0
    取\(i=10\)\(j=02\),对应的十进制为3和2,即\(C_{3}^2\ mod\)\(5=C_0^2C_1^0\ mod\ 5\),其中的\(C_0^2\)为0
    取\(i=10\)\(j=01\),对应的十进制为3和1,即\(C_{3}^1\ mod\)\(5=C_0^1C_1^0\ mod\ 5\),其中的\(C_0^1\)为0
  (这例子好像一点也不小)

  各位置数字大小的方案数问题,经典的数位DP

数位DP部分:

  给定\(n,m\)的上界和质数\(k\),问能构造出多少对\(n,m\),使\(n>m\)且在\(k\)进制下至少有一个位置\(p\)满足\(n_p<m_p\),其中\(n_p\)表示\(n\)\(k\)进制下从左往右第p位数,\(m_p\)同理(n和m右对齐,高位补零)。
  先不考虑\(n\)\(m\)的上界,从高位向低位考虑第\(i+1\)个位置的数。前面\(i\)位的合法分布情况,有以下3种:
    1.\(n\)\(m\)的前i位相等,记作\(dp[i][0]\)
    2.\(n\)的前\(i\)位分别大于或等于\(m\)的前\(i\)位,记作\(dp[i][1]\)
    3.\(n\)的前\(i\)位中至少有1位小于\(m\)的对应位置,但是\(n\)要大于\(m\)以保证方案合法,记作\(dp[i][2]\)
  显然第3种状态即题目条件,所以需要用到的结果为\(dp[len][2]\)(len为n在k进制下的长度)

状态转移:

更新:

求对立的,也就是对于所有位置上都保证n≥m就行了,最后答案就是\(n≥m\)的总数减去dp求出的方案即可

Codes:


#include <bits/stdc++.h>
using namespace std;

const long long MOD = 1e9 + 7;

long long dp[100][2][2];

long long ksm(long long a, long long x)
{
    long long ans = 1;
    while(x)
    {
        if(x % 2)
        {
            ans *= a;
            ans %= MOD;
        }

        a *= a;
        a %= MOD;
        x /= 2;
    }
    return ans;
}

long long inv = ksm(2,MOD-2);
long long cal(long long topn, long long topm)
{
    topm = min(topn,topm);
    topn %= MOD;
    topm %= MOD;
    return ((topm+2)*(topm+1) % MOD * inv % MOD + (topn-topm+MOD)*(topm+1) % MOD) % MOD;
}

int main()
{
    int t;
    long long k;
    cin >> t >> k;

    while(t--)
    {
        long long n,m;
        cin >> n >> m;
        if(m > n)
            m = n;

        vector<long long> vn;
        vector<long long> vm;
        vn.clear();
        vm.clear();

        int len = 0;
        long long numn = n;
        long long numm = m;
        while(numn)
        {
            len++;
            vn.push_back(numn % k);
            numn /= k;

            if(numm)
            {
                vm.push_back(numm % k);
                numm /= k;
            }
            else
                vm.push_back(0);
        }

        for(int i=0; i<len+1; ++i)
            for(int j=0; j<2; ++j)
                for(int k=0; k<2; ++k)
                    dp[i][j][k] = 0;

        dp[len][1][1] = 1;
        for(int i=len-1; i>=0; --i)
        {
            long long nown = vn[i];
            long long nowm = vm[i];

            dp[i][0][0] += dp[i+1][0][0] * cal(k-1,k-1) % MOD
                            + dp[i+1][0][1] * cal(k-1,nowm-1) % MOD 
                            + dp[i+1][1][0] * cal(nown-1,k-1) % MOD 
                            + dp[i+1][1][1] * cal(nown-1,nowm-1) % MOD;
            dp[i][0][0] %= MOD;

            if(nown >= nowm)
            {
                dp[i][0][1] += dp[i+1][0][1] * (k-nowm) % MOD
                            + dp[i+1][1][1] * (nown-nowm) % MOD;
                dp[i][0][1] %= MOD;

                dp[i][1][0] += dp[i+1][1][0] * (nown+1) % MOD
                            + dp[i+1][1][1] * nowm % MOD;
                dp[i][1][0] %= MOD;
        
                dp[i][1][1] = dp[i+1][1][1]; 
            }
            else
            {
                dp[i][0][1] += dp[i+1][0][1] * (k-nowm) % MOD;
                dp[i][0][1] %= MOD;

                dp[i][1][0] += dp[i+1][1][0] * (nown+1) % MOD
                            + dp[i+1][1][1] * (nown+1) % MOD;
                dp[i][1][0] %= MOD;

                dp[i][1][1] = 0;
            }
        }

        long long sum = dp[0][0][0] + dp[0][0][1] + dp[0][1][0] + dp[0][1][1];
        sum %= MOD;

        long long ans = cal(n, m) - sum + MOD;
        ans %= MOD;

        printf("%lld\n",ans);
    }

    //system("pause");
    return 0;
}

原本的实现

  • 对于全相等的状态\(dp[i][0]\)
          第\(i+1\)位仍保持相等,转移到\(dp[i+1][0]\)
          第\(i+1\)位令\(n_{i+1}>m_{i+1}\),转移到\(dp[i+1][1]\)
          (注意,不能转移到\(dp[i+1][2]\),因为\(n_{i+1}<m_{i+1}\)会导致\(n<m\),使方案不合法)
  • 对于已经保证\(n>m\)的状态\(dp[i][1]\)
          第\(i+1\)位保持\(n_{i+1}≥m_{i+1}\),转移到\(dp[i+1][1]\)
          第\(i+1\)位使\(n_{i+1}<m_{i+1}\),转移到\(dp[i+1][2]\)
  • 对于已经满足条件的状态\(dp[i][2]\)
          第\(i+1\)位任意取值,转移到\(dp[i+1][2]\)
具体考虑,k进制下,每个位置能取的数有0~k-1共k种
  • \(n_{i+1}=m_{i+1}\)的取法有\(k\)
  • \(n_{i+1}>m_{i+1}\)的取法:
            \(n_{i+1}=1\)时,\(m_{i+1}\)只能取\(0\),共\(1\)
            \(n_{i+1}=2\)时,\(m_{i+1}\)可以取\(0,1\),共\(2\)
            \(n_{i+1}=3\)时,\(m_{i+1}\)可以取\(0,1,2\),共\(3\)
            ......
            \(n_{i+1}=k-1\)时,\(m_{i+1}\)可以取\(0,1,2......k-3,k-2\),共\(k-1\)
            故\(n_{i+1}>m_{i+1}\)的取法一共有\(\frac{k(k-1)}{2}\)
  • \(n_{i+1}<m_{i+1}\)的取法,把上一种反过来,就是\(\frac{k(k-1)}{2}\)
所以状态转移方程为:

    \(dp[i+1][0] = dp[i][0] * k;\)
    \(dp[i+1][1] = dp[i][0] * \frac{k^(k-1)}{2} + dp[i][1] * (k+\frac{k(k-1)}{2});\)
    \(dp[i+1][2] = dp[i][1] * \frac{k(k-1)}{2} + dp[i][2] * k^2;\)

考虑回上界的问题

  对于每个状态多加1维记录当前数是否保持和上界相等:

  • 如果到前\(i\)个位置前都是和上界相等的,那么第\(i+1\)位能取的数就不是\(k\)个,而是上界的第\(i+1\)位的数对应的数量
  • 如果已经在前面\(i\)个位置的某个位置已经小于上界了,那么第\(i+1\)位的数不管怎么取都不会超上界,转移不变

  显然状态dp[][1]和dp[][2]都不需要这么记录,有点浪费。总之特判方法挺多的

Codes:



#include <bits/stdc++.h>
using namespace std;

const int MAXN = 1e8 + 233;
const long long MOD = 1e9 + 7;

unsigned long long dp[100][3][4];
/*
    第二维
        0 : 前i位n和m相等
        1 : 前i位满足n>m且不存在某位n比m小
        2 : 前i位满足n>m且至少有一位n比m小
    第三维
        0 : 前i位都没到上界
        1 : 前i位n到上界
        2 : 前i位m到上界
        3: 前i位都到上界
*/

long long cal(long long topn, long long topm, long long compare, long long musttop)
{
    if(topn < 0 || topm < 0)
        return 0;
    //compare == 3 不考虑n和m的大小关系
    switch(musttop)
    {
        case 0 :
        {
            if(compare == 0)
                return min(topn,topm) + 1;
            if(compare == 3)
                return (topn+1) * (topm+1) % MOD;

            if(compare == 2)
                swap(topn,topm);
            if(topm >= topn)
                return (topn+1) * topn / 2 % MOD;
            else
                return (topn+topn-topm) * (topm+1) / 2 % MOD;
        }
        break;

        case 1 :
        {
            if(compare == 0)
                return topn <= topm;
            else if(compare == 3)
                return topm+1;
            else if(compare == 1)
                return min(topn,topm+1);
            else
                return max(0LL,topm-topn);
        }
        break;

        case 2 :
        {
            if(compare == 0)
                return topn >= topm;
            else if(compare == 3)
                return topn+1;
            else if(compare == 1)
                return max(0LL,topn-topm);
            else
                return min(topm,topn+1);
        }
        break;

        case 3 :
        {
            if(compare == 0)
                return topn == topm;
            else if(compare == 3)
                return 1;
            else if((compare == 1 && topn > topm) || (compare == 2 && topn < topm))
                return 1;
            else
                return 0;
        }
        break;
    }
    
}

int main()
{
    int t;
    long long k;
    cin >> t >> k;

    while(t--)
    {
        long long n,m;
        cin >> n >> m;
        if(m > n)
            m = n;

        vector<int> vn;
        vector<int> vm;
        vn.clear();
        vm.clear();

        int len = 0;
        long long numn = n;
        long long numm = m;
        while(numn)
        {
            len++;
            vn.push_back(numn % k);
            numn /= k;

            if(numm)
            {
                vm.push_back(numm % k);
                numm /= k;
            }
            else
                vm.push_back(0);
        }

        for(int i=0; i<len+1; ++i)
            for(int j=0; j<3; ++j)
                for(int k=0; k<4; ++k)
                    dp[i][j][k] = 0;

        //cout << len << endl;
        dp[len][0][3] = 1;

        for(int i=len-1; i>=0; --i)
        {
            int nn = vn[i];
            int mm = vm[i];
            int x = k-1;

            dp[i][0][0] = dp[i+1][0][0]*cal(x,x,0,0)
                        + dp[i+1][0][1]*cal(nn-1,x,0,0)
                        + dp[i+1][0][2]*cal(x,mm-1,0,0)
                        + dp[i+1][0][3]*cal(nn-1,mm-1,0,0); // 1
            dp[i][0][1] = dp[i+1][0][1]*cal(nn,x,0,1)
                        + dp[i+1][0][3]*cal(nn,mm-1,0,1); // 1
            dp[i][0][2] = dp[i+1][0][2]*cal(x,mm,0,2) 
                        + dp[i+1][0][3]*cal(nn-1,mm,0,2); // 1
            dp[i][0][3] = dp[i+1][0][3]*cal(nn,mm,0,3); // 1

            dp[i][1][0] = dp[i+1][1][0]*(cal(x,x,1,0)+cal(x,x,0,0)) + dp[i+1][0][0]*cal(x,x,1,0)
                        + dp[i+1][1][1]*(cal(nn-1,x,1,0)+cal(nn-1,x,0,0)) + dp[i+1][0][1]*cal(nn-1,x,1,0)
                        + dp[i+1][1][2]*(cal(x,mm-1,1,0)+cal(x,mm-1,0,0)) + dp[i+1][0][2]*cal(x,mm-1,1,0)
                        + dp[i+1][1][3]*(cal(nn-1,mm-1,1,0)+cal(nn-1,mm-1,0,0)) + dp[i+1][0][3]*cal(nn-1,mm-1,1,0); // 1
            dp[i][1][1] = dp[i+1][1][1]*(cal(nn,x,1,1)+cal(nn,x,0,1)) + dp[i+1][0][1]*cal(nn,x,1,1)
                        + dp[i+1][1][3]*(cal(nn,mm-1,1,1)+cal(nn,mm-1,0,1)) + dp[i+1][0][3]*cal(nn,mm-1,1,1); // 1
            dp[i][1][2] = dp[i+1][1][2]*(cal(x,mm,1,2)+cal(x,mm,0,2)) + dp[i+1][0][2]*cal(x,mm,1,2) 
                        + dp[i+1][1][3]*(cal(nn-1,mm,1,2)+cal(nn-1,mm,0,2)) + dp[i+1][0][3]*cal(nn-1,mm,1,2); // 1
            dp[i][1][3] = dp[i+1][1][3]*(cal(nn,mm,1,3)+cal(nn,mm,0,3)) + dp[i+1][0][3]*cal(nn,mm,1,3); // 1

            dp[i][2][0] = dp[i+1][2][0]*cal(x,x,3,0) + dp[i+1][1][0]*cal(x,x,2,0)
                        + dp[i+1][2][1]*cal(nn-1,x,3,0) + dp[i+1][1][1]*cal(nn-1,x,2,0)
                        + dp[i+1][2][2]*cal(x,mm-1,3,0) + dp[i+1][1][2]*cal(x,mm-1,2,0)
                        + dp[i+1][2][3]*cal(nn-1,mm-1,3,0) + dp[i+1][1][3]*cal(nn-1,mm-1,2,0); // 1
            dp[i][2][1] = dp[i+1][2][1]*cal(nn,x,3,1) + dp[i+1][1][1]*cal(nn,x,2,1)
                        + dp[i+1][2][3]*cal(nn,mm-1,3,1) + dp[i+1][1][3]*cal(nn,mm-1,2,1); // 1
            dp[i][2][2] = dp[i+1][2][2]*cal(x,mm,3,2) + dp[i+1][1][2]*cal(x,mm,2,2) 
                        + dp[i+1][2][3]*cal(nn-1,mm,3,2) + dp[i+1][1][3]*cal(nn-1,mm,2,2); // 1
            dp[i][2][3] = dp[i+1][2][3]*cal(nn,mm,3,3) + dp[i+1][1][3]*cal(nn,mm,2,3); // 1

            for(int j=0; j<3; ++j)
                for(int k=0; k<4; ++k)
                    dp[i][j][k] %= MOD;
        }

        printf("%lld\n",(dp[0][2][0] + dp[0][2][1] + dp[0][2][2] + dp[0][2][3])%MOD);
        
    }

    //system("pause");
    return 0;
}

posted @ 2019-03-27 21:19  摸鱼鱼  阅读(555)  评论(0编辑  收藏  举报