数位DP题目瞎做

普通的数位DP,题目形式大概是——

在区间[L,R]里,统计满足条件的数的个数,并且暴力统计会严重超时

这样就提醒我们在每一位去枚举这一位数字的决策,递推统计方案。对于区间[L,R]的查询,可以转化为对区间[1,R][1,L1]的查询——>给出一个数x,查询小于等于x且满足条件的数的个数。

在解决这个问题的时候,会将数x按一棵树去处理。假设数xn位,分别是Xn,Xn1,...,X1。这个处理过程如下图:

image

对于数的第p位,这一位左子树我们统计的是XnXn1...Xp+100000...0XnXn1...Xp+1(Xp1)9999...9里面符合条件的数。也就是1p1位的数都可以放飞自我随意取值。那么左子树的值就可以通过前p1位的递推值以及XnXn1...Xp+1这个前缀综合起来求解,最终我们留下的就是一个从根到最右端节点的链还没有统计过,也就是x本身,统计x本身显然是一个非常简单的问题。需要注意的是第一棵左子树[0,Xn1],里面会包含前导0,有的题目不允许前导0,需要特殊判断。

一般来说,数位DP有两种写法,递推预处理法和记忆化搜索法。前者可以认为是自低位向高位的顺推,后者可以认为是自高位向低位的逆推。

顺推法会提前处理好所有情况下的dp状态值,接下来根据查询的数,在左子树里面按需取值。更适合处理低位决策会影响高位决策的情况,例如状态中会记录低位产生的进位。

逆推法更适合处理高位会影响低位决策的情况,例如状态中会出现数字的大小关系(高位如果小了,低位不会改变大小关系)。整个记忆化搜索的过程就像上图这棵树的过程一样,非常直观。

一般记忆化搜索的模板如下:

点击查看代码
ll solve(int n, state..., bool zero, bool limit, const std::vector<int> &digit)
{
    if (n == 0 && 满足条件)
        return 1;
    if (n == 0)
        return 0;
    if (!zero && !limit && dp[n][state] != -1)
        return dp[n][state];
    ll res = 0;
    int mx = limit ? digit[n - 1] : 9;
    for (int i = 0; i <= mx; ++i)
        res += solve(n - 1, newstate, zero&&(i==0), limit & (i == digit[n - 1]), digit);
    if (!zero && !limit)
        dp[n][state] = res;
    return res;
}

zero表示当前是不是前导0,limit表示当前是不是在树的最右端的链上,如果在最右端的链上,那么下一位的取值就不能超过x本身在下一位的值。一旦进入某颗左子树,limit就会变成false。在zero和limit是true的情况下,不能更新dp的值。因为如果题目查询的是区间[L,R],我们既要对L1求解,也要对R求解,那么两次求解过程中,zero和limit的值跟当前求解的x相关,属于特殊数据,不能计入通用的dp值。当然,假如题目只要求求解一组[1,n]的话,这两个状态就可以计入dp。

具体的结合题目来分析。

题目

LightOJ-1068-Investigation

给出区间[A,B],和一个数K,统计区间中能被K整除且数位之和也能被K整除的数。1A,B<231,0K10000

首先这个K是唬人的,因为数位之和不超过90。。。那么就可以把余数计入状态,dp[i][m1][m2]分别表示前i位,数位之和%Km1且数%Km2的数量,枚举当前位0~9转移即可。

点击查看代码
void init()
{
    for (int i = 1; i <= MAXN; ++i)
        for (int j = 0; j < 10; ++j)
            for (int x = 0; x < K; ++x)
                for (int y = 0; y < K; ++y)
                    dp[i][j][x][y] = 0;
    for (int i = 0; i < 10; ++i)
        dp[1][i][i % K][i % K] = 1;
    for (int i = 2, v = 1; i <= MAXN; ++i)
    {
        v = (v * 10) % K;
        for (int j = 0; j < 10; ++j)
        {
            for (int k = 0; k < 10; ++k)
                for (int x = 0; x < K; ++x)
                    for (int y = 0; y < K; ++y)
                    {
                        dp[i][j][(x + v * j) % K][(y + j) % K] += dp[i - 1][k][x][y];
                    }
        }
    }
}

