数位DP题目瞎做

普通的数位DP,题目形式大概是——

在区间\([L,R]\)里,统计满足条件的数的个数,并且暴力统计会严重超时

这样就提醒我们在每一位去枚举这一位数字的决策,递推统计方案。对于区间\([L,R]\)的查询,可以转化为对区间\([1,R]\)\([1,L-1]\)的查询——>给出一个数\(x\),查询小于等于\(x\)且满足条件的数的个数。

在解决这个问题的时候,会将数\(x\)按一棵树去处理。假设数\(x\)\(n\)位,分别是\(X_n,X_{n-1},...,X_1\)。这个处理过程如下图:

image

对于数的第\(p\)位,这一位左子树我们统计的是\(X_nX_{n-1}...X_{p+1}00000...0\)\(X_nX_{n-1}...X_{p+1}(X_p-1)9999...9\)里面符合条件的数。也就是\(1\)\(p-1\)位的数都可以放飞自我随意取值。那么左子树的值就可以通过前\(p-1\)位的递推值以及\(X_nX_{n-1}...X_{p+1}\)这个前缀综合起来求解,最终我们留下的就是一个从根到最右端节点的链还没有统计过,也就是\(x\)本身,统计\(x\)本身显然是一个非常简单的问题。需要注意的是第一棵左子树\([0,X_n-1]\),里面会包含前导0,有的题目不允许前导0,需要特殊判断。

一般来说,数位DP有两种写法,递推预处理法和记忆化搜索法。前者可以认为是自低位向高位的顺推,后者可以认为是自高位向低位的逆推。

顺推法会提前处理好所有情况下的dp状态值,接下来根据查询的数,在左子树里面按需取值。更适合处理低位决策会影响高位决策的情况,例如状态中会记录低位产生的进位。

逆推法更适合处理高位会影响低位决策的情况,例如状态中会出现数字的大小关系(高位如果小了,低位不会改变大小关系)。整个记忆化搜索的过程就像上图这棵树的过程一样,非常直观。

一般记忆化搜索的模板如下:

点击查看代码
ll solve(int n, state..., bool zero, bool limit, const std::vector<int> &digit)
{
    if (n == 0 && 满足条件)
        return 1;
    if (n == 0)
        return 0;
    if (!zero && !limit && dp[n][state] != -1)
        return dp[n][state];
    ll res = 0;
    int mx = limit ? digit[n - 1] : 9;
    for (int i = 0; i <= mx; ++i)
        res += solve(n - 1, newstate, zero&&(i==0), limit & (i == digit[n - 1]), digit);
    if (!zero && !limit)
        dp[n][state] = res;
    return res;
}

zero表示当前是不是前导0,limit表示当前是不是在树的最右端的链上,如果在最右端的链上,那么下一位的取值就不能超过\(x\)本身在下一位的值。一旦进入某颗左子树,limit就会变成false。在zero和limit是true的情况下,不能更新dp的值。因为如果题目查询的是区间\([L,R]\),我们既要对\(L-1\)求解,也要对\(R\)求解,那么两次求解过程中,zero和limit的值跟当前求解的\(x\)相关,属于特殊数据,不能计入通用的dp值。当然,假如题目只要求求解一组\([1,n]\)的话,这两个状态就可以计入dp。

具体的结合题目来分析。

题目

LightOJ-1068-Investigation

给出区间\([A,B]\),和一个数\(K\),统计区间中能被\(K\)整除且数位之和也能被\(K\)整除的数。\(1 \le A,B < 2^{31}, 0 \le K \le 10000\)

首先这个K是唬人的,因为数位之和不超过90。。。那么就可以把余数计入状态,\(dp[i][m1][m2]\)分别表示前\(i\)位,数位之和\(\%K\)\(m1\)且数\(\%K\)\(m2\)的数量,枚举当前位0~9转移即可。

