【题解报告】杭电多校 Round 2

补题链接:hdu vjudge

1004 -- I love counting

题意简介

给你一个数组 \(c_i\) ,和 \(q\) 个询问,问你对于一个区间 \([l,r]\) 里有多少种不同\(c\) 值,使得 \(c \oplus a \leq b\)

思路分析

巧了,第一场有一道 Trie 和一道莫队,俩揉起来改一改就过了。

因为 Trie 和莫队的分析上一场写过了,这一次就不写了。还是那个分类讨论的思路。

需要注意的是,这次要维护的东西是种数。

参考杭电第一场的 1006 和 1010 。具体实现看代码。

解题代码

#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
const int N = 1e5 + 5;
int n, c[N], q, blen, Ans[N], to[N * 18][2], cnt[N * 18], cn = 1, kind[N * 18];
struct Query {
    int l, r, a, b, i, l_bid;
    Query() {
        l = r = a = b = i = 0;
        l_bid = 0;
    }
    bool operator < (const Query &B) const {
        if(l_bid != B.l_bid) return l_bid < B.l_bid;
        if(l_bid & 1) return r < B.r;
        return r > B.r;
    }
}Q[N];
int tot;
inline int insert(int x) {
    int u = 1;
    for(int i = 17; i >= 0; i--) {
        int bit = 0;
        if((1 << i) & x) bit = 1;
        if(!to[u][bit]) to[u][bit] = ++cn;
        u = to[u][bit]; cnt[u]++;
    }
    if(cnt[u] == 1) {
        u = 1;
        for(int i = 17; i >= 0; i--) {
            int bit = 0;
            if((1 << i) & x) bit = 1;
            u = to[u][bit]; kind[u]++;
        }
    }
    return cnt[u];
}
inline int decrease(int x) {
    int u = 1;
    for(int i = 17; i >= 0; i--) {
        int bit = 0;
        if((1 << i) & x) bit = 1;
        if(!to[u][bit]) to[u][bit] = ++cn;
        u = to[u][bit]; cnt[u]--;
    }
    if(cnt[u] == 0){
        u = 1;
        for(int i = 17; i >= 0; i--) {
            int bit = 0;
            if((1 << i) & x) bit = 1;
            u = to[u][bit]; kind[u]--;
        }
    }
    return cnt[u];
}
inline void add(int x) {
    if(insert(x) == 1) tot++;
}
inline void sub(int x) {
    if(decrease(x) == 0) tot--;
}
inline int getAns(int a, int b) {
    int u = 1, cc = 0;
    for(int i = 17; i >= 0 && u; i--) {
        int bit_a, bit_b;
        bit_b = (b & (1 << i)) ? 1 : 0;
        bit_a = (a & (1 << i)) ? 1 : 0;
        if(bit_a < bit_b) {
            u = to[u][1];
            continue;
        }
        if(bit_a > bit_b) {
            cc += kind[to[u][0]];
            u = to[u][1];
            continue;
        }
        if(bit_a == bit_b && bit_a == 1) {
            u = to[u][0];
            continue;
        }
        if(bit_a == bit_b) {
            cc += kind[to[u][1]];
            u = to[u][0];
            continue;
        }
    }
    return tot - cc;
}
int main() {
    scanf("%d", &n);
    for(int i = 1; i <= n; i++) {
        scanf("%d", c + i);
    }
    blen = (int) sqrt(n + 0.5) + 1;
    scanf("%d", &q);
    for(int i = 1; i <= q; i++) {
        scanf("%d%d%d%d", &Q[i].l, &Q[i].r, &Q[i].a, &Q[i].b);
        Q[i].l_bid = Q[i].l / blen + 1;
        Q[i].i = i;
    }
    sort(Q + 1, Q + q + 1);
    int l = 1, r = 0;
    for(int i = 1; i <= q; i++) {
        while(r < Q[i].r) add(c[++r]);
        while(l > Q[i].l) add(c[--l]);
        while(r > Q[i].r) sub(c[r--]);
        while(l < Q[i].l) sub(c[l++]);
        Ans[Q[i].i] = getAns(Q[i].a, Q[i].b);
    }
    for(int i = 1; i <= q; i++) {
        printf("%d\n", Ans[i]);
    }
    return 0;
}