int solve(int n)
{
    if (n == 0)
        return 1;
    std::vector<int> digit;
    for (int x = n; x; x /= 10)
        digit.push_back(x % 10);
    int res = 0;
    int v = 1;
    for (int i = 1; i < digit.size(); ++i)
        v *= 10;
    for (int i = digit.size() - 1; i >= 1; --i)
        for (int j = 1; j < 10; ++j)
            res += dp[i][j][0][0];
    res++;
    int m[2] = {0, 0};
    for (int i = digit.size() - 1; i >= 0; --i, v /= 10)
    {
        // printf("%d\n", v);
        for (int j = (i == digit.size() - 1); j < digit[i]; ++j)
            res += dp[i + 1][j][(K - m[0]) % K][(K - m[1]) % K];
        m[0] = (m[0] + 1ll * digit[i] * v % K) % K;
        m[1] = (m[1] + digit[i]) % K;
    }
    if (m[0] == 0 && m[1] == 0)
        res++;
    return res;
}

Logan and DIGIT IMMUNE numbers

在区间[A,B]里面找到第K个“仅由奇数组成,且不能被任意一位数整除的”数。1A,B<1018,1K1018

这里要求第K大了,好像跟普通的数位DP问题不太一样。不过可以通过二分答案来转化成原问题。即找到第一个m,区间[A,m]中恰好有K个满足条件的数。
注意到问题要求不能被任意一位数整除,那么还得记录一下当前出现过哪些数,以及它们对应的余数。一共有4个可能出现的数,3,5,7,9,用状态压缩记录一下,然后剩下四个状态表示余数就可以了。

点击查看代码
void init()
{
    dp[0][0][0][0][0][0] = 1;
    dp[1][1][0][3 % 5][3 % 7][3 % 9] = 1;
    dp[1][2][5 % 3][0][5 % 7][5 % 9] = 1;
    dp[1][4][7 % 3][7 % 5][0][7 % 9] = 1;
    dp[1][8][9 % 3][9 % 5][9 % 7][0] = 1;
    ll v = 1;
    /*for (int i = 2; i <= 2; ++i)
    {
        v = 10;
        for (int mask = 5; mask < 6; ++mask)
        {
            for (int a = 0; a <= 0; ++a)
                for (int b = 3; b <= 3; ++b)
                    for (int c = 3; c <= 3; ++c)
                        for (int d = 3; d <= 3; ++d)
                            if (mask & 1)
                                dp[i][mask][(v * 7 + a) % 3][(v * 7 + b) % 5][(v * 7 + c) % 7][(v * 7 + d) % 9] = dp[i - 1][mask][a][b][c][d] + dp[i - 1][mask ^ 4][a][b][c][d];
        }
    }*/
    v = 1;
    for (int i = 2; i <= MAXN - 1; ++i)
    {
        v *= 10;
        for (int mask = 1; mask < 16; ++mask)
        {
            for (int a = 0; a < 3; ++a)
                for (int b = 0; b < 5; ++b)
                    for (int c = 0; c < 7; ++c)
                        for (int d = 0; d < 9; ++d)
                        {
                            if (mask & 1)
                                dp[i][mask][a][(v * 3 + b) % 5][(v * 3 + c) % 7][(v * 3 + d) % 9] += dp[i - 1][mask][a][b][c][d] + dp[i - 1][mask ^ 1][a][b][c][d];
                            if (mask & 2)
                                dp[i][mask][(v * 5 + a) % 3][b][(v * 5 + c) % 7][(v * 5 + d) % 9] += dp[i - 1][mask][a][b][c][d] + dp[i - 1][mask ^ 2][a][b][c][d];
                            if (mask & 4)
                                dp[i][mask][(v * 7 + a) % 3][(v * 7 + b) % 5][c][(v * 7 + d) % 9] += dp[i - 1][mask][a][b][c][d] + dp[i - 1][mask ^ 4][a][b][c][d];
                            if (mask & 8)
                                dp[i][mask][(v * 9 + a) % 3][(v * 9 + b) % 5][(v * 9 + c) % 7][d] += dp[i - 1][mask][a][b][c][d] + dp[i - 1][mask ^ 8][a][b][c][d];
                        }
        }
    }
}