点击查看代码
void init()
{
    for (int i = 1; i <= MAXN; ++i)
        for (int j = 0; j < 10; ++j)
            for (int x = 0; x < K; ++x)
                for (int y = 0; y < K; ++y)
                    dp[i][j][x][y] = 0;
    for (int i = 0; i < 10; ++i)
        dp[1][i][i % K][i % K] = 1;
    for (int i = 2, v = 1; i <= MAXN; ++i)
    {
        v = (v * 10) % K;
        for (int j = 0; j < 10; ++j)
        {
            for (int k = 0; k < 10; ++k)
                for (int x = 0; x < K; ++x)
                    for (int y = 0; y < K; ++y)
                    {
                        dp[i][j][(x + v * j) % K][(y + j) % K] += dp[i - 1][k][x][y];
                    }
        }
    }
}

int solve(int n)
{
    if (n == 0)
        return 1;
    std::vector<int> digit;
    for (int x = n; x; x /= 10)
        digit.push_back(x % 10);
    int res = 0;
    int v = 1;
    for (int i = 1; i < digit.size(); ++i)
        v *= 10;
    for (int i = digit.size() - 1; i >= 1; --i)
        for (int j = 1; j < 10; ++j)
            res += dp[i][j][0][0];
    res++;
    int m[2] = {0, 0};
    for (int i = digit.size() - 1; i >= 0; --i, v /= 10)
    {
        // printf("%d\n", v);
        for (int j = (i == digit.size() - 1); j < digit[i]; ++j)
            res += dp[i + 1][j][(K - m[0]) % K][(K - m[1]) % K];
        m[0] = (m[0] + 1ll * digit[i] * v % K) % K;
        m[1] = (m[1] + digit[i]) % K;
    }
    if (m[0] == 0 && m[1] == 0)
        res++;
    return res;
}

Logan and DIGIT IMMUNE numbers

在区间\([A,B]\)里面找到第\(K\)个“仅由奇数组成,且不能被任意一位数整除的”数。\(1 \le A,B < 10^{18}, 1 \le K \le 10^{18}\)

这里要求第\(K\)大了,好像跟普通的数位DP问题不太一样。不过可以通过二分答案来转化成原问题。即找到第一个\(m\),区间\([A,m]\)中恰好有\(K\)个满足条件的数。
注意到问题要求不能被任意一位数整除,那么还得记录一下当前出现过哪些数,以及它们对应的余数。一共有\(4\)个可能出现的数,\(3,5,7,9\),用状态压缩记录一下,然后剩下四个状态表示余数就可以了。

点击查看代码
void init()
{
    dp[0][0][0][0][0][0] = 1;
    dp[1][1][0][3 % 5][3 % 7][3 % 9] = 1;
    dp[1][2][5 % 3][0][5 % 7][5 % 9] = 1;
    dp[1][4][7 % 3][7 % 5][0][7 % 9] = 1;
    dp[1][8][9 % 3][9 % 5][9 % 7][0] = 1;
    ll v = 1;
    /*for (int i = 2; i <= 2; ++i)
    {
        v = 10;
        for (int mask = 5; mask < 6; ++mask)
        {
            for (int a = 0; a <= 0; ++a)
                for (int b = 3; b <= 3; ++b)
                    for (int c = 3; c <= 3; ++c)
                        for (int d = 3; d <= 3; ++d)
                            if (mask & 1)
                                dp[i][mask][(v * 7 + a) % 3][(v * 7 + b) % 5][(v * 7 + c) % 7][(v * 7 + d) % 9] = dp[i - 1][mask][a][b][c][d] + dp[i - 1][mask ^ 4][a][b][c][d];
        }
    }*/
    v = 1;
    for (int i = 2; i <= MAXN - 1; ++i)
    {
        v *= 10;
        for (int mask = 1; mask < 16; ++mask)
        {
            for (int a = 0; a < 3; ++a)
                for (int b = 0; b < 5; ++b)
                    for (int c = 0; c < 7; ++c)
                        for (int d = 0; d < 9; ++d)
                        {
                            if (mask & 1)
                                dp[i][mask][a][(v * 3 + b) % 5][(v * 3 + c) % 7][(v * 3 + d) % 9] += dp[i - 1][mask][a][b][c][d] + dp[i - 1][mask ^ 1][a][b][c][d];
                            if (mask & 2)
                                dp[i][mask][(v * 5 + a) % 3][b][(v * 5 + c) % 7][(v * 5 + d) % 9] += dp[i - 1][mask][a][b][c][d] + dp[i - 1][mask ^ 2][a][b][c][d];
                            if (mask & 4)
                                dp[i][mask][(v * 7 + a) % 3][(v * 7 + b) % 5][c][(v * 7 + d) % 9] += dp[i - 1][mask][a][b][c][d] + dp[i - 1][mask ^ 4][a][b][c][d];
                            if (mask & 8)
                                dp[i][mask][(v * 9 + a) % 3][(v * 9 + b) % 5][(v * 9 + c) % 7][d] += dp[i - 1][mask][a][b][c][d] + dp[i - 1][mask ^ 8][a][b][c][d];
                        }
        }
    }
}