1005 -- I love string

题意简介

给你 \(n\) 个字符,要求你按从左往右插入一个空字符串中,除第一次插入外,插入时可以插在字符串左端,也可以插在右端,问你有多少种插入方法,可以字符串是所有能得到的字符串中字典序最大的。

思路分析

很显然,当第一个不同的字符被要求插入的那一刻起,想要保证字符串字典序最大,后面的字符的插入方式就被固定下来了。

因为想要保证一个字符在插入后,字符串字典序最大,唯一能有两种插法的情况是前缀字符等于后缀。

嗯,自己手玩一下就看出来了。

解题代码

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 1e5 + 5, mod = 1e9 + 7;
char s[N]; int T;
int main() {
    scanf("%d", &T);
    while(T--) {
        bool equal = 1;
        int n; scanf("%d %s", &n, s + 1);
        ll ans = 1;
        for(int i = 2; i <= n; i++) {
            if(!equal) break;
            if(s[i] == s[i-1]) ans = ans * 2 % mod;
            else equal = 0;
        }
        printf("%lld\n", ans);
    }
    return 0;
}

1007 -- I love data structure

题意简介

给你两个数组 \(A_i, B_i\) ,然后有四种操作:

  1. 将区间 \([l, r]\) 的所有 \(A_i\) ( \(B_i\) ) 加上 \(x\)

  2. 将区间 \([l, r]\) 的所有 \(A_i\) 变成 \(3\times A_i + 2\times B_i\)\(B_i\) 变成 \(3\times A_i - 2 \times B_i\)

  3. 将区间 \([l, r]\)\(A, B\) 互换

  4. 查询区间 \(\sum A_i\times B_i\)

思路分析

不,你不喜欢。

线段树。可以发现,对于操作 \(2\)\(3\) 可以视为对一个二维向量乘上一个 \(2\) 阶矩阵。

可以发现,对于操作二,执行后,\(a'\times b' = 9a^2 - 4b^2\) 。所以,我们考虑多维护 \(a^2, b^2\)

对于区间加,我们则可以通过在最后一列加一个 \(1\) 来转移。

因此,其实我们所有的修改操作都可以用一个 \(6\times 6\) 的矩阵来维护。

复杂度大致算了一下,5s 呢,万一过了呢。

然后 T 了。但出题人题解说放过了 6 阶的做法,也许是我自带大常数吧。

然后仔细一想,\(6\) 阶矩阵似乎无法用来直接简单地维护加法。因为我们存的都是区间和,区间加时,我们对当前值的维护是需要乘上区间长度。

所以,既然无论如何加法都是需要特判的,我们不妨把加法搞出来。

于是,有了 5 阶矩阵的做法。

再仔细一想,其实两阶就够了。事实上,针对操作 2、3,我们可以维护一个两阶的转移矩阵。

单次操作的具体矩阵参考代码,不会 latex 。

我们记这个矩阵每行从左往右的四个元素为 \(i, j, k, w\) ,则,我们有:

\(a' = i \times a + k \times b, b' = j \times a + w \times b\)

于是,对于我们需要维护的东西,我们可以推出:

\(a'^2 = i^2 \times a^2 + k^2 \times b^2 + 2\times i\times k\times a\times b\)

\(b^2, ab\) 同理,我就不写出来了。

接下来考虑操作 1。

对于两种不同操作的懒标记,我们通常需要给这两种标记一个固定的顺序。且为了保证正确性,我们通常需要让其中一种操作对另一种操作的懒标记产生影响。

这道题,显然是矩阵乘法对加法懒标记产生影响更合适。

也就是说,当加法覆盖在乘法上时,我们直接算。当乘法覆盖在加法上时,我们需要更新加法的标记。然后在向下转移时,我们先转移乘法,再转移加法。

简单地推一下,记当前 \(a' = a + x, b' = b + y\)

我们有 \(a^2 = a^2 + 2 x a + x^2\) ,其它同理。

需要注意的一点是,我们维护的所有的值,实际上都是区间和,所以当你推式子发现一个常数项时,这个常数项需要乘以区间的长度

而且由于很多数字都可能是 \(1e9\) 级别的,连乘要记得取模

我因为有个地方没乘常数项查了两小时,然后因为补上去后忘了取模又查了两个小时。

解题代码

#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
typedef unsigned long long ll;
const ll N = 2e5 + 5, mod = 1e9 + 7;
#define ls (u << 1)
#define rs ((u << 1) + 1)
inline char getc() {
    static char buf[1 << 14], *p1 = buf, *p2 = buf;
    return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 14, stdin), p1 == p2) ? EOF : *p1++;
}
inline int scan() {
    int x = 0; char ch = 0;
    while(ch < 48) ch = getc();
    while(ch >= 48) x = x * 10 - 48 + ch;
    return x;
}
const ll init[2][2] = {{1, 0}, {0, 1}};
struct Node{
    ll len, val[2], add[2], laz[2][2], pow_a, pow_b, ab; 
    bool flag;
    Node(ll a = 0, ll b = 0) {
        val[0] = a; val[1] = b;
        pow_a = a * a % mod;
        pow_b = b * b % mod;
        ab = a * b % mod;
        add[0] = add[1] = laz[1][0] = laz[0][1] = 0;
        laz[0][0] = laz[1][1] = 1;
        flag = 0; len = 1;
    }
}tr[N << 2];
inline void maintain(int u) {
    tr[u].pow_a = (tr[ls].pow_a + tr[rs].pow_a) % mod;
    tr[u].pow_b = (tr[ls].pow_b + tr[rs].pow_b) % mod;
    tr[u].ab = (tr[ls].ab + tr[rs].ab) % mod;
    tr[u].val[0] = (tr[ls].val[0] + tr[rs].val[0]) % mod;
    tr[u].val[1] = (tr[ls].val[1] + tr[rs].val[1]) % mod;
}
inline void mult1(ll *A, ll B[][2]) {
    ll a = A[0], b = A[1];
    A[0] = (a * B[0][0] % mod + b * B[1][0] % mod) % mod;
    A[1] = (a * B[0][1] % mod + b * B[1][1] % mod) % mod;
}
inline void mult2(ll A[][2], ll B[][2]) {
    static ll temp[2][2];
    memcpy(temp, A, sizeof(temp));
    for(int i = 0; i < 2; i++) for(int j = 0; j < 2; j++) {
        A[i][j] = 0;
        for(int k = 0; k < 2; k++) A[i][j] = (A[i][j] + temp[i][k] * B[k][j] % mod) % mod;
    }
}
inline void modify(int u, ll zl[][2]) {
    ll a, b, aa, bb, ab, i, j, k, w;
    a = tr[u].val[0], b = tr[u].val[1];
    aa = tr[u].pow_a, bb = tr[u].pow_b, ab = tr[u].ab;
    i = zl[0][0], j = zl[0][1], k = zl[1][0], w = zl[1][1];
    tr[u].pow_a = (i * i % mod * aa % mod + i * k % mod * ab % mod * 2ll % mod + k * k % mod * bb % mod) % mod;
    tr[u].pow_b = (j * j % mod * aa % mod + j * w % mod * ab % mod * 2ll % mod + w * w % mod * bb % mod) % mod;
    tr[u].ab = (i * j % mod * aa % mod + k * w % mod * bb % mod + (i * w % mod + k * j % mod) % mod * ab % mod) %mod;
    mult1(tr[u].val, zl);
    mult1(tr[u].add, zl);
    mult2(tr[u].laz, zl);
    tr[u].flag = 1;
}
inline void increase(int u, ll ad[2]) {
    ll a = tr[u].val[0], b = tr[u].val[1], aa = tr[u].pow_a, bb = tr[u].pow_b;
    tr[u].val[0] = (a + ad[0] * tr[u].len % mod) % mod;
    tr[u].val[1] = (b + ad[1] * tr[u].len % mod) % mod;
    tr[u].add[0] = (tr[u].add[0] + ad[0]) % mod;
    tr[u].add[1] = (tr[u].add[1] + ad[1]) % mod;
    // 记 得 这 里 要 乘 tr[u].len !!!!!! 乘 完 记 得 取 模 !!!!!!
    tr[u].pow_a = (aa + 2ll * ad[0] % mod * a % mod + tr[u].len * ad[0] % mod * ad[0] % mod) % mod;
    tr[u].pow_b = (bb + 2ll * ad[1] % mod * b % mod + tr[u].len * ad[1] % mod * ad[1] % mod) % mod;
    tr[u].ab = (tr[u].ab + ad[0] * b % mod + ad[1] * a % mod + tr[u].len * ad[0] % mod * ad[1] % mod) % mod;
}
inline void pushdown(int u) {
    // ll test0 = tr[u].add[0], test1 = tr[u].add[1];
    if(tr[u].flag) {
        tr[u].flag = 0;
        modify(ls, tr[u].laz), modify(rs, tr[u].laz);
        // if(tr[u].ab != (tr[ls].ab + tr[rs].ab) % mod) puts("err"), exit(0);
        memcpy(tr[u].laz, init, sizeof(init));
    }
    if(tr[u].add[0] || tr[u].add[1]) {
        increase(ls, tr[u].add);
        increase(rs, tr[u].add);
        // if(tr[u].pow_a != (tr[ls].pow_a + tr[rs].pow_a) % mod) puts("err"), exit(0);
        // if(tr[u].pow_b != (tr[ls].pow_b + tr[rs].pow_b) % mod) puts("err"), exit(0);
        // if(tr[u].ab != (tr[ls].ab + tr[rs].ab) % mod) puts("err"), exit(0);
        // if(tr[u].val[1] != (tr[ls].val[1] + tr[rs].val[1]) % mod) puts("err"), exit(0);
        // if(tr[u].add[0] != test0 || tr[u].add[1] != test1) puts("err"), exit(0);
        tr[u].add[0] = tr[u].add[1] = 0;
    }
}
void build(int u, int l, int r) {
    if(l == r) {
        ll a, b; 
        scanf("%lld%lld", &a, &b);
        tr[u] = Node(a, b);
    }else {
        int mid = (l + r) >> 1;
        build(ls, l, mid), build(rs, mid + 1, r);
        tr[u].len = r - l + 1;
        // tr[u].len = tr[ls].len + tr[rs].len;
        maintain(u);
    }
}
int opt, ql, qr, mark; ll xval;
void deal(int u) {
    static ll ADD[2] = {0, 0};
    static ll CHANGE[2][2] = {{3, 3}, {2, mod - 2}};
    static ll SWAP[2][2] = {{0, 1}, {1, 0}};
    switch(opt) {
        case 1: {
            ADD[mark] = xval, ADD[mark ^ 1] = 0;
            increase(u, ADD);
            break;
        }
        case 2: {
            modify(u, CHANGE);
            break;
        }
        case 3: {
            modify(u, SWAP);
            break;
        }
        default: break;
    }
}
void update(int u, int l, int r) {
    if(l > qr || r < ql) return;
    if(ql <= l && r <= qr) return deal(u);
    int mid = (l + r) >> 1;
    pushdown(u);
    update(ls, l, mid), update(rs, mid + 1, r);
    maintain(u);
}
ll query(int u, int l, int r) {
    if(l > qr || r < ql) return 0;
    if(ql <= l && r <= qr) return tr[u].ab;
    int mid = (l + r) >> 1;
    pushdown(u);
    return (query(ls, l, mid) + query(rs, mid + 1, r)) % mod;
}
void debug(int u, int l, int r) {
    if(l == r) {
        // printf("a=%lld b=%lld ab=%lld a^2=%lld b^2=%lld\n", tr[u].val[0], tr[u].val[1], tr[u].ab, tr[u].pow_a, tr[u].pow_b);
        return;
    }
    int mid = (l + r) >> 1;
    // printf("For [%d, %d]: ", l, r);
    // printf("a=%lld b=%lld ab=%lld a^2=%lld b^2=%lld\n", tr[u].val[0], tr[u].val[1], tr[u].ab, tr[u].pow_a, tr[u].pow_b);
    pushdown(u);
    debug(ls, l, mid), debug(rs, mid + 1, r);
}
int main() {
    // freopen("in.txt", "r", stdin);
    // freopen("out.txt", "w", stdout);
    int n, q; 
    scanf("%d", &n);
    build(1, 1, n);
    scanf("%d", &q);
    while(q--) {
        scanf("%d", &opt);
        if(opt == 1) {
            scanf("%d%d%d%lld", &mark, &ql, &qr, &xval);
            update(1, 1, n);
            // debug(1, 1, n);
        }else if(opt != 4){
            scanf("%d%d", &ql, &qr);
            update(1, 1, n);
            // debug(1, 1, n);
        }else {
            scanf("%d%d", &ql, &qr);
            printf("%lld\n", query(1, 1, n));
        }
    }
    return 0;
}

