CF1007D. Ants(树链剖分+线段树+2-SAT及前缀优化建图)
题目链接
https://codeforces.com/problemset/problem/1007/D
题解
这道题本身并不难,这里只是记录一下 2-SAT 的前缀优化建图的相关内容。
由于问题的本质是给定许多二元集合,判断是否能从每一个二元集合中选出一个元素,使得所有选出的元素合法,因此考虑使用 2-SAT 解决该问题。
不难发现,使用 2-SAT 解决该问题的复杂度瓶颈在于建图。
我们为每一种颜色 \(i\) 对应的两条路径赋上编号。首先,我们需要为每一条树边记录包含该条边的所有路径的编号。可以将原树树链剖分之后,按结点的 dfs 序建出线段树,将路径的编号添加到线段树对应的结点上。这样,包含一条树边的所有路径编号储存在该边在线段树上对应的叶子结点以及该叶子结点的各级祖先上。因此,我们只需要通过建图来保证线段树的每一个叶子结点及其各级祖先包含的所有编号中,最多只能选择一个编号即可。
考虑使用前缀优化建图。
简单来说,前缀优化建图常用来处理某个命题集合中最多有一个命题成立或不成立的情况(不失一般性地,接下来只分析最多有一个命题成立的情况,在本题中,命题成立即为选择对应编号)。假设这些命题的编号为 \(1 \sim x\),那么我们新建 \(x\) 个结点,用这些结点来表示集合的某个前缀的所有命题是否均不成立,即:这些结点中,第 \(i\) 个结点若为真,则命题 \(1 \sim i\) 均不成立,否则命题 \(1 \sim i\) 中存在一个命题成立。
定义一次建边 \((u \rightarrow v)\) 为:建 \((u\) 为真 \(\rightarrow\) \(v\) 为真\()\) 与 \((v\) 为假 \(\rightarrow u\) 为假\()\) 两条对称边。那么我们只需要建如下三类边即可:
- \(1 \sim i\) 的所有命题均不成立 \(\rightarrow\) \(1 \sim i - 1\) 的所有命题均不成立
- 命题 \(i\) 成立 \(\rightarrow\) \(1 \sim i\) 的所有命题存在成立
- 命题 \(i\) 成立 \(\rightarrow\) \(1 \sim i - 1\) 的所有命题均不成立
本题的建图略有不同,由于在线段树上,所有的前缀本身就构成了一个树形结构,因此相同的前缀可以共用结点。不难发现,最后建出的总结点数(及边数)和线段树上存储的编号总数同阶,为 \(O(m \log^2 n)\)。由于初始在线段树上添加标记的时间复杂度也为 \(O(m \log^2 n)\),因此总时间复杂度为 \(O(m \log^2 n)\)。
代码
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10, maxnode = N * 60;
void cmin(int& x, int y) {
if (x > y) {
x = y;
}
}
int n, m, pa[N], dep[N], size[N], heavy[N], dfn[maxnode], dfn_cnt, top[N], node_cnt, low[maxnode], sccno[maxnode], scc;
vector<int> graph[N], graph_t[maxnode], nodes[N << 2], stack_t;
void dfs1(int u, int pa) {
size[u] = 1;
for (auto v : graph[u]) {
if (v != pa) {
::pa[v] = u;
dep[v] = dep[u] + 1;
dfs1(v, u);
size[u] += size[v];
if (size[v] > size[heavy[u]]) {
heavy[u] = v;
}
}
}
}
void dfs2(int u, int t) {
top[u] = t;
dfn[u] = ++dfn_cnt;
if (heavy[u]) {
dfs2(heavy[u], t);
for (auto v : graph[u]) {
if (v != pa[u] && v != heavy[u]) {
dfs2(v, v);
}
}
}
}
void add_edge(int u, int v) {
graph_t[u].push_back(v);
graph_t[v ^ 1].push_back(u ^ 1);
}
#define lo (o<<1)
#define ro (o<<1|1)
void modify(int l, int r, int o, int ql, int qr, int id) {
if (ql <= l && r <= qr) {
nodes[o].push_back(id);
} else {
int mid = l + r >> 1;
if (ql <= mid) {
modify(l, mid, lo, ql, qr, id);
} if (qr > mid) {
modify(mid + 1, r, ro, ql, qr, id);
}
}
}
void build(int l, int r, int o, int lastid) {
int ql = ++node_cnt;
int qr = (node_cnt += nodes[o].size());
if (qr > ql) {
add_edge(qr << 1 | 1, qr - 1 << 1 | 1);
} else if (lastid) {
add_edge(ql << 1 | 1, lastid << 1 | 1);
}
for (int i = 0; i < nodes[o].size(); ++i) {
int id = nodes[o][i];
if (i > 0) {
add_edge(ql + i << 1 | 1, ql + i - 1 << 1 | 1);
} else if (lastid) {
add_edge(ql << 1 | 1, lastid << 1 | 1);
}
add_edge(id, ql + i << 1);
if (i > 0) {
add_edge(ql + i - 1 << 1, id ^ 1);
} else if (lastid) {
add_edge(lastid << 1, id ^ 1);
}
}
if (l < r) {
int mid = l + r >> 1;
build(l, mid, lo, qr);
build(mid + 1, r, ro, qr);
}
}
void add_tag(int u, int v, int id) {
for (; top[u] != top[v]; u = pa[top[u]]) {
if (dep[top[u]] < dep[top[v]]) {
swap(u, v);
}
modify(2, n, 1, dfn[top[u]], dfn[u], id);
}
if (dep[u] > dep[v]) {
swap(u, v);
}
if (dfn[u] < dfn[v]) {
modify(2, n, 1, dfn[u] + 1, dfn[v], id);
}
}
void tarjan(int u) {
stack_t.push_back(u);
dfn[u] = low[u] = ++dfn_cnt;
for (auto v : graph_t[u]) {
if (!dfn[v]) {
tarjan(v);
cmin(low[u], low[v]);
} else if (!sccno[v]) {
cmin(low[u], dfn[v]);
}
}
if (dfn[u] == low[u]) {
++scc;
while (1) {
int x = stack_t.back();
stack_t.pop_back();
sccno[x] = scc;
if (x == u) {
break;
}
}
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i < n; ++i) {
int u, v;
scanf("%d%d", &u, &v);
graph[u].push_back(v);
graph[v].push_back(u);
}
dfs1(1, 0);
dfs2(1, 1);
scanf("%d", &m);
node_cnt = m;
for (int i = 1; i <= m; ++i) {
int a, b, c, d;
scanf("%d%d%d%d", &a, &b, &c, &d);
add_tag(a, b, i << 1);
add_tag(c, d, i << 1 | 1);
}
build(2, n, 1, 0);
memset(dfn, 0, sizeof dfn);
dfn_cnt = 0;
for (int i = 2; i <= (node_cnt << 1 | 1); ++i) {
if (!dfn[i]) {
tarjan(i);
}
}
for (int i = 1; i <= m; ++i) {
if (sccno[i << 1] == sccno[i << 1 | 1]) {
return puts("NO"), 0;
}
}
puts("YES");
for (int i = 1; i <= m; ++i) {
printf("%d\n", sccno[i << 1] < sccno[i << 1 | 1] ? 1 : 2);
}
return 0;
}