ll solve(ll n)
{
    if (n < 10)
        return 0;

    std::vector<int> digit;
    for (ll x = n; x; x /= 10)
        digit.push_back(x % 10);
    ll res = 0;
    ll v = 1;
    for (int i = digit.size() - 1; i >= 1; --i)
    {
        v *= 10;
        for (int mask = 1; mask < 16; ++mask)
        {
            for (int a = (mask & 1) == 1; a < 3; ++a)
                for (int b = (mask & 2) == 2; b < 5; ++b)
                    for (int c = (mask & 4) == 4; c < 7; ++c)
                        for (int d = (mask & 8) == 8; d < 9; ++d)
                            res += dp[i][mask][a][b][c][d];
        }
    }
    int last = 0;
    ll num = 0;
    int d[4] = {3, 5, 7, 9};
    for (int i = digit.size() - 1; i >= 0; --i)
    {
        // printf("%lld\n", res);
        for (int k = 0; k < 4; ++k)
        {
            if (d[k] >= digit[i])
                break;
            ll tmp = num + 1ll * d[k] * v;
            int s0 = last | (1 << k);
            for (int s = (i != 0); s < 16; ++s)
            {
                int mask = s | s0;
                for (int a = (mask & 1) == 1; a < 3; ++a)
                    for (int b = (mask & 2) == 2; b < 5; ++b)
                        for (int c = (mask & 4) == 4; c < 7; ++c)
                            for (int d = (mask & 8) == 8; d < 9; ++d)
                            {
                                int na = (a + 3 - (tmp % 3)) % 3, nb = (b + 5 - (tmp % 5)) % 5;
                                int nc = (c + 7 - (tmp % 7)) % 7, nd = (d + 9 - (tmp % 9)) % 9;
                                res += dp[i][s][na][nb][nc][nd];
                            }
            }
        }
        if (digit[i] != 3 && digit[i] != 7 && digit[i] != 9 && digit[i] != 5)
            return res;
        if (digit[i] == 3)
            last |= 1;
        if (digit[i] == 5)
            last |= 2;
        if (digit[i] == 7)
            last |= 4;
        if (digit[i] == 9)
            last |= 8;
        num += 1ll * digit[i] * v;
        v /= 10;
    }
    // printf("solve %lld = %lld\n", n, res);
    for (int k = 0; k < 4; ++k)
    {
        if (last & (1 << k))
            if (num % d[k] == 0)
                return res;
    }
    res++;
    return res;
}

void Main()
{
    ll l, r, k;
    read(l, r, k);
    ll base = solve(l - 1);
    k += base;
    ll mx = solve(r);
    if (mx < k)
    {
        // printf("solve(%d) = %lld\n", r, mx);
        printf("-1\n");
        return;
    }
    while (l <= r)
    {
        ll mid = (l + r) >> 1;
        ll res = solve(mid);
        // printf("solve(%d) = %lld\n", mid, res);
        if (res >= k)
            r = mid - 1;
        else
            l = mid + 1;
    }
    printf("%lld\n", l);
}

Count Pairs: Used

\([1,n]\)中统计数对\((x,y)\)的数量。\((x,y)\)需要满足\(x<y\)\(S(x) <S(y)\)\(S(x)\)代表\(x\)的数位之和
\(n \le 10^{250}\)