1008 -- I love exam

题意简介

给你 \(n \leq 50\) 个科目,然后你一开始每个科目都是 \(0\) 分。

再给你 \(m \leq 15000\) 个复习材料。每个复习材料只能学一次,且只有一个对应的科目,学它需要花 \(y\leq 10\) 天,学完这一科涨 \(x\leq 10\) 分。

然后你现在还剩下 \(t\leq 500\) 的时间可以复习。

如果某一科分数严格小于 \(60\) ,就是挂科。你最多只能挂 \(p\leq 3\) 科,挂多了会被麻麻打洗。

问你在保证挂科数量不大于 \(p\) 的情形下,总分最大能是多少。

思路分析

一眼背包,然后复杂度越看越假。还手滑wa了一发

首先,我们可以预处理出对于科目 \(i\),学 \(j\) 天,最多能拿的分数 \(g(i, j)\)

预处理出这个之后,我们就可以搞出这个转移方程:

\(f(o, d, p, 0/1)\) 表示当前学到科目 \(o\) ,已经学了 \(d\) 天,算上这一科共挂了 \(p\) 科,当前这一科有没有挂 的最大总分。

于是,\(g(o, k) < 60\) 时 ,有 \(f(o, d, p, 1) = \max f(o - 1, d - k, p - 1, 0/1) + g(o, k)\)