ll solve(ll n)
{
    if (n < 10)
        return 0;

    std::vector<int> digit;
    for (ll x = n; x; x /= 10)
        digit.push_back(x % 10);
    ll res = 0;
    ll v = 1;
    for (int i = digit.size() - 1; i >= 1; --i)
    {
        v *= 10;
        for (int mask = 1; mask < 16; ++mask)
        {
            for (int a = (mask & 1) == 1; a < 3; ++a)
                for (int b = (mask & 2) == 2; b < 5; ++b)
                    for (int c = (mask & 4) == 4; c < 7; ++c)
                        for (int d = (mask & 8) == 8; d < 9; ++d)
                            res += dp[i][mask][a][b][c][d];
        }
    }
    int last = 0;
    ll num = 0;
    int d[4] = {3, 5, 7, 9};
    for (int i = digit.size() - 1; i >= 0; --i)
    {
        // printf("%lld\n", res);
        for (int k = 0; k < 4; ++k)
        {
            if (d[k] >= digit[i])
                break;
            ll tmp = num + 1ll * d[k] * v;
            int s0 = last | (1 << k);
            for (int s = (i != 0); s < 16; ++s)
            {
                int mask = s | s0;
                for (int a = (mask & 1) == 1; a < 3; ++a)
                    for (int b = (mask & 2) == 2; b < 5; ++b)
                        for (int c = (mask & 4) == 4; c < 7; ++c)
                            for (int d = (mask & 8) == 8; d < 9; ++d)
                            {
                                int na = (a + 3 - (tmp % 3)) % 3, nb = (b + 5 - (tmp % 5)) % 5;
                                int nc = (c + 7 - (tmp % 7)) % 7, nd = (d + 9 - (tmp % 9)) % 9;
                                res += dp[i][s][na][nb][nc][nd];
                            }
            }
        }
        if (digit[i] != 3 && digit[i] != 7 && digit[i] != 9 && digit[i] != 5)
            return res;
        if (digit[i] == 3)
            last |= 1;
        if (digit[i] == 5)
            last |= 2;
        if (digit[i] == 7)
            last |= 4;
        if (digit[i] == 9)
            last |= 8;
        num += 1ll * digit[i] * v;
        v /= 10;
    }
    // printf("solve %lld = %lld\n", n, res);
    for (int k = 0; k < 4; ++k)
    {
        if (last & (1 << k))
            if (num % d[k] == 0)
                return res;
    }
    res++;
    return res;
}

void Main()
{
    ll l, r, k;
    read(l, r, k);
    ll base = solve(l - 1);
    k += base;
    ll mx = solve(r);
    if (mx < k)
    {
        // printf("solve(%d) = %lld\n", r, mx);
        printf("-1\n");
        return;
    }
    while (l <= r)
    {
        ll mid = (l + r) >> 1;
        ll res = solve(mid);
        // printf("solve(%d) = %lld\n", mid, res);
        if (res >= k)
            r = mid - 1;
        else
            l = mid + 1;
    }
    printf("%lld\n", l);
}

Count Pairs: Used

[1,n]中统计数对(x,y)的数量。(x,y)需要满足x<yS(x)<S(y)S(x)代表x的数位之和
n10250

这道题要统计的是数对,那么每一位既要枚举x,也要枚举y,并且还得保证x<y。因为这里要保证数的大小,用记忆化搜索从高位向低位查询是比较方便的做法。同时还要维护一下两个数的数位之和。这样的话复杂度就是O(250250025002),有点高了。注意到我们并不想要知道两个数具体数位和是多少,而是想知道大小关系,那么维护一个差值就好了。因此dp状态就是dp[n][diff][flag],表示第n位,当前S(y)S(x)diffflag代表更高位是否已经满足了x<y,这样的数对数量。枚举的时候根据flag调整x在这一位上的枚举上限。
进一步地,由于这里只需要求解一次,那么记忆化搜索中的limit也可以加入到dp状态中。

点击查看代码
int solve(int n, int diff, bool less, bool limit)
{
    if (n == 0 && diff > MAXM)
        return 1;
    else if (n == 0)
        return 0;
    if (dp[n][diff][less][limit] != -1)
        return dp[n][diff][less][limit];
    int &res = dp[n][diff][less][limit];
    res = 0;
    for (int y = limit ? digit[n - 1] : 9; y >= 0; --y)
    {
        for (int x = less ? 9 : y; x >= 0; --x)
        {
            res += solve(n - 1, diff + (y - x), less | (x < y), limit & (y == digit[n - 1]));
            res %= mod;
        }
    }
    return res;
}

P2657 [SCOI2009] windy 数

标准的数位DP例题,状态里记录一下每一位的数是啥,然后转移就行了。代码不放了...

P3413 SAC#1 - 萌数