这道题要统计的是数对,那么每一位既要枚举\(x\),也要枚举\(y\),并且还得保证\(x<y\)。因为这里要保证数的大小,用记忆化搜索从高位向低位查询是比较方便的做法。同时还要维护一下两个数的数位之和。这样的话复杂度就是\(O(250*2500*2500*2)\),有点高了。注意到我们并不想要知道两个数具体数位和是多少,而是想知道大小关系,那么维护一个差值就好了。因此dp状态就是\(dp[n][diff][flag]\),表示第\(n\)位,当前\(S(y)-S(x)\)\(diff\)\(flag\)代表更高位是否已经满足了\(x<y\),这样的数对数量。枚举的时候根据flag调整\(x\)在这一位上的枚举上限。
进一步地,由于这里只需要求解一次,那么记忆化搜索中的limit也可以加入到dp状态中。

点击查看代码
int solve(int n, int diff, bool less, bool limit)
{
    if (n == 0 && diff > MAXM)
        return 1;
    else if (n == 0)
        return 0;
    if (dp[n][diff][less][limit] != -1)
        return dp[n][diff][less][limit];
    int &res = dp[n][diff][less][limit];
    res = 0;
    for (int y = limit ? digit[n - 1] : 9; y >= 0; --y)
    {
        for (int x = less ? 9 : y; x >= 0; --x)
        {
            res += solve(n - 1, diff + (y - x), less | (x < y), limit & (y == digit[n - 1]));
            res %= mod;
        }
    }
    return res;
}

P2657 [SCOI2009] windy 数

标准的数位DP例题,状态里记录一下每一位的数是啥,然后转移就行了。代码不放了...

P3413 SAC#1 - 萌数

题目中要求存在长度为2的回文串的数。不如直接求不存在回文串的数,用总量一减就行了。那么第n位的数不能跟第n-1位相同,也不能跟第n-2位相同,这样就确保没有回文串。DP状态里记录一下前两位的数就行了。

P4127 [AHOI2009]同类分布

枚举数字的和sum,问题就变成有多少数字和为sum且能被sum整除的数,变成一个背包问题。

P2602 [ZJOI2010] 数字计数

标准数位DP例题,枚举每一位的决策,将下一位得到的方案数累加在该决策的答案里。

P4317 花神的数论题

把问题转化一下。假设出现了\(k\)\(1\)的数有\(sum[k]\)个,那么最终答案就是\(\prod_{i=1}^{log_2{n}}i^{sum[i]}\)
那么就求\(sum[k]\),就是普通的数位DP了。

CF55D

Volodya 认为一个数字 \(x\) 是美丽的,当且仅当 \(x\in\mathbb{Z^+}\)并且对于 \(x\) 的每一个非零位上的数 \(y\),都有 \(y|x\)。你需要帮助他算出在区间 \([l,r]\) 中有多少个数是美丽的。\(1≤l≤r≤9×10^{18}\)

转换一下题意就是说,数\(x\)要被其数位上的最小公倍数整除,\(lcm\)最大不超过\(2520\)。一开始想到了状压,确定了数字出现的状态之后就可以直接求出\(lcm\)。然后发现可能的\(lcm\)不超过\(50\)个,好像是\(49\)个,那么枚举这些\(lcm\),对每次枚举都做一次数位dp。注意到\(lcm\)的值很大,但是个数很小,可以以离散化的形式设计状态。假设枚举了一个\(lcm\)\(dp[n][cur][rem]\)表示当前n位数的最小公倍数是\(cur\),对\(lcm\)取余值为\(rem\)的数字个数。这个\(cur\)可以以离散之后的值代替。

点击查看代码
ll solve(int n, int rem, int cur, int LCM, bool zero, bool limit, const std::vector<int> &digit)
{
    if (n == 0 && rem == 0 && cur == LCM)
        return 1;
    if (n == 0)
        return 0;
    if (!zero && !limit && dp[n][mp[cur]][rem] != -1)
        return dp[n][mp[cur]][rem];
    ll res = 0;
    int mx = limit ? digit[n - 1] : 9;
    res += solve(n - 1, rem, cur, LCM, zero, limit && (0 == digit[n - 1]), digit);
    for (int i = 1; i <= mx; ++i)
    {
        if (LCM % i == 0)
            res += solve(n - 1, (rem + (pw[n - 1] * i % LCM)) % LCM, std::lcm(cur, i), LCM, false, limit & (i == digit[n - 1]), digit);
    }
    if (!zero && !limit)
        dp[n][mp[cur]][rem] = res;
    return res;
}

CF628D