这一科没挂,则有 \(f(o, d, p, 0) = \max f(o - 1, d - k, p, 0/1) + g(o, k)\)

复杂度是 \(O(nt^2)\) ,可以接受。

于是问题就在于如何求 \(g(i, j)\)

(这里我当时想复杂le)

因为 \(m \leq 15000\) ,直接用 01背包 去求的话,复杂度会是 \(O(n + mt)\) 左右,其实完全不会 T,但是我当时被这个数据范围吓坏了,没去仔细算,于是乎满脑子优化。

我们可以观察到,材料的种数其实最多有 \(10 \times 10\) 种,于是我们可以用一个二维数组来存它。

然后我们可以发现,它可能会有大量重复的材料。如果重复的材料数比 \(t/y\) 还多,意味着我们就算一直学这个也没关系。所以对于这种物品,我们当成无限物品的背包来看,只跑一遍。

此外,我们不难看出,对于同一种学习材料,如果前面的第某次扫时已经不会更新最大值,那后面的一样不会再更新最大值。所以可以直接 break 出来。

当然还有其它的优化思路就是了。虽然其实都不需要用上qwq

解题代码

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 55, D = 505;
int T, n, to[N * 20][26], id[N * 20], cn, m, t, p; 
char s[20];
// g 表示第 i 门课,学了 d 天的最大分数
// f 表示到第 i 门课,已经学了 d 天,挂了 c 门,当前这门的状态
int cnt[N][12][12], g[N][D], f[N][D][4][2];
inline void insert(int x) {
    int len = strlen(s + 1), u = 1;
    for(int i = 1; i <= len; i++) {
        int d = s[i] - 'a';
        if(!to[u][d]) to[u][d] = ++cn;
        u = to[u][d];
    }
    id[u] = x;
}
inline int find() {
    int len = strlen(s + 1), u = 1;
    for(int i = 1; i <= len; i++) {
        int d = s[i] - 'a';
        if(!to[u][d]) to[u][d] = ++cn;
        u = to[u][d];
    }
    return id[u];
}
inline void calc(int x) {
    // i 是天数,j 是分数
    for(int i = 1; i <= 10; i++) {
        for(int j = 10; j >= 1; j--) {
            if(cnt[x][i][j] >= t / i) {
                for(int k = 1; k <= t; k++) {
                    if(k - i >= 0) {
                        g[x][k] = max(g[x][k], g[x][k - i] + j);
                    }
                }
            }else {
                int cc = cnt[x][i][j];
                for(int c = 1; c <= cc; c++) {
                    bool change = 0;
                    for(int k = t; k >= 1; k--) {
                        if(k - i >= 0) {
                            if(g[x][k] < g[x][k - i] + j) {
                                change = 1;
                                g[x][k] = g[x][k - i] + j;
                            }
                        }
                    }
                    if(!change) break;
                }
            }
        }
    }
    /*printf("For subject %d:\n", x);
    for(int i = 1; i <= t; i++) {
        printf("%d ", g[x][i]);
    }
    puts("");*/
}
int main() {
    scanf("%d", &T);
    while(T--) {
        scanf("%d", &n);
        cn = 1;
        memset(to, 0, sizeof(to));
        memset(f, 0, sizeof(f));
        memset(g, 0, sizeof(g));
        memset(cnt, 0, sizeof(cnt));
        memset(id, 0, sizeof(id));
        for(int i = 1; i <= n; i++) {
            scanf(" %s", s + 1);
            insert(i);
        }
        scanf("%d", &m);
        for(int i = 1; i <= m; i++) {
            int x, y, o;
            scanf(" %s%d%d", s + 1, &x, &y);
            o = find();
            cnt[o][y][x]++;
        }
        scanf("%d%d", &t, &p);
        for(int i = 1; i <= n; i++) calc(i);
        for(int i = 1; i <= n; i++) {
            for(int j = 1; j <= t; j++) {
                for(int d = 0; d <= j; d++) {
                    int sco = g[i][d], del = 0;
                    if(sco > 100) sco = 100;
                    if(sco < 60) del = 1;
                    for(int l = del; l <= p; l++) {
                        f[i][j][l][del] = max(f[i - 1][j - d][l - del][0] + sco, f[i][j][l][del]);
                        f[i][j][l][del] = max(f[i - 1][j - d][l - del][1] + sco, f[i][j][l][del]);
                    }
                }
            }
        }
        int res = 0;
        for(int i = 0; i <= p; i++) {
            for(int j = 0; j < 2; j++) {
                for(int d = 1; d <= t; d++) {
                    res = max(res, f[n][d][i][j]);
                }
            }
        }
        /*puts("print f(i, d, gua, now):");
        for(int i = 1; i <= n; i++) {
            for(int j = 1; j <= t; j++) {
                printf("day = %d, sub = %d:\n", j, i);
                for(int gua = 0; gua <= p; gua++) {
                    printf("  {%d %d}", f[i][j][gua][0], f[i][j][gua][1]);
                }
                puts("");
            }
        }*/
        if(res >= n * 60) printf("%d\n", res);
        else printf("-1\n");
    }
    return 0;
}

