算法笔记--支配树
1.DAG
按照拓扑序从小到大处理,对于每个节点,将所有连接它的点的lca求出来,它在支配树上的父亲就是这个lca。
2.一般图
模板:
vector<int> g[N], rg[N], tg[N], G[N];
int in[N], dfn[N], rak[N], fa[N], sdom[N], idom[N], val[N], ufs[N], cnt;
LL sz[N];
void dfs1(int u) {
sdom[u] = dfn[u] = ++cnt;
rak[cnt] = u;
for (int v : g[u]) {
if(!dfn[v]) {
fa[v] = u;
dfs1(v);
}
}
}
void dfs2(int u) {
sz[u] = 1;
for (int v : G[u]) {
dfs2(v);
sz[u] += sz[v];
}
}
int Find(int x) {
if(ufs[x] == x) return x;
int fx = Find(ufs[x]);
if(sdom[val[ufs[x]]] < sdom[val[x]]) val[x] = val[ufs[x]];
return ufs[x] = fx;
}
void lengauer_tarjan(int n) {
///初始化
cnt = 0;
for (int i = 0; i <= n; ++i) {
sz[i] = dfn[i] = in[i] = 0;
ufs[i] = val[i] = i;
g[i].clear(); ///原图
rg[i].clear();///反图
tg[i].clear();
G[i].clear();
}
///输入
///从n出发
g[0].pb(n), rg[n].pb(0);
///从入度为0的出发
for (int i = 1; i <= n; ++i) if(!in[i]) g[0].pb(i), rg[i].pb(0);
dfs1(0);
for (int i = cnt; i >= 2; --i) {
int u = rak[i];
for (int v : rg[u]) {
if(!dfn[v]) continue;
Find(v);
sdom[u] = min(sdom[u], sdom[val[v]]);
}
ufs[u] = fa[u];
tg[rak[sdom[u]]].pb(u);
for (int v : tg[fa[u]]) {
Find(v);
if(sdom[val[v]] == sdom[v]) idom[v] = rak[sdom[v]];
else idom[v] = val[v];
}
tg[fa[u]].clear();
}
for (int i = 1; i <= cnt; ++i) {
int u = rak[i];
if(idom[u] != rak[sdom[u]]) idom[u] = idom[idom[u]];
}
for (int i = 1; i <= n; ++i) if(dfn[i]) G[idom[i]].pb(i);
dfs2(0);
for (int i = 1; i <= n; ++i) printf("%d%c", dfn[i]?sz[i]:0, " \n"[i==n]);
}
例题1:P2597 [ZJOI2012]灾难
代码:
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb emplace_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "\n";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head
const int N = 1e5 + 5;
vector<int> g[N], rg[N], tg[N], G[N];
int in[N], dfn[N], rak[N], fa[N], sdom[N], idom[N], val[N], ufs[N], sz[N], cnt;
void dfs1(int u) {
sdom[u] = dfn[u] = ++cnt;
rak[cnt] = u;
for (int v : g[u]) {
if(!dfn[v]) {
fa[v] = u;
dfs1(v);
}
}
}
void dfs2(int u) {
sz[u] = 1;
for (int v : G[u]) {
dfs2(v);
sz[u] += sz[v];
}
}
int Find(int x) {
if(ufs[x] == x) return x;
int fx = Find(ufs[x]);
if(sdom[val[ufs[x]]] < sdom[val[x]]) val[x] = val[ufs[x]];
return ufs[x] = fx;
}
void lengauer_tarjan(int n) {
///初始化
cnt = 0;
for (int i = 0; i <= n; ++i) {
dfn[i] = in[i] = 0;
ufs[i] = val[i] = i;
g[i].clear();
rg[i].clear();
tg[i].clear();
G[i].clear();
}
///输入
for (int i = 1; i <= n; ++i) {
int x;
while(~scanf("%d", &x) && x) {
g[x].pb(i);
in[i]++;
rg[i].pb(x);
}
}
///输入
for (int i = 1; i <= n; ++i) if(!in[i]) g[0].pb(i), rg[i].pb(0);
dfs1(0);
for (int i = n+1; i >= 2; --i) {
int u = rak[i];
for (int v : rg[u]) {
Find(v);
sdom[u] = min(sdom[u], sdom[val[v]]);
}
ufs[u] = fa[u];
tg[rak[sdom[u]]].pb(u);
for (int v : tg[fa[u]]) {
Find(v);
if(sdom[val[v]] == sdom[v]) idom[v] = rak[sdom[v]];
else idom[v] = val[v];
}
tg[fa[u]].clear();
}
for (int i = 1; i <= n+1; ++i) {
int u = rak[i];
if(idom[u] != rak[sdom[u]]) idom[u] = idom[idom[u]];
}
for (int i = 1; i <= n; ++i) G[idom[i]].pb(i);
dfs2(0);
for (int i = 1; i <= n; ++i) printf("%d\n", sz[i]-1);
}
int n;
int main() {
scanf("%d", &n);
lengauer_tarjan(n);
return 0;
}
例题2:P5180 【模板】支配树
代码:
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb emplace_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "\n";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head
const int N = 2e5 + 5;
vector<int> g[N], rg[N], tg[N], G[N];
int in[N], dfn[N], rak[N], fa[N], sdom[N], idom[N], val[N], ufs[N], sz[N], cnt;
void dfs1(int u) {
sdom[u] = dfn[u] = ++cnt;
rak[cnt] = u;
for (int v : g[u]) {
if(!dfn[v]) {
fa[v] = u;
dfs1(v);
}
}
}
void dfs2(int u) {
sz[u] = 1;
for (int v : G[u]) {
dfs2(v);
sz[u] += sz[v];
}
}
int Find(int x) {
if(ufs[x] == x) return x;
int fx = Find(ufs[x]);
if(sdom[val[ufs[x]]] < sdom[val[x]]) val[x] = val[ufs[x]];
return ufs[x] = fx;
}
void lengauer_tarjan(int n) {
///初始化
cnt = 0;
for (int i = 0; i <= n; ++i) {
dfn[i] = in[i] = 0;
ufs[i] = val[i] = i;
g[i].clear();
rg[i].clear();
tg[i].clear();
G[i].clear();
}
///输入
int m, u, v;
scanf("%d", &m);
for (int i = 1; i <= m; ++i) {
scanf("%d %d", &u, &v);
g[u].pb(v);
in[v]++;
rg[v].pb(u);
}
///输入
g[0].pb(1), rg[1].pb(0);
dfs1(0);
for (int i = n+1; i >= 2; --i) {
int u = rak[i];
for (int v : rg[u]) {
Find(v);
sdom[u] = min(sdom[u], sdom[val[v]]);
}
ufs[u] = fa[u];
tg[rak[sdom[u]]].pb(u);
for (int v : tg[fa[u]]) {
Find(v);
if(sdom[val[v]] == sdom[v]) idom[v] = rak[sdom[v]];
else idom[v] = val[v];
}
tg[fa[u]].clear();
}
for (int i = 1; i <= n+1; ++i) {
int u = rak[i];
if(idom[u] != rak[sdom[u]]) idom[u] = idom[idom[u]];
}
for (int i = 1; i <= n; ++i) G[idom[i]].pb(i);
dfs2(0);
for (int i = 1; i <= n; ++i) printf("%d%c", sz[i], " \n"[i==n]);
}
int n;
int main() {
scanf("%d", &n);
lengauer_tarjan(n);
return 0;
}
例题3:HDU 4694
代码:
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb emplace_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "\n";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head
const int N = 5e4 + 5;
vector<int> g[N], rg[N], tg[N], G[N];
int in[N], dfn[N], rak[N], fa[N], sdom[N], idom[N], val[N], ufs[N], cnt;
LL sz[N];
void dfs1(int u) {
sdom[u] = dfn[u] = ++cnt;
rak[cnt] = u;
for (int v : g[u]) {
if(!dfn[v]) {
fa[v] = u;
dfs1(v);
}
}
}
void dfs2(int u, LL d) {
sz[u] = d;
for (int v : G[u]) {
dfs2(v, d+v);
}
}
int Find(int x) {
if(ufs[x] == x) return x;
int fx = Find(ufs[x]);
if(sdom[val[ufs[x]]] < sdom[val[x]]) val[x] = val[ufs[x]];
return ufs[x] = fx;
}
void lengauer_tarjan(int n) {
///初始化
cnt = 0;
for (int i = 0; i <= n; ++i) {
sz[i] = dfn[i] = in[i] = 0;
ufs[i] = val[i] = i;
g[i].clear();
rg[i].clear();
tg[i].clear();
G[i].clear();
}
///输入
int m, u, v;
scanf("%d", &m);
for (int i = 1; i <= m; ++i) {
scanf("%d %d", &u, &v);
g[u].pb(v);
in[v]++;
rg[v].pb(u);
}
///输入
g[0].pb(n), rg[n].pb(0);
dfs1(0);
for (int i = cnt; i >= 2; --i) {
int u = rak[i];
for (int v : rg[u]) {
if(!dfn[v]) continue;
Find(v);
sdom[u] = min(sdom[u], sdom[val[v]]);
}
ufs[u] = fa[u];
tg[rak[sdom[u]]].pb(u);
for (int v : tg[fa[u]]) {
Find(v);
if(sdom[val[v]] == sdom[v]) idom[v] = rak[sdom[v]];
else idom[v] = val[v];
}
tg[fa[u]].clear();
}
for (int i = 1; i <= cnt; ++i) {
int u = rak[i];
if(idom[u] != rak[sdom[u]]) idom[u] = idom[idom[u]];
}
for (int i = 1; i <= n; ++i) if(dfn[i]) G[idom[i]].pb(i);
dfs2(0, 0);
for (int i = 1; i <= n; ++i) printf("%lld%c", dfn[i]?sz[i]:0, " \n"[i==n]);
}
int n;
int main() {
while(~scanf("%d", &n)) lengauer_tarjan(n);
return 0;
}
例题4:Codechef Counting on a directed graph
代码:
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb emplace_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "\n";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head
const int N = 1e5 + 5;
vector<int> g[N], rg[N], tg[N], G[N];
int in[N], dfn[N], rak[N], fa[N], sdom[N], idom[N], val[N], ufs[N], cnt;
LL sz[N];
void dfs1(int u) {
sdom[u] = dfn[u] = ++cnt;
rak[cnt] = u;
for (int v : g[u]) {
if(!dfn[v]) {
fa[v] = u;
dfs1(v);
}
}
}
void dfs2(int u) {
sz[u] = 1;
for (int v : G[u]) {
dfs2(v);
sz[u] += sz[v];
}
}
int Find(int x) {
if(ufs[x] == x) return x;
int fx = Find(ufs[x]);
if(sdom[val[ufs[x]]] < sdom[val[x]]) val[x] = val[ufs[x]];
return ufs[x] = fx;
}
void lengauer_tarjan(int n) {
///初始化
cnt = 0;
for (int i = 0; i <= n; ++i) {
sz[i] = dfn[i] = in[i] = 0;
ufs[i] = val[i] = i;
g[i].clear(); ///原图
rg[i].clear();///反图
tg[i].clear();
G[i].clear();
}
///输入
int m, u, v;
scanf("%d", &m);
for (int i = 1; i <= m; ++i) {
scanf("%d %d", &u, &v);
g[u].pb(v);
in[v]++;
rg[v].pb(u);
}
///从n出发
g[0].pb(1), rg[1].pb(0);
///从入度为0的出发
dfs1(0);
for (int i = cnt; i >= 2; --i) {
int u = rak[i];
for (int v : rg[u]) {
if(!dfn[v]) continue;
Find(v);
sdom[u] = min(sdom[u], sdom[val[v]]);
}
ufs[u] = fa[u];
tg[rak[sdom[u]]].pb(u);
for (int v : tg[fa[u]]) {
Find(v);
if(sdom[val[v]] == sdom[v]) idom[v] = rak[sdom[v]];
else idom[v] = val[v];
}
tg[fa[u]].clear();
}
for (int i = 1; i <= cnt; ++i) {
int u = rak[i];
if(idom[u] != rak[sdom[u]]) idom[u] = idom[idom[u]];
}
for (int i = 1; i <= n; ++i) if(dfn[i]) G[idom[i]].pb(i);
dfs2(0);
LL ans = 0, tot = 1;
for (int v : G[1]) {
ans += sz[v]*tot;
tot += sz[v];
}
printf("%lld\n", ans);
}
int n;
int main() {
while(~scanf("%d", &n)) lengauer_tarjan(n);
return 0;
}