给你 \(4\) 个数 \(m,d,l,r\)保证 \(l,r\)位数相同。问满足以下条件的数 \(x\) 的个数:

  1. \(l \leq x\leq r\)
  2. \(x\) 的偶数位是 \(d\),奇数位不是 \(d\)。 (这里定义偶数位为从高位往低位的数的偶数位)
  3. \(m|x\)

答案对 \(1000000007\) 取模。
\(1\le m \le 2000,0\le d \le 9,1\le l \le r \le 10^{2000}\)

\(dp[n][rem]\)表示\(n\)位,对\(m\)取余为\(rem\)的数字个数。值得注意的是这里奇偶位从高到低,所以记忆化搜索比较好写,而且要额外判断一个前导0的状态,因为前导0会影响高位的奇偶。

点击查看代码
int solve(int n, int rem, int dep, bool zero, bool limit, int digit[])
{
    if (n == 0 && rem == 0)
        return 1;
    if (n == 0)
        return 0;
    if (!limit && !zero && dp[n][rem][dep] != -1)
        return dp[n][rem][dep];
    int mx = limit ? digit[n - 1] : 9;
    int res = 0;
    if (dep)
    {
        for (int i = 0; i <= mx; ++i)
        {
            if (d == 0)
            {
                if (!zero && i == d)
                    continue;
            }
            else if (d == i)
                continue;
            int ndep = (zero && (i == 0)) ? 1 : 0;
            res += solve(n - 1, (rem + 1ll * i * pw[n - 1] % m) % m, ndep, zero && (i == 0), limit && (i == digit[n - 1]), digit);
            res %= mod;
        }
    }
    else
    {
        if (d <= mx)
            res = solve(n - 1, (rem + 1ll * d * pw[n - 1] % m) % m, 1, zero && (d == 0), limit && (d == digit[n - 1]), digit);
    }
    if (!zero && !limit)
        dp[n][rem][dep] = res;
    return res;
}

CF1073E

给定\(K,L,R\),求\(L\)~\(R\)之间最多不包含超过\(K\)个数码的数的和。
\(K\le10,L,R\le 1^{18}\)

求所有数的和不太好计算。不如单独考虑每一位的贡献。假设第n位填i有k个数,它对答案的贡献就是\(k * i * 10^{n-1}\),那么不单要对数字的和做dp,还要dp一下数字的数量。
\(dp[n][mask][0]\)表示前n位,目前出现的数字状态是\(mask\)的数字数量,以\(dp[n][mask][1]\)表示数字的和。每枚举到一个数码i,\(dp[n][mask][1] = i * 10^{n-1} * dp[n-1][nmask][0] + dp[n-1][nmask][1]\)

点击查看代码
std::pair<int, int> solve(int n, int mask, int target, bool zero, bool limit, int digit[])
{
    if (n == 0 && mask == target)
        return std::make_pair(1, 0);
    if (n == 0)
        return std::make_pair(0, 0);
 
    if (!zero && !limit && dp[n][mask][0] != -1)
        return std::make_pair(dp[n][mask][0], dp[n][mask][1]);
 
    std::pair<int, int> res;
    int mx = limit ? digit[n] : 9;
    for (int i = 0; i <= mx; ++i)
    {
        if ((target & (1 << i)) || (zero && (i == 0)))
        {
            int nmask = mask | (1 << i);
            if (i == 0 && zero)
                nmask = mask;
            std::pair<int, int> ret = solve(n - 1, nmask, target, zero && (i == 0), limit && (i == digit[n]), digit);
            res.first = (res.first + ret.first) % mod;
            res.second += ((1ll * ret.first * pw[n - 1] % mod) * i % mod + ret.second) % mod;
            res.second %= mod;
        }
    }
 
    if (!limit && !zero)
        dp[n][mask][0] = res.first, dp[n][mask][1] = res.second;
    return res;
}

CF1710C

给你一个数 \(n\),问:有多少对数 \(0\leq a,b,c \leq n\)满足\(a \oplus b,b \oplus c,a \oplus c\)。三个数字构成了一个非退化三角形,也就是两条短边之和大于第三边的长度。