1011 -- I love max and multiply

题意简介

给你俩长度为 \(n\) 的数组 \(A_i, B_i\) ,要你对每个位置 \(k\) ,求出 \(C_k = \max A_i\times B_j\) 满足 \(i & j \geq k\)

然后输出 \(\sum C_i\)

注意:

  1. 题目的两个数组含负数

  2. 下标从 0 开始

思路分析

简单 dp 。

我们先考虑什么情况下有 \(i \& j \geq k\)

第一种显而易见的情况是,\(k\) 的所有 1 位都与 \(i, j\) 有重叠。

比如 \(k = (10101)_2, i = (11101)_2, j = (10111)_2\)

第二种,虽然没有完全重叠,但 \(i\&j\) 多出来的那一位 1 使得自己比 \(k\) 更大。

比如 \(k = (10101)_2, i = (11001)_2, j = (11010)_2\)

我们看看要求的是什么。是积的最大值。显然地,我们让 \(A_i\)\(B_j\) 分别达到最大或最小(因为有负数)不就行了。

我们先对情况 1,我们这样维护(数组 \(B_i\) 、最小值维护类似):

\(fa(k)\)\(i \& k = k\) 的所有 \(A_i\) 的最大值,有 \(fa(k) = \max \{ fa(1<<p | k), A_k\}\)

