Luogu 4899 [IOI2018] werewolf 狼人
UOJ 407
kruskal重构树
所有边的边权记为两个点的$max$建立一棵重构树,这样子可以做出一个点走不超过$R$能到达的点的集合,再把所有边的边权记为两个点的$min$建立重构树,这样子可以做出一个点走不少于$L$所能到达的点的集合
注意到这两个集合其实是重构树叶子的一段区间,问题有解等价于区间有交,这样子就变成了一个二维数点问题
暴力线段树合并过不去,别问我是怎么知道的
一开始交的时候边数开错了
#include <cstdio> #include <cstring> #include <algorithm> #include <vector> #include "werewolf.h" using namespace std; const int N = 8e5 + 5; const int M = 2e5 + 5; const int Lg = 22; const int inf = 1 << 30; int n, m, qn, tot = 0, head[N], ufs[N]; int rt1, rt2, ncnt, w[N], fa[N][Lg]; int ln[N], rn[N], dfsc = 0, id1[M], id2[M], px1[M], px2[M], ans[M]; struct Edge { int to, nxt; } e[N]; inline void add(int from, int to) { e[++tot].to = to; e[tot].nxt = head[from]; head[from] = tot; } struct Path { int u, v, val; friend bool operator < (const Path x, const Path y) { return x.val < y.val; } } pat[M << 1]; struct Query { int s, t, lv, rv; } q[M]; struct Seg { int l, r, id, type; }; vector <Seg> vec[M]; namespace Fread { const int L = 1 << 15; char buffer[L], *S, *T; inline char Getchar() { if(S == T) { T = (S = buffer) + fread(buffer, 1, L, stdin); if(S == T) return EOF; } return *S++; } template <class T> inline void read(T &X) { char ch; T op = 1; for(ch = Getchar(); ch > '9' || ch < '0'; ch = Getchar()) if(ch == '-') op = -1; for(X = 0; ch >= '0' && ch <= '9'; ch = Getchar()) X = (X << 1) + (X << 3) + ch - '0'; X *= op; } } using Fread :: read; namespace Fwrite { const int L = 1 << 15; char buf[L], *pp = buf; void Putchar(const char c) { if(pp - buf == L) fwrite(buf, 1, L, stdout), pp = buf; *pp++ = c; } template<typename T> void print(T x) { if(x < 0) { Putchar('-'); x = -x; } if(x > 9) print(x / 10); Putchar(x % 10 + '0'); } void fsh() { fwrite(buf, 1, pp - buf, stdout); pp = buf; } template <typename T> inline void write(T x, char ch = 0) { print(x); if (ch != 0) Putchar(ch); } } using Fwrite :: write; /* namespace SegT { struct Node { int lc, rc, sum; } s[N * 20]; int root[N], pcnt = 0, top = 0, pool[N * 20]; #define lc(p) s[p].lc #define rc(p) s[p].rc #define sum(p) s[p].sum #define mid ((l + r) >> 1) inline void up(int p) { if (!p) return; sum(p) = sum(lc(p)) + sum(rc(p)); } inline int newNode() { if (top) return pool[top--]; else return ++pcnt; } inline void del(int p) { pool[++top] = p; } int merge(int u, int v, int l, int r) { if (!u || !v) return u + v; int p = newNode(); if (l == r) { sum(p) = sum(u) + sum(v); } else { lc(p) = merge(lc(u), lc(v), l, mid); rc(p) = merge(rc(u), rc(v), mid + 1, r); up(p); } del(u), del(v); return p; } void build(int &p, int l, int r, int x) { if (!p) p = newNode(); if (l == r) { ++sum(p); return; } if (x <= mid) build(lc(p), l, mid, x); else build(rc(p), mid + 1, r, x); up(p); } int query(int p, int l, int r, int x, int y) { if (!p) return 0; if (x <= l && y >= r) return sum(p); int res = 0; if (x <= mid) res += query(lc(p), l, mid, x, y); if (y > mid) res += query(rc(p), mid + 1, r, x, y); up(p); return res; } void print(int k) { printf("%d : ", k); for (int i = 1; i <= n; i++) printf("%d ", query(root[k], 1, n, i, i)); puts(""); } #undef mid } using namespace SegT; */ int find(int x) { return ufs[x] == x ? x : ufs[x] = find(ufs[x]); } void dfs1(int x, int fat) { bool isLeaf = 1; ln[x] = inf, rn[x] = -inf; fa[x][0] = fat; for (int i = 1; i <= 20; i++) fa[x][i] = fa[fa[x][i - 1]][i - 1]; for (int i = head[x]; i; i = e[i].nxt) { int y = e[i].to; isLeaf = 0; dfs1(y, x); ln[x] = min(ln[x], ln[y]); rn[x] = max(rn[x], rn[y]); } if (isLeaf) { id1[++dfsc] = x; px1[x] = dfsc; ln[x] = rn[x] = dfsc; } } void dfs2(int x, int fat) { bool isLeaf = 1; ln[x] = inf, rn[x] = -inf; fa[x][0] = fat; for (int i = 1; i <= 20; i++) fa[x][i] = fa[fa[x][i - 1]][i - 1]; for (int i = head[x]; i; i = e[i].nxt) { int y = e[i].to; isLeaf = 0; dfs2(y, x); ln[x] = min(ln[x], ln[y]); rn[x] = max(rn[x], rn[y]); } if (isLeaf) { id2[++dfsc] = x - 2 * n; px2[x - 2 * n] = dfsc; ln[x] = rn[x] = dfsc; } } inline int getPos1(int x, int val) { w[0] = inf; for (int i = 20; i >= 0; i--) if (w[fa[x][i]] <= val) x = fa[x][i]; return x; } inline int getPos2(int x, int val) { w[0] = 0; for (int i = 20; i >= 0; i--) if (w[fa[x][i]] >= val) x = fa[x][i]; return x; } /* void solve(int x, int fat) { stk[++tp] = x; for (int i = head[x]; i; i = e[i].nxt) { int y = e[i].to; solve(y, x); root[x] = merge(root[x], root[y], 1, n); } int siz = vec[x].size(); for (int i = 0; i < siz; i++) { int qid = vec[x][i]; int pos = getPos2(q[qid].s + 2 * n, q[qid].lv); // printf("%d %d\n", qid, pos); ans[qid] = query(root[x], 1, n, ln[pos - 2 * n], rn[pos - 2 * n]); } // print(x); --tp; } */ void debug(int x, int fat, bool is2 = 0) { printf("%d:%d\n", x, w[x]); if (is2) printf("%d::%d %d\n", x, ln[x - 2 * n], rn[x - 2 * n]); for (int i = head[x]; i; i = e[i].nxt) { int y = e[i].to; printf("%d %d\n", x, y); debug(y, x); } } namespace Bit { int s[M]; #define lowbit(p) (p & (-p)) inline void modify(int p) { for (; p <= n; p += lowbit(p)) ++s[p]; } inline int query(int p) { int res = 0; for (; p > 0; p -= lowbit(p)) res += s[p]; return res; } } using namespace Bit; /*int main() { #ifndef ONLINE_JUDGE freopen("sample.in", "r", stdin); freopen("my.out", "w", stdout); #endif read(n), read(m), read(qn); for (int i = 1; i <= m; i++) { read(pat[i].u), read(pat[i].v); ++pat[i].u, ++pat[i].v; pat[i].val = max(pat[i].u, pat[i].v); } for (int i = 1; i <= 4 * n; i++) ufs[i] = i; ncnt = n; sort(pat + 1, pat + 1 + m); for (int cnt = 0, i = 1; i <= m; i++) { int u = pat[i].u, v = pat[i].v; int fu = find(u), fv = find(v); if (fu == fv) continue; ++cnt, ++ncnt; w[ncnt] = pat[i].val; ufs[fu] = ufs[fv] = ncnt; add(ncnt, fu), add(ncnt, fv); if (cnt == n - 1) break; } rt1 = find(1), dfsc = 0; dfs1(rt1, 0); // debug(rt1, 0); for (int i = 1; i <= m; i++) pat[i].val = min(pat[i].u, pat[i].v); ncnt = 3 * n; sort(pat + 1, pat + 1 + m); for (int cnt = 0, i = m; i >= 1; i--) { int u = pat[i].u + 2 * n, v = pat[i].v + 2 * n; int fu = find(u), fv = find(v); if (fu == fv) continue; ++cnt, ++ncnt; w[ncnt] = pat[i].val; ufs[fu] = ufs[fv] = ncnt; add(ncnt, fu), add(ncnt, fv); if (cnt == n - 1) break; } rt2 = find(2 * n + 1), dfsc = 0; dfs2(rt2, 0); // puts(""); // debug(rt2, 0, 1); // for (int i = 1; i <= n; i++) // printf("%d%c", id[i], " \n"[i == n]); // for (int i = 1; i <= n; i++) // printf("%d ", id1[i]); // printf("\n"); // for (int i = 1; i <= n; i++) // printf("%d ", id2[i]); // printf("\n"); // for (int i = 1; i <= n; i++) // printf("%d ", px2[id1[i]]); // printf("\n"); for (int i = 1; i <= qn; i++) { read(q[i].s), read(q[i].t), read(q[i].lv), read(q[i].rv); ++q[i].s, ++q[i].t, ++q[i].lv, ++q[i].rv; int pos1 = getPos1(q[i].t, q[i].rv), pos2 = getPos2(q[i].s + 2 * n, q[i].lv); // printf("%d %d %d %d\n", ln[pos1], rn[pos1], ln[pos2], rn[pos2]); vec[ln[pos1] - 1].push_back((Seg) {ln[pos2], rn[pos2], i, -1}); vec[rn[pos1]].push_back((Seg) {ln[pos2], rn[pos2], i, 1}); } for (int i = 1; i <= n; i++) { int siz = vec[i].size(); modify(px2[id1[i]]); for (int j = 0; j < siz; j++) { int l = vec[i][j].l, r = vec[i][j].r, t = vec[i][j].type, qid = vec[i][j].id; ans[qid] += t * (query(r) - query(l - 1)); } } // for (int i = 1; i <= qn; i++) // write(ans[i] > 0 ? 1 : 0, '\n'); for (int i = 1; i <= qn; i++) write(ans[i] > 0 ? 1 : 0, '\n'); Fwrite :: fsh(); return 0; } */ vector<int> check_validity(int N, vector<int> X, vector<int> Y, vector<int> S, vector<int> E, vector<int> L, vector<int> R) { // read(n), read(m), read(qn); n = N, m = X.size(), qn = S.size(); for (int i = 1; i <= m; i++) { // read(pat[i].u), read(pat[i].v); pat[i].u = X[i - 1], pat[i].v = Y[i - 1]; ++pat[i].u, ++pat[i].v; pat[i].val = max(pat[i].u, pat[i].v); } for (int i = 1; i <= 4 * n; i++) ufs[i] = i; ncnt = n; sort(pat + 1, pat + 1 + m); for (int cnt = 0, i = 1; i <= m; i++) { int u = pat[i].u, v = pat[i].v; int fu = find(u), fv = find(v); if (fu == fv) continue; ++cnt, ++ncnt; w[ncnt] = pat[i].val; ufs[fu] = ufs[fv] = ncnt; add(ncnt, fu), add(ncnt, fv); if (cnt == n - 1) break; } rt1 = find(1), dfsc = 0; dfs1(rt1, 0); // debug(rt1, 0); for (int i = 1; i <= m; i++) pat[i].val = min(pat[i].u, pat[i].v); ncnt = 3 * n; sort(pat + 1, pat + 1 + m); for (int cnt = 0, i = m; i >= 1; i--) { int u = pat[i].u + 2 * n, v = pat[i].v + 2 * n; int fu = find(u), fv = find(v); if (fu == fv) continue; ++cnt, ++ncnt; w[ncnt] = pat[i].val; ufs[fu] = ufs[fv] = ncnt; add(ncnt, fu), add(ncnt, fv); if (cnt == n - 1) break; } rt2 = find(2 * n + 1), dfsc = 0; dfs2(rt2, 0); // puts(""); // debug(rt2, 0, 1); // for (int i = 1; i <= n; i++) // printf("%d%c", id[i], " \n"[i == n]); // for (int i = 1; i <= n; i++) // printf("%d ", id1[i]); // printf("\n"); // for (int i = 1; i <= n; i++) // printf("%d ", id2[i]); // printf("\n"); // for (int i = 1; i <= n; i++) // printf("%d ", px2[id1[i]]); // printf("\n"); for (int i = 1; i <= qn; i++) { // read(q[i].s), read(q[i].t), read(q[i].lv), read(q[i].rv); q[i].s = S[i - 1], q[i].t = E[i - 1], q[i].lv = L[i - 1], q[i].rv = R[i - 1]; ++q[i].s, ++q[i].t, ++q[i].lv, ++q[i].rv; int pos1 = getPos1(q[i].t, q[i].rv), pos2 = getPos2(q[i].s + 2 * n, q[i].lv); // printf("%d %d %d %d\n", ln[pos1], rn[pos1], ln[pos2], rn[pos2]); vec[ln[pos1] - 1].push_back((Seg) {ln[pos2], rn[pos2], i, -1}); vec[rn[pos1]].push_back((Seg) {ln[pos2], rn[pos2], i, 1}); } for (int i = 1; i <= n; i++) { int siz = vec[i].size(); modify(px2[id1[i]]); for (int j = 0; j < siz; j++) { int l = vec[i][j].l, r = vec[i][j].r, t = vec[i][j].type, qid = vec[i][j].id; ans[qid] += t * (query(r) - query(l - 1)); } } // for (int i = 1; i <= qn; i++) // write(ans[i] > 0 ? 1 : 0, '\n'); // for (int i = 1; i <= qn; i++) // write(ans[i] > 0 ? 1 : 0, '\n'); // Fwrite :: fsh(); vector <int> retAns; for (int i = 1; i <= qn; i++) retAns.push_back(ans[i] > 0 ? 1 : 0); return retAns; }