\(f = a \oplus b,g = b \oplus c,h = a \oplus c\),分别用三个bit表示是否已经满足\(f+g>h, f+h>g, g+h>f\),因为这里出现了比较大小,更适合记搜去写。然而这里出现了一个问题,就是\(f+g\)产生的进位有可能会影响到高位的大小比较结果,这种情况出现在\(f+g\)末尾有一串1的情况。然而用真值表发现,\(f,g,h\)要么都是0,要么有两个1一个0,那么\(f+g\)在本位总是0,即便低一有进位1,也不会传递到更高位。对于其他两种组合同理。那么就安心记搜就好了。

点击查看代码
int solve(int n, bool alimit, bool blimit, bool climit, bool tag1, bool tag2, bool tag3)
{
    if (n == 0 && tag1 && tag2 && tag3)
        return 1;
    if (n == 0)
        return 0;
    int mask1 = (alimit << 2) | (blimit << 1) | climit, mask2 = (tag1 << 2) | (tag2 << 1) | tag3;
    if (dp[n][mask1][mask2] != -1)
        return dp[n][mask1][mask2];
 
    int &res = dp[n][mask1][mask2];
    res = 0;
    int mxa = alimit ? digit[n] : 1, mxb = blimit ? digit[n] : 1, mxc = climit ? digit[n] : 1;
    for (int a = 0; a <= mxa; ++a)
        for (int b = 0; b <= mxb; ++b)
            for (int c = 0; c <= mxc; ++c)
            {
                res += solve(n - 1, alimit && (a == digit[n]), blimit && (b == digit[n]), climit && (c == digit[n]),
                             tag1 | ((a ^ b) + (b ^ c) > (a ^ c)), tag2 | ((a ^ b) + (a ^ c) > (b ^ c)),
                             tag3 | ((a ^ c) + (b ^ c) > (a ^ b)));
                if (res > mod)
                    res -= mod;
            }
    return res;
}

CF1036C

定义一个数字是“好数”,当且仅当它的十进制表示下有不超过\(3\)\(1 \sim 9\)
给定\([l,r]\),问有多少个\(x\)使得\(l \le x \le r\)\(x\)是“好数”
\(1 \le l \le r \le 10^{18}\)

太裸了,注意前导0就行了。

CF1734F

S 是一个 Thue-Morse序列。它是一个由以下方式生成的无限长 01 字符串:
最初,令 \(S\) 为 "0"。
随后进行以下操作无穷多次:将 \(S\) 与各位取反后的 \(S\) 连接。以前 4 次操作为例:
次数 \(S\) 取反后的 \(S\) 操作后得到的\(S\)
1 01
2 0110
3 01101001
4 0110100110010110
给定 2 个正整数 \(n\)\(m\),求 \(S_0S_1 \cdots S_{m-1}\)\(S_{n}S_{n+1} \cdots S_{n+m-1}\)有几位不同。
\(1≤n,m≤10^{18}\)

一个很巧妙的题!这个字符串每扩充\(2^i\)的长度,就会翻转一次。也就是说,对一个下标\(x\),每在这个\(x\)的二进制串的最前面加一个1,字符就翻转一次。下标二进制串里面1的数量的奇偶性每变化一次都会翻转一下字符。于是题目就变成了在\([1,m]\)里有多少数\(x\)\(x+n\)\(x\)里面1的数量的奇偶性是不一样的。
因为这里面是比较\(x\)\(x+n\),加法会产生进位,所以适合递推法来写。
\(dp[n][x][odd][p]\)表示前n位,当前填\(x\)\(1\)的奇偶性为\(odd\),后面的进位为\(p\)的数量。枚举一下\(n-1\)的决策就可以转移了。

点击查看代码
void init()
{
    memset(dp, 0, sizeof dp);
    dp[1][0][0][0] = ((n & 1) == 0);
    dp[1][1][0][0] = ((n & 1) == 0);
    dp[1][0][1][0] = (n & 1);
    dp[1][1][1][0] = (n & 1);
    for (int i = 2; i <= 60; ++i)
    {
        int bit = (n & (1ll << (i - 1))) != 0;
        int pbit = (n & (1ll << (i - 2))) != 0;
        for (int x = 0; x < 2; ++x)
        {
            for (int y = 0; y < 2; ++y)
            {
                for (int p = 0; p < 2; ++p)
                {
                    int pre = y + pbit + p;
                    int cur = x + bit + (pre > 1);
                    for (int odd = 0; odd < 2; ++odd)
                        dp[i][x][odd][pre > 1] += dp[i - 1][y][(cur & 1) ^ x ^ odd][p];
                }
            }
        }
    }
}
 