我们姑且让 \(C_k = fa(k) \times fb(k)\) ,这样我们至少保证了选出来的 \((i,j)\) 是符合题意的,即 \(i\&j \geq k\)

再看看情况 2,不难想到,情况 2 的所有可能其实都包含在了 \(C(k)\) 的后缀里了。

例如上面举例的情况,实际上在 \(k = (11000)_2\) 会被考虑到。

因此,我们只需要再对 \(C_k\) 求一个后缀最大值即可。

记得比较过程别急着取模,最后求和再去取。

解题代码

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 6e5 + 5, mod = 998244353, INF = 1e9;
int T; ll f1[N], f2[N], g1[N], g2[N], n, f[N];
int main() {
    scanf("%d", &T);
    while(T--) {
        for(int i = 0; i < (1<<18); i++) {
            f1[i] = f2[i] = -INF;
            g1[i] = g2[i] = INF;
        }
        scanf("%lld", &n);
        for(int i = 0; i < n; i++) {
            scanf("%lld", g1 + i);
            f1[i] = g1[i];
        }
        for(int i = 0; i < n; i++) {
            scanf("%lld", g2 + i);
            f2[i] = g2[i];
        }
        ll res = 0;
        for(int i = n - 1; i >= 0; i--) {
            for(int j = 0; j < 18; j++) {
                if((1 << j) & i) continue;
                int p = i ^ (1 << j);
                f1[i] = max(f1[i], f1[p]);
                f2[i] = max(f2[i], f2[p]);
                g1[i] = min(g1[i], g1[p]);
                g2[i] = min(g2[i], g2[p]);
            }
        }
        for(int i = 0; i < n; i++) {
            f[i] = max(f1[i] * f2[i], g1[i] * g2[i]);
            f[i] = max(f[i], f1[i] * g2[i]);
            f[i] = max(f[i], f2[i] * g1[i]);
        }
        for(int i = n - 2; i >= 0; i--) {
            f[i] = max(f[i], f[i + 1]);
            // printf("%lld ", f[i]);
        }
        // puts("");
        for(int i = 0; i < n; i++) {
            if(f[i] < 0) f[i] = (f[i] - f[i] / mod * mod + mod) % mod;
            else f[i] = f[i] % mod;
            res = (res + f[i]) % mod;
        }
        printf("%lld\n", res);
    }
    return 0;
}

(来自 Overslept everyday

posted @ 2021-07-25 10:59  喵乖乖喵  阅读(222)  评论(0编辑  收藏  举报

膜拜众神

网安院技术部     ZZY大师     Xinyang 大佬     Wjyyy