题目中要求存在长度为2的回文串的数。不如直接求不存在回文串的数,用总量一减就行了。那么第n位的数不能跟第n-1位相同,也不能跟第n-2位相同,这样就确保没有回文串。DP状态里记录一下前两位的数就行了。

P4127 [AHOI2009]同类分布

枚举数字的和sum,问题就变成有多少数字和为sum且能被sum整除的数,变成一个背包问题。

P2602 [ZJOI2010] 数字计数

标准数位DP例题,枚举每一位的决策,将下一位得到的方案数累加在该决策的答案里。

P4317 花神的数论题

把问题转化一下。假设出现了k1的数有sum[k]个,那么最终答案就是i=1log2nisum[i]
那么就求sum[k],就是普通的数位DP了。

CF55D

Volodya 认为一个数字 x 是美丽的,当且仅当 xZ+并且对于 x 的每一个非零位上的数 y,都有 y|x。你需要帮助他算出在区间 [l,r] 中有多少个数是美丽的。1lr9×1018

转换一下题意就是说,数x要被其数位上的最小公倍数整除,lcm最大不超过2520。一开始想到了状压,确定了数字出现的状态之后就可以直接求出lcm。然后发现可能的lcm不超过50个,好像是49个,那么枚举这些lcm,对每次枚举都做一次数位dp。注意到lcm的值很大,但是个数很小,可以以离散化的形式设计状态。假设枚举了一个lcmdp[n][cur][rem]表示当前n位数的最小公倍数是cur,对lcm取余值为rem的数字个数。这个cur可以以离散之后的值代替。

点击查看代码
ll solve(int n, int rem, int cur, int LCM, bool zero, bool limit, const std::vector<int> &digit)
{
    if (n == 0 && rem == 0 && cur == LCM)
        return 1;
    if (n == 0)
        return 0;
    if (!zero && !limit && dp[n][mp[cur]][rem] != -1)
        return dp[n][mp[cur]][rem];
    ll res = 0;
    int mx = limit ? digit[n - 1] : 9;
    res += solve(n - 1, rem, cur, LCM, zero, limit && (0 == digit[n - 1]), digit);
    for (int i = 1; i <= mx; ++i)
    {
        if (LCM % i == 0)
            res += solve(n - 1, (rem + (pw[n - 1] * i % LCM)) % LCM, std::lcm(cur, i), LCM, false, limit & (i == digit[n - 1]), digit);
    }
    if (!zero && !limit)
        dp[n][mp[cur]][rem] = res;
    return res;
}

CF628D

给你 4 个数 m,d,l,r保证 l,r位数相同。问满足以下条件的数 x 的个数:

  1. lxr
  2. x 的偶数位是 d,奇数位不是 d。 (这里定义偶数位为从高位往低位的数的偶数位)
  3. m|x

答案对 1000000007 取模。
1m2000,0d9,1lr102000

dp[n][rem]表示n位,对m取余为rem的数字个数。值得注意的是这里奇偶位从高到低,所以记忆化搜索比较好写,而且要额外判断一个前导0的状态,因为前导0会影响高位的奇偶。

点击查看代码
int solve(int n, int rem, int dep, bool zero, bool limit, int digit[])
{
    if (n == 0 && rem == 0)
        return 1;
    if (n == 0)
        return 0;
    if (!limit && !zero && dp[n][rem][dep] != -1)
        return dp[n][rem][dep];
    int mx = limit ? digit[n - 1] : 9;
    int res = 0;
    if (dep)
    {
        for (int i = 0; i <= mx; ++i)
        {
            if (d == 0)
            {
                if (!zero && i == d)
                    continue;
            }
            else if (d == i)
                continue;
            int ndep = (zero && (i == 0)) ? 1 : 0;
            res += solve(n - 1, (rem + 1ll * i * pw[n - 1] % m) % m, ndep, zero && (i == 0), limit && (i == digit[n - 1]), digit);
            res %= mod;
        }
    }
    else
    {
        if (d <= mx)
            res = solve(n - 1, (rem + 1ll * d * pw[n - 1] % m) % m, 1, zero && (d == 0), limit && (d == digit[n - 1]), digit);
    }
    if (!zero && !limit)
        dp[n][rem][dep] = res;
    return res;
}

CF1073E

给定K,L,R,求L~R之间最多不包含超过K个数码的数的和。
K10L,R118

