「解题报告」[省选联考 2022] 卡牌

放假上午想出的做法,写了一下 TLE 35 分。

以为有更高级的复杂度,然后刚看了看题解发现题解就是这个复杂度,呃呃,卡常吧。


考虑将每个数写成它所包含的质因子的集合,写成一个 01 串的形式。那么我们可以看作是选若干数,或起来,求有多少种方案使得结果与上某次询问等于这个询问。

也就是求:

\[\sum_{T \subset S} \left[q \mathop{\rm and} \left(\mathop{\rm or}_{x \in T} x\right) = q \right] \]

首先我们求出每个数能被多少集合或出来,可以看作有 \(n\) 个幂级数满足 \(a_0 = a_x = 1\),然后求这 \(n\) 个幂级数的或卷积。最后求 \(\sum_{x \mathop{\rm and} q = q} a_x\),这就是与卷积的 FWT(或者说高维后缀和)。

第一步与卷积我们肯定不能卷 \(n\) 次,我们可以先将每个式子 FWT 出来,发现得出的数肯定是若干个 \(1\),若干个 \(2\)。那么,既然我们要求所有数的对应点值处的乘积,我们可以先取一个 \(\log_2\),这样我们只需要求每个数的对应点值处的和,而发现取完对数之后其实还是一个高维前缀和的形式,而我们可以将所有的高维前缀和放到一起来做,于是只需要令 \(a_x \gets a_x + 1\),做一遍高维前缀和,再 \(a_x \gets 2^{a_x}\),再高维差分(IFWT)回去即可。

那么这道题就做完了..吗?

值域为 \(2000\),有约 \(300\) 个质数,还是不能直接卷。

我们先梳理一遍我们需要支持的操作:

  1. \(a_x \gets a_x + 1\)
  2. \(a\) 高维前缀和
  3. \(a_x \gets 2^{a_x}\)
  4. \(a\) 高维差分
  5. \(a\) 高维后缀和

虽然有约 \(300\) 个质数,但是众所周知,大于根号的质数最多只出现 \(1\) 次,所以我们可以将每个数拆成高位和低位来看。而小于等于根号的质数只有 \(14\) 个,\(2^{14}\) 的卷积看起来就很可做。

我们先按照最高位分组,对每一组做一下前三个操作。

考虑此时最高位多于一个 \(1\) 的位置:发现高维前缀和后,最高位不是 \(1\) 的位置就是将它高位每个 \(1\) 加起来
再加上高位为 \(0\) 的答案,而进行 \(a_x \gets 2^{a_x}\) 后就变成了乘积。

第四步和第五步都是 FWT 的形式,我们可以将两步合起来,对每一位进行考虑。

拿最高位举例子,我们有两个数 \(a_{0x}, a_{1x}\) 表示最高位为 \(0/1\),剩下的位都是 \(x\),那么进行差分与后缀和后,两个数分别就变成了 \(a_{1x}, a_{1x} - a_{0x}\)

而我们上面得出了每个数可以拆成若干个数的乘积的结论,于是我们可以拆为 \(a_{0x} a_{10}, a_{0x}(a_{10} - 1)\)

发现这样一直拆完之后,每个数仍然是乘积的形式,如果这一位为 \(9\) 贡献为 \(a_{10}\),为 \(1\) 贡献为 \((a_{10} - 1)\)

所以每次询问我们可以利用这个规律先对高位进行 4 5 操作,然后对低位再暴力 FWT。我们可以先处理出高位都是 \(9\) 时的结果,对于每一个 \(1\) 再处理一下贡献即可。由于 \(\sum c_i \le 18000\),总共改变的 \(1\) 的数量就是 \(\sum c_i\),这部分的复杂度为 \(O(\sum c_i 2^p)\)\(p\) 为根号内的质数个数)。而每次暴力 FWT 的复杂度为 \(O(mp 2^p)\),所以总复杂度就是 \(O(\sum c_i + mp )2^p\)

一些常数优化:

  1. 取模优化
  2. 虽然是小于等于根号的质数,但是 \(43^2 = 1849\) 可以放在大于根号的里面,因为 \(43 \times 47 = 2021\),这样 \(p=13\)
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 1000005, P = 998244353;
// const int PRI[16] = { 2, 3, 5, 7 }, PCNT = 4;
// vector<int> lpri = { 11, 13, 17, 19, 23 };
const int PRI[16] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41 }, PCNT = 13;
vector<int> lpri = { 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 
    97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 
    157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 
    227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 
    283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 
    367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 
    439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 
    509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 
    599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 
    661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 
    751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 
    829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 
    919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997, 
    1009, 1013, 1019, 1021, 1031, 1033, 1039, 1049, 1051, 1061, 
    1063, 1069, 1087, 1091, 1093, 1097, 1103, 1109, 1117, 1123, 
    1129, 1151, 1153, 1163, 1171, 1181, 1187, 1193, 1201, 1213, 
    1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283, 
    1289, 1291, 1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 
    1367, 1373, 1381, 1399, 1409, 1423, 1427, 1429, 1433, 1439, 
    1447, 1451, 1453, 1459, 1471, 1481, 1483, 1487, 1489, 1493, 
    1499, 1511, 1523, 1531, 1543, 1549, 1553, 1559, 1567, 1571, 
    1579, 1583, 1597, 1601, 1607, 1609, 1613, 1619, 1621, 1627, 
    1637, 1657, 1663, 1667, 1669, 1693, 1697, 1699, 1709, 1721, 
    1723, 1733, 1741, 1747, 1753, 1759, 1777, 1783, 1787, 1789, 
    1801, 1811, 1823, 1831, 1847, 1861, 1867, 1871, 1873, 1877, 
    1879, 1889, 1901, 1907, 1913, 1931, 1933, 1949, 1951, 1973, 
    1979, 1987, 1993, 1997, 1999 };