ll solve(ll m)
{
    if (m == 0)
        return __builtin_popcountll(0) % 2 != __builtin_popcountll(n) % 2;
 
    std::vector<int> digit;
    for (int i = 0; i < 60; ++i)
        digit.push_back((m & (1ll << i)) != 0);
 
    ll res = 0;
    ll last = 0, pren = 0;
    for (int i = 59; i >= 0; --i)
    {
        int nbit = (n & (1ll << i)) != 0;
        if (digit[i] == 1)
        {
            for (int p = 0; p < 2; ++p)
            {
                ll pre = last + pren;
 
                if (p + nbit > 1)
                    pre += (1ll << (i + 1)); 
                bool odd = (__builtin_popcountll(pre) % 2) ^ (__builtin_popcountll(last) % 2);
                res += dp[i + 1][0][odd ^ 1][p];
            }
            last += (1ll << i);
        }
        if (nbit)
            pren += (1ll << i);
    }
    if (__builtin_popcountll(m) % 2 != __builtin_popcountll(m + n) % 2)
        res++;
    return res;
}

CF1290F

\(n\)个向量组成一个凸包,让这些凸包落在\(m*m\)的正方形里,有多少种方案. \(n \le 5, m \le 10^9\)

要保证是凸包,可以将向量按逆时针排列,这样就会发现,假设向量\(v\)在凸包里出现了\(k\)次,那么它一定是连续的\(k\)次。因此如果\(k_i * v_i\)这些向量可以构成一个凸包,这个凸包就是唯一的。换句话说要求的是\(n\)元组\(k_i\)的数量,使得\(k_i * v_i\)构成一个凸包。

首先构成凸包一定要闭合。那么\(\sum_{i=1}^{n}k_i * v_i = (0,0)\)。把横纵坐标单独拿出来就是\(\sum_{i=1}^{n}k_i * x_i = 0, \sum_{i=1}^{n}k_i * y_i = 0\)。再将\(x_i\)里面的正负拆出来,就是\(\sum_{i=1}^{n}k_i * x_i *(x_i \ge 0) = \sum_{i=1}^{n}k_i*(-x_i)*(x_i<0)\),对于\(y\)同理。并且这个凸包还要落在\(m*m\)的矩阵里,那么\(\sum_{i=1}^{n}k_i * x_i *(x_i \ge 0) \le m\), 对于\(y\)同理。

接下来就是数位DP,按位枚举\(n\)元组的每一位。注意到这里面我们可以选择任意进制来数位DP,进制越小,位数越多,但是每一位的决策越少。枚举\(n\)元组每一位,\(\sum_{i=1}^{n}k_i * x_i*(x_i \ge 0)\)会产生向上的进位,其他同理,因此状态里还要记录四个进位(\(x \ge 0, x < 0, y \ge 0, y < 0\)四种情况),并且时刻维护着小于\(m\)的限制。\(dp[n][px][py][nx][ny][fx][fy]\)分别表示前n位,正的\(x\)进位为px,正的\(y\)进位为py,负进位为\(nx\)\(ny\)\(fx\)表示当前正\(x\)的和是否小于\(m\)\(fy\)表示当前正\(y\)的和是否小于\(m\)。暴力转移即可。
如果选择2进制来做数位DP,时间复杂度是\(O(log_2{m}*20^4*4*2^n)\)