求所有数的和不太好计算。不如单独考虑每一位的贡献。假设第n位填i有k个数,它对答案的贡献就是ki10n1,那么不单要对数字的和做dp,还要dp一下数字的数量。
dp[n][mask][0]表示前n位,目前出现的数字状态是mask的数字数量,以dp[n][mask][1]表示数字的和。每枚举到一个数码i,dp[n][mask][1]=i10n1dp[n1][nmask][0]+dp[n1][nmask][1]

点击查看代码
std::pair<int, int> solve(int n, int mask, int target, bool zero, bool limit, int digit[])
{
    if (n == 0 && mask == target)
        return std::make_pair(1, 0);
    if (n == 0)
        return std::make_pair(0, 0);
 
    if (!zero && !limit && dp[n][mask][0] != -1)
        return std::make_pair(dp[n][mask][0], dp[n][mask][1]);
 
    std::pair<int, int> res;
    int mx = limit ? digit[n] : 9;
    for (int i = 0; i <= mx; ++i)
    {
        if ((target & (1 << i)) || (zero && (i == 0)))
        {
            int nmask = mask | (1 << i);
            if (i == 0 && zero)
                nmask = mask;
            std::pair<int, int> ret = solve(n - 1, nmask, target, zero && (i == 0), limit && (i == digit[n]), digit);
            res.first = (res.first + ret.first) % mod;
            res.second += ((1ll * ret.first * pw[n - 1] % mod) * i % mod + ret.second) % mod;
            res.second %= mod;
        }
    }
 
    if (!limit && !zero)
        dp[n][mask][0] = res.first, dp[n][mask][1] = res.second;
    return res;
}

CF1710C

给你一个数 n,问:有多少对数 0a,b,cn满足ab,bc,ac。三个数字构成了一个非退化三角形,也就是两条短边之和大于第三边的长度。

f=ab,g=bc,h=ac,分别用三个bit表示是否已经满足f+g>h,f+h>g,g+h>f,因为这里出现了比较大小,更适合记搜去写。然而这里出现了一个问题,就是f+g产生的进位有可能会影响到高位的大小比较结果,这种情况出现在f+g末尾有一串1的情况。然而用真值表发现,f,g,h要么都是0,要么有两个1一个0,那么f+g在本位总是0,即便低一有进位1,也不会传递到更高位。对于其他两种组合同理。那么就安心记搜就好了。

点击查看代码
int solve(int n, bool alimit, bool blimit, bool climit, bool tag1, bool tag2, bool tag3)
{
    if (n == 0 && tag1 && tag2 && tag3)
        return 1;
    if (n == 0)
        return 0;
    int mask1 = (alimit << 2) | (blimit << 1) | climit, mask2 = (tag1 << 2) | (tag2 << 1) | tag3;
    if (dp[n][mask1][mask2] != -1)
        return dp[n][mask1][mask2];
 
    int &res = dp[n][mask1][mask2];
    res = 0;
    int mxa = alimit ? digit[n] : 1, mxb = blimit ? digit[n] : 1, mxc = climit ? digit[n] : 1;
    for (int a = 0; a <= mxa; ++a)
        for (int b = 0; b <= mxb; ++b)
            for (int c = 0; c <= mxc; ++c)
            {
                res += solve(n - 1, alimit && (a == digit[n]), blimit && (b == digit[n]), climit && (c == digit[n]),
                             tag1 | ((a ^ b) + (b ^ c) > (a ^ c)), tag2 | ((a ^ b) + (a ^ c) > (b ^ c)),
                             tag3 | ((a ^ c) + (b ^ c) > (a ^ b)));
                if (res > mod)
                    res -= mod;
            }
    return res;
}

CF1036C

定义一个数字是“好数”,当且仅当它的十进制表示下有不超过319
给定[l,r],问有多少个x使得lxrx是“好数”
1lr1018

太裸了,注意前导0就行了。

CF1734F

S 是一个 Thue-Morse序列。它是一个由以下方式生成的无限长 01 字符串:
最初,令 S 为 "0"。
随后进行以下操作无穷多次:将 S 与各位取反后的 S 连接。以前 4 次操作为例:
次数 S 取反后的 S 操作后得到的S
1 01
2 0110
3 01101001
4 0110100110010110
给定 2 个正整数 nm,求 S0S1Sm1SnSn+1Sn+m1有几位不同。
1n,m1018