int mp[2002];
int qpow(int a, int b) {
    int ans = 1;
    while (b) {
        if (b & 1) ans = 1ll * ans * a % P;
        a = 1ll * a * a % P;
        b >>= 1;
    }
    return ans;
}
int n, m, s[MAXN];
int a[300][17000];
int ainv[300][17000];
int pow2[MAXN], pinv2[MAXN];
int b[17000];
int c[17000];
int cnt[MAXN];
void add(int &x, int y) {
    x += y;
    if (x >= P) x -= P;
}
void del(int &x, int y) {
    x -= y;
    if (x < 0) x += P;
}
void fwtor(int a[], int n, bool rev) {
    for (int mid = 1; mid < n; mid <<= 1) {
        for (int l = 0; l < n; l += (mid << 1)) {
            for (int i = 0; i < mid; i++) {
                a[l + i + mid] = (a[l + i + mid] + (rev ? P - 1ll : 1ll) * a[l + i]) % P;
            }
        }
    }
}
void fwt2(int a[], int n) {
    for (int mid = 1; mid < n; mid <<= 1) {
        for (int l = 0; l < n; l += (mid << 1)) {
            for (int i = 0; i < mid; i++) {
                swap(a[l + i], a[l + i + mid]), a[l + i + mid] = (a[l + i] - a[l + i + mid] + P) % P;
            }
        }
    }
}
int read() {
    int x = 0;
    char ch = getchar();
    while (ch < '0') ch = getchar();
    while (ch >= '0') x = (x * 10 + ch - '0'), ch = getchar();
    return x;
}
int main() {
    for (int i = 0; i < PCNT; i++) {
        mp[PRI[i]] = i;
    }
    for (int i = 0; i < lpri.size(); i++) {
        mp[lpri[i]] = i;
    }
    n = read();
    for (int i = 1; i <= n; i++) {
        s[i] = read();
        int w = 0;
        for (int j = 0; j < PCNT; j++) {
            if (s[i] % PRI[j] == 0) {
                while (s[i] % PRI[j] == 0) s[i] /= PRI[j];
                w |= 1 << j;
            }
        }
        if (s[i] != 1) {
            if (s[i] == 1849) s[i] = 43;
            a[mp[s[i]] + 1][w]++;
        } else {
            a[0][w]++;
        }
    }
    for (int i = 0; i <= lpri.size(); i++) {
        for (int mid = 1; mid < (1 << PCNT); mid <<= 1) {
            for (int l = 0; l < (1 << PCNT); l += (mid << 1)) {
                for (int j = 0; j < mid; j++) {
                    add(a[i][l | j | mid], a[i][l | j]);
                }
            }
        }
    }
    pow2[0] = pinv2[0] = 1;
    pow2[1] = 2, pinv2[1] = (P + 1) / 2;
    for (int i = 2; i <= n; i++) {
        pow2[i] = 2ll * pow2[i - 1] % P;
        pinv2[i] = 1ll * pinv2[i - 1] * pinv2[1] % P;
    }
    for (int i = 0; i <= lpri.size(); i++) {
        for (int j = 0; j < (1 << PCNT); j++) {
            int ori = a[i][j];
            a[i][j] = pow2[a[i][j]];
            ainv[i][j] = 1ll * pinv2[ori] * (a[i][j] - 1 + P) % P;
        }
    }
    for (int x = 0; x < (1 << PCNT); x++) {
        b[x] = a[0][x];
        for (int i = 1; i <= lpri.size(); i++) {
            b[x] = 1ll * b[x] * a[i][x] % P;
        }
    }
    m = read();
    for (int i = 1; i <= m; i++) {
        int cc = read();
        bool flag = false;
        int x = 0;
        vector<int> viss;
        for (int j = 1; j <= cc; j++) {
            int w = read(); 
            if (w <= PRI[PCNT - 1]) {
                x |= 1 << mp[w];
            } else {
                viss.push_back(mp[w] + 1);
            }
        }
        for (int x = 0; x < (1 << PCNT); x++) {
            c[x] = b[x];
        }
        for (int x = 0; x < (1 << PCNT); x++) {
            for (int i : viss) {
                c[x] = 1ll * c[x] * ainv[i][x] % P;
            }
        }
        for (int mid = 1; mid < (1 << PCNT); mid <<= 1) {
            for (int l = 0; l < (1 << PCNT); l += (mid << 1)) {
                for (int i = 0; i < mid; i++) {
                    swap(c[l | i], c[l | i | mid]), c[l | i | mid] = c[l | i] - c[l | i | mid];
                    if (c[l | i | mid] < 0) c[l | i | mid] += P;
                }
            }
        }
        printf("%d\n", c[x]);
    }
    return 0;
}

posted @ 2023-02-07 10:25  APJifengc  阅读(51)  评论(0编辑  收藏  举报