【Cf #449 C】Willem, Chtholly and Seniorious(set维护线段)
这里介绍以个小$trick$,民间流传为$Old Driver Tree$,实质上就是$set$维护线段。
我们将所有连续一段权值相同的序列合并成一条线段,扔到$set$里去,于是$set$里的所有线段的并就是原序列,并且都不相交。
我们在操作的时候很暴力,每次把$[l, r]$的线段抠出来,暴力枚举一遍算答案。对于每一个非区间赋值的操作,最多断两条线段,新加两条线段。
实现起来很方便,思路也非常简单,但是局限性也很明显,因为其复杂度是基于随机的,并且必须存在区间赋值的操作。
但$set$维护线段的技巧还是很常见的,我用$set$和$map$都实现的一遍,发现这里用$map$相当好写。
$set$的实现:
#include <set> #include <cstdio> #include <algorithm> using namespace std; typedef long long LL; const int N = 100005; int n, m, vmax, tp; pair<LL, int> st[N]; struct Seg { int l, r; LL v; friend bool operator < (Seg a, Seg b) { return (a.l != b.l)? (a.l < b.l) : (a.r < b.r); } }; multiset<Seg> S; typedef multiset<Seg>::iterator Saber; LL Pow(LL x, int b, int p, LL re = 1) { for (x %= p; b; b >>= 1, x = x * x % p) if (b & 1) re = re * x % p; return re; } namespace R { const int mod = 1e9 + 7; int seed, ret; int rnd() { ret = seed; seed = (seed * 7LL + 13) % mod; return ret; } } int main() { scanf("%d%d%d%d", &n, &m, &R::seed, &vmax); for (int i = 1; i <= n; ++i) S.insert((Seg){ i, i, R::rnd() % vmax + 1 }); for (int op, l, r, x, y; m; --m) { op = (R::rnd() & 3) + 1; l = R::rnd() % n + 1; r = R::rnd() % n + 1; if (l > r) swap(l, r); if (op == 3) x = R::rnd() % (r - l + 1) + 1; else x = R::rnd() % vmax + 1; if (op == 4) y = R::rnd() % vmax + 1; Saber p = --S.lower_bound((Seg){ l + 1, 0, 0 }); Saber q = --S.lower_bound((Seg){ r + 1, 0, 0 }); if ((*p).l < l) S.insert((Seg){ (*p).l, l - 1, (*p).v }); if (r < (*q).r) S.insert((Seg){ r + 1, (*q).r, (*q).v }); if (op == 2) { for (Saber z = p, rin; z != q; ) { rin = z; ++z; S.erase(rin); } S.erase(q); S.insert((Seg){ l, r, x }); continue; } Saber np, nq; if (p == q) { np = nq = S.insert((Seg){ l, r, (*p).v }); S.erase(p); } else { np = S.insert((Seg){ l, (*p).r, (*p).v }); nq = S.insert((Seg){ (*q).l, r, (*q).v }); S.erase(p); S.erase(q); } if (op == 1) { Seg ins; for (Saber z = np, rin; z != nq; ) { rin = z; ++z; ins = *rin; ins.v += x; S.erase(rin); S.insert(ins); } ins = *nq; ins.v += x; S.erase(nq); S.insert(ins); } if (op == 3) { tp = 0; for (Saber z = np; z != nq; ++z) st[++tp] = make_pair((*z).v, (*z).r - (*z).l + 1); st[++tp] = make_pair((*nq).v, (*nq).r - (*nq).l + 1); sort(st + 1, st + 1 + tp); for (int i = 1; i <= tp; x -= st[i].second, ++i) { if (st[i].second >= x) { printf("%lld\n", st[i].first); x = 0; break; } } if (x > 0) { puts("I love you, love you forever."); return 0; } } if (op == 4) { int ans = 0; for (Saber z = np; z != nq; ++z) ans = (ans + Pow((*z).v, x, y) * ((*z).r - (*z).l + 1)) % y; ans = (ans + Pow((*nq).v, x, y) * ((*nq).r - (*nq).l + 1)) % y; printf("%d\n", ans); } } return 0; }
$map$的实现:
#include <map> #include <cstdio> #include <algorithm> using namespace std; typedef long long LL; const int N = 100005; int n, m, vmax, tp; pair<LL, int> st[N]; map<int, LL> M; LL Pow(LL x, int b, int p, LL re = 1) { for (x %= p; b; b >>= 1, x = x * x % p) if (b & 1) re = re * x % p; return re; } namespace R { const int mod = 1e9 + 7; int seed, ret; int rnd() { ret = seed; seed = (seed * 7LL + 13) % mod; return ret; } } int main() { scanf("%d%d%d%d", &n, &m, &R::seed, &vmax); for (int i = 1; i <= n; ++i) M[i] = R::rnd() % vmax + 1; M[n + 1] = 0; for (int op, l, r, x, y; m; --m) { op = (R::rnd() & 3) + 1; l = R::rnd() % n + 1; r = R::rnd() % n + 1; if (l > r) swap(l, r); if (op == 3) x = R::rnd() % (r - l + 1) + 1; else x = R::rnd() % vmax + 1; if (op == 4) y = R::rnd() % vmax + 1; auto p = --M.upper_bound(l); auto q = M.upper_bound(r); if (p->first < l) M[l] = p->second, ++p; if (r + 1 < q->first) --q, M[r + 1] = q->second, ++q; if (op == 1) { for (; p != q; ++p) p->second += x; } if (op == 2) { while (p != q) M.erase(p++); M[l] = x; } if (op == 3) { tp = 0; for (auto rin = p; p != q; ++p) { st[++tp] = { p->second, (++rin)->first - p->first }; } sort(st + 1, st + 1 + tp); for (int i = 1; i <= tp; x -= st[i].second, ++i) { if (x <= st[i].second) { printf("%lld\n", st[i].first); break; } } } if (op == 4) { int ans = 0; for (auto rin = p; p != q; ++p) { ans = (ans + Pow(p->second, x, y) * ((++rin)->first - p->first)) % y; } printf("%d\n", ans); } } return 0; }