一个很巧妙的题!这个字符串每扩充2i的长度,就会翻转一次。也就是说,对一个下标x,每在这个x的二进制串的最前面加一个1,字符就翻转一次。下标二进制串里面1的数量的奇偶性每变化一次都会翻转一下字符。于是题目就变成了在[1,m]里有多少数xx+nx里面1的数量的奇偶性是不一样的。
因为这里面是比较xx+n,加法会产生进位,所以适合递推法来写。
dp[n][x][odd][p]表示前n位,当前填x1的奇偶性为odd,后面的进位为p的数量。枚举一下n1的决策就可以转移了。

点击查看代码
void init()
{
    memset(dp, 0, sizeof dp);
    dp[1][0][0][0] = ((n & 1) == 0);
    dp[1][1][0][0] = ((n & 1) == 0);
    dp[1][0][1][0] = (n & 1);
    dp[1][1][1][0] = (n & 1);
    for (int i = 2; i <= 60; ++i)
    {
        int bit = (n & (1ll << (i - 1))) != 0;
        int pbit = (n & (1ll << (i - 2))) != 0;
        for (int x = 0; x < 2; ++x)
        {
            for (int y = 0; y < 2; ++y)
            {
                for (int p = 0; p < 2; ++p)
                {
                    int pre = y + pbit + p;
                    int cur = x + bit + (pre > 1);
                    for (int odd = 0; odd < 2; ++odd)
                        dp[i][x][odd][pre > 1] += dp[i - 1][y][(cur & 1) ^ x ^ odd][p];
                }
            }
        }
    }
}
 
ll solve(ll m)
{
    if (m == 0)
        return __builtin_popcountll(0) % 2 != __builtin_popcountll(n) % 2;
 
    std::vector<int> digit;
    for (int i = 0; i < 60; ++i)
        digit.push_back((m & (1ll << i)) != 0);
 
    ll res = 0;
    ll last = 0, pren = 0;
    for (int i = 59; i >= 0; --i)
    {
        int nbit = (n & (1ll << i)) != 0;
        if (digit[i] == 1)
        {
            for (int p = 0; p < 2; ++p)
            {
                ll pre = last + pren;
 
                if (p + nbit > 1)
                    pre += (1ll << (i + 1)); 
                bool odd = (__builtin_popcountll(pre) % 2) ^ (__builtin_popcountll(last) % 2);
                res += dp[i + 1][0][odd ^ 1][p];
            }
            last += (1ll << i);
        }
        if (nbit)
            pren += (1ll << i);
    }
    if (__builtin_popcountll(m) % 2 != __builtin_popcountll(m + n) % 2)
        res++;
    return res;
}

CF1290F

n个向量组成一个凸包,让这些凸包落在mm的正方形里,有多少种方案. n5,m109

要保证是凸包,可以将向量按逆时针排列,这样就会发现,假设向量v在凸包里出现了k次,那么它一定是连续的k次。因此如果kivi这些向量可以构成一个凸包,这个凸包就是唯一的。换句话说要求的是n元组ki的数量,使得kivi构成一个凸包。

首先构成凸包一定要闭合。那么i=1nkivi=(0,0)。把横纵坐标单独拿出来就是i=1nkixi=0,i=1nkiyi=0。再将xi里面的正负拆出来,就是i=1nkixi(xi0)=i=1nki(xi)(xi<0),对于y同理。并且这个凸包还要落在mm的矩阵里,那么i=1nkixi(xi0)m, 对于y同理。

接下来就是数位DP,按位枚举n元组的每一位。注意到这里面我们可以选择任意进制来数位DP,进制越小,位数越多,但是每一位的决策越少。枚举n元组每一位,i=1nkixi(xi0)会产生向上的进位,其他同理,因此状态里还要记录四个进位(x0,x<0,y0,y<0四种情况),并且时刻维护着小于m的限制。dp[n][px][py][nx][ny][fx][fy]分别表示前n位,正的x进位为px,正的y进位为py,负进位为nxnyfx表示当前正x的和是否小于mfy表示当前正y的和是否小于m。暴力转移即可。
如果选择2进制来做数位DP,时间复杂度是O(log2m20442n)