点击查看代码
void Main()
{
    int n, m;
    read(n, m);
    for (int i = 1; i <= n; ++i)
        read(x[i], y[i]);
    dp[0][0][0][0][0][0][0] = 1;
    int s = (1 << n) - 1;
    for (int i = 0; i <= s; ++i)
        for (int j = 1; j <= n; ++j)
            if (i & (1 << (j - 1)))
                (x[j] > 0 ? px : nx)[i] += std::abs(x[j]), (y[j] > 0 ? py : ny)[i] += std::abs(y[j]);
    for (int i = 0; i <= 30; ++i)
        for (int a = 0; a <= px[s]; ++a)
            for (int b = 0; b <= py[s]; ++b)
                for (int c = 0; c <= nx[s]; ++c)
                    for (int d = 0; d <= ny[s]; ++d)
                        for (int fx = 0; fx < 2; ++fx)
                            for (int fy = 0; fy < 2; ++fy)
                            {
                                if (dp[i][a][b][c][d][fx][fy] == 0)
                                    continue;
                                for (int mask = 0; mask <= s; ++mask)
                                {
                                    int A = a + px[mask], B = b + py[mask], C = c + nx[mask], D = d + ny[mask];
                                    if ((A & 1) == (C & 1) && (B & 1) == (D & 1))
                                    {
                                        int bit = (m >> i) & 1;
                                        int FX = (A & 1) == bit ? fx : ((A & 1) > bit);
                                        int FY = (B & 1) == bit ? fy : ((B & 1) > bit);
                                        dp[i + 1][A >> 1][B >> 1][C >> 1][D >> 1][FX][FY] += dp[i][a][b][c][d][fx][fy];
                                    }
                                }
                            }
    modint998244353 ans = dp[31][0][0][0][0][0][0] - 1;
    printf("%d\n", ans);
}

CCPC2022广州M

下班前偶然发现的一道题,看了看题然后骑车回家的时候脑子里口胡了一下。
首先式子是一个异或的式子,不难想到按位算贡献。
\(i\)位如果有\(x\)个1,对value的贡献就是\(2^i * x * (k -x)\)。那么做数位DP的时候枚举每一位有多少个1,把进位记录在状态里,并且跟\(n\)匹配一下,不合法的可以直接剔除掉。
接下来考虑\(a_i \le m\)的限制。一个显然的结论是用一个\(bitmask\)记录哪些数在当前位之前小于\(m\),但是这样复杂度过高了(\(2^k\))。不像上一道题五元组里每个数对应的向量是不一样的,这道题的\(a_i\)之间是没有什么差别的,那么可以把这个\(bitmask\)浓缩成一个状态——当前有\(y\)个数大于\(m\)。以\(dp[i][j][y]\)表示前\(i\)位,当前进位是\(j\), 有\(y\)个数大于\(m\)。枚举这一位填的1的数量\(x\),那么这\(x\)个数可以分给当前大于\(m\)的数一部分,另一部分分给小于等于\(m\)数的一部分,然后重新计算出新的状态里有多少个数大于\(m\),再用组合数乘一下算出转移。

点击查看代码
void Main()
{
    C[0][0] = 1;
    for (int i = 1; i <= 20; ++i)
    {
        C[i][0] = 1;
        for (int j = 1; j <= i; ++j)
            C[i][j] = C[i - 1][j - 1] + C[i - 1][j];
    }
 
    ll n, m, k;
    read(n, m, k);
 
    dp[0][0][0] = 1;
    for (int i = 0; i <= 50; ++i)
    {
        int nbit = ((n & (1ll << i)) != 0);
        int mbit = ((m & (1ll << i)) != 0);
        // printf("bit:%d, n:%d, m:%d\n", i + 1, nbit, mbit);
        for (int j = 0; j <= 100; ++j)
            for (int x = 0; x <= k; ++x)
            {
                if (dp[i][j][x] == 0)
                    continue;
                // printf("dp[%d][%d][%d] = %d\n", i, j, x, dp[i][j][x]);
                for (int y = 0; y <= k; ++y)
                {
                    int cur = y * (k - y) + j;
                    if ((cur & 1) != nbit)
                        continue;
                    int mxz = std::min(y, x);
                    for (int z = 0; z <= mxz; ++z)
                    {
                        if (mbit == 1)
                            dp[i + 1][cur >> 1][z] += C[x][z] * C[k - x][y - z] * dp[i][j][x];
                        else
                            dp[i + 1][cur >> 1][x + y - z] += C[x][z] * C[k - x][y - z] * dp[i][j][x];
                    }
                }
            }
    }
    printf("%d\n", dp[51][0][0]);
}
posted @ 2022-11-16 19:26  KSYImba  阅读(55)  评论(2)    收藏  举报