点击查看代码
void Main()
{
    int n, m;
    read(n, m);
    for (int i = 1; i <= n; ++i)
        read(x[i], y[i]);
    dp[0][0][0][0][0][0][0] = 1;
    int s = (1 << n) - 1;
    for (int i = 0; i <= s; ++i)
        for (int j = 1; j <= n; ++j)
            if (i & (1 << (j - 1)))
                (x[j] > 0 ? px : nx)[i] += std::abs(x[j]), (y[j] > 0 ? py : ny)[i] += std::abs(y[j]);
    for (int i = 0; i <= 30; ++i)
        for (int a = 0; a <= px[s]; ++a)
            for (int b = 0; b <= py[s]; ++b)
                for (int c = 0; c <= nx[s]; ++c)
                    for (int d = 0; d <= ny[s]; ++d)
                        for (int fx = 0; fx < 2; ++fx)
                            for (int fy = 0; fy < 2; ++fy)
                            {
                                if (dp[i][a][b][c][d][fx][fy] == 0)
                                    continue;
                                for (int mask = 0; mask <= s; ++mask)
                                {
                                    int A = a + px[mask], B = b + py[mask], C = c + nx[mask], D = d + ny[mask];
                                    if ((A & 1) == (C & 1) && (B & 1) == (D & 1))
                                    {
                                        int bit = (m >> i) & 1;
                                        int FX = (A & 1) == bit ? fx : ((A & 1) > bit);
                                        int FY = (B & 1) == bit ? fy : ((B & 1) > bit);
                                        dp[i + 1][A >> 1][B >> 1][C >> 1][D >> 1][FX][FY] += dp[i][a][b][c][d][fx][fy];
                                    }
                                }
                            }
    modint998244353 ans = dp[31][0][0][0][0][0][0] - 1;
    printf("%d\n", ans);
}

CCPC2022广州M

下班前偶然发现的一道题,看了看题然后骑车回家的时候脑子里口胡了一下。
首先式子是一个异或的式子,不难想到按位算贡献。
i位如果有x个1,对value的贡献就是2ix(kx)。那么做数位DP的时候枚举每一位有多少个1,把进位记录在状态里,并且跟n匹配一下,不合法的可以直接剔除掉。
接下来考虑aim的限制。一个显然的结论是用一个bitmask记录哪些数在当前位之前小于m,但是这样复杂度过高了(2k)。不像上一道题五元组里每个数对应的向量是不一样的,这道题的ai之间是没有什么差别的,那么可以把这个bitmask浓缩成一个状态——当前有y个数大于m。以dp[i][j][y]表示前i位,当前进位是j, 有y个数大于m。枚举这一位填的1的数量x,那么这x个数可以分给当前大于m的数一部分,另一部分分给小于等于m数的一部分,然后重新计算出新的状态里有多少个数大于m,再用组合数乘一下算出转移。

点击查看代码
void Main()
{
    C[0][0] = 1;
    for (int i = 1; i <= 20; ++i)
    {
        C[i][0] = 1;
        for (int j = 1; j <= i; ++j)
            C[i][j] = C[i - 1][j - 1] + C[i - 1][j];
    }
 
    ll n, m, k;
    read(n, m, k);
 
    dp[0][0][0] = 1;
    for (int i = 0; i <= 50; ++i)
    {
        int nbit = ((n & (1ll << i)) != 0);
        int mbit = ((m & (1ll << i)) != 0);
        // printf("bit:%d, n:%d, m:%d\n", i + 1, nbit, mbit);
        for (int j = 0; j <= 100; ++j)
            for (int x = 0; x <= k; ++x)
            {
                if (dp[i][j][x] == 0)
                    continue;
                // printf("dp[%d][%d][%d] = %d\n", i, j, x, dp[i][j][x]);
                for (int y = 0; y <= k; ++y)
                {
                    int cur = y * (k - y) + j;
                    if ((cur & 1) != nbit)
                        continue;
                    int mxz = std::min(y, x);
                    for (int z = 0; z <= mxz; ++z)
                    {
                        if (mbit == 1)
                            dp[i + 1][cur >> 1][z] += C[x][z] * C[k - x][y - z] * dp[i][j][x];
                        else
                            dp[i + 1][cur >> 1][x + y - z] += C[x][z] * C[k - x][y - z] * dp[i][j][x];
                    }
                }
            }
    }
    printf("%d\n", dp[51][0][0]);
}
posted @   KSYImba  阅读(32)  评论(2编辑  收藏  举报
相关博文:
阅读排行:
· 单线程的Redis速度为什么快?
· 展开说说关于C#中ORM框架的用法!
· Pantheons:用 TypeScript 打造主流大模型对话的一站式集成库
· SQL Server 2025 AI相关能力初探
· 为什么 退出登录 或 修改密码 无法使 token 失效
点击右上角即可分享
微信分享提示