虚树DP
用于解决一类题目给出多个询问,每次询问只涉及一部分点的树上DP。
很明显如果直接DP那么每次询问都要做一遍DP,复杂度是不能接受的。
虚树
顾名思义就是一棵虚拟创建的树,这棵树上只会保留与答案有关的关键点以及关键点的 LCA,
这样整棵树的规模不会超过原树的两倍。
虚树的建立
预处理出建立虚树所需要的东西,LCA以及一个栈(栈里维护的是根节点到栈顶节点的一条链)
对于每个关键点,按照 dfs 序从小到大排序。
同 a 表示当前需要加入的点,p 表示当前栈顶维护的点,这时会有两种情况:
1.a 与 p 在同一条链中,直接入栈即可
2.a 与 p 在不同子树中,这时就说明 p 所在这棵子树已经遍历完毕了,对它构建虚树
设栈顶元素为 x, 栈顶下一个元素为 y
1.若 dfn[y] > dfn[lca],将 y -> x 连边,x出栈。
2.若 dfn[y] <= dfn[lca],即 y = lca 或 lca 在 x,y之间,将 lca -> x 连边 lca,a 入栈,结束。
不断重复上述过程,就可以构建出一个虚树了。
例题
先建立虚树,然后在虚树中跑树上DP。
#include<iostream>
#include<algorithm>
#include<cstdio>
using namespace std;
typedef long long ll;
const int N = 5e5 + 5;
const ll INF = 1e18;
ll mn[N], f[N];
int stk[N], dep[N], p[N];
int fa[N][21], dfn[N], tot, top;
struct Edge
{
int head[N], ver[N], net[N], edge[N], idx;
void add(int a, int b, int c)
{
net[++idx] = head[a], ver[idx] = b, edge[idx] = c, head[a] = idx;
}
} e1, e2;
bool cmp(int a, int b)
{
return dfn[a] < dfn[b];
}
void dfs1(int u, int f)
{
fa[u][0] = f;
dfn[u] = ++tot, dep[u] = dep[f] + 1;
for (int i = 1; i <= 20; i++)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
for (int i = e1.head[u]; i; i = e1.net[i])
{
int v = e1.ver[i];
if (v == f)
continue;
mn[v] = min(mn[u], 1LL * e1.edge[i]);
dfs1(v, u);
}
}
int lca(int x, int y)
{
if (dep[x] < dep[y])
swap(x, y);
for (int i = 20; i >= 0; i--)
if (dep[fa[x][i]] >= dep[y])
x = fa[x][i];
if (x == y)
return x;
for (int i = 20; i >= 0; i--)
if (fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
void insert(int u)
{
if (top == 1)
{
if (u != 1)
stk[++top] = u;
return;
}
int l = lca(u, stk[top]);
if (l == stk[top])
return;
while (top > 1 && dfn[stk[top - 1]] >= dfn[l])
e2.add(stk[top - 1], stk[top], 0), top--;
if (stk[top] != l)
{
e2.add(l, stk[top], 0);
stk[top] = l;
}
stk[++top] = u;
}
void dfs2(int u)
{
if (!e2.head[u])
f[u] = mn[u];
else
{
f[u] = 0;
for (int i = e2.head[u]; i; i = e2.net[i])
{
int v = e2.ver[i];
dfs2(v);
f[u] += f[v];
}
f[u] = min(f[u], mn[u]);
e2.head[u] = 0;
}
}
int main()
{
int n, m;
scanf("%d", &n);
for (int i = 1; i < n; i++)
{
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
e1.add(u, v, w), e1.add(v, u, w);
}
mn[1] = INF;
dfs1(1, 0);
scanf("%d", &m);
while (m--)
{
int k;
scanf("%d", &k);
for (int i = 1; i <= k; i++)
scanf("%d", &p[i]);
sort(p + 1, p + 1 + k, cmp);
e2.idx = 0;
stk[top = 1] = 1;
for (int i = 1; i <= k; i++)
insert(p[i]);
while (top > 1)
e2.add(stk[top - 1], stk[top], 0), top--;
dfs2(1);
printf("%lld\n", f[1]);
}
return 0;
}
CF1594E2 Rubik's Cube Coloring (hard version)
比较特殊的虚树DP,虽然没有询问,但原树节点很多,而关键点较少,可以考虑建立虚树
#include<iostream>
#include<cstdio>
#include<unordered_map>
using namespace std;
const int MOD = 1e9 + 7;
const int N = 2e5 + 5;
typedef long long ll;
ll v[2005], f[N][7];
int k, n, c[N];
int head[N], ver[N], net[N], idx, tot;
char s[2005][21];
unordered_map<ll, int> id;
void add(int a, int b)
{
for (int i = head[a]; i; i = net[i])
if (ver[i] == b)
return;
net[++idx] = head[a], ver[idx] = b, head[a] = idx;
}
void build()
{
for (int i = 1; i <= n; i++)
{
ll now = v[i];
if (!id[now])
id[now] = ++tot;
if (s[i][0] == 'w')
c[id[now]] = 1;
else if (s[i][0] == 'y')
c[id[now]] = 2;
else if (s[i][0] == 'g')
c[id[now]] = 3;
else if (s[i][0] == 'b')
c[id[now]] = 4;
else if (s[i][0] == 'r')
c[id[now]] = 5;
else
c[id[now]] = 6;
while (now >> 1)
add(id[now >> 1] ? id[now >> 1] : id[now >> 1] = ++tot, id[now]), now >>= 1;
}
}
void dfs(int u)
{
for (int j = head[u]; j; j = net[j])
dfs(ver[j]);
for (int i = 1; i <= 6; i++)
{
if (!c[u])
f[u][i] = 1;
else if (c[u] == i)
f[u][i] = 1;
// printf("%d %d %lld\n", u, c[u], f[u][i]);
for (int j = head[u]; j; j = net[j])
{
int v = ver[j];
ll tmp = f[v][0];
if (i <= 2)
tmp -= f[v][1] + f[v][2];
else if (i <= 4)
tmp -= f[v][3] + f[v][4];
else
tmp -= f[v][5] + f[v][6];
// printf("%d %d %d %lld %lld\n", i, u, v, f[v][0], tmp);
f[u][i] = f[u][i] * tmp % MOD;
}
// printf("%d %lld %lld %d\n", u, f[u][0], f[u][i], c[u]);
f[u][0] = (f[u][0] + f[u][i]) % MOD;
}
}
ll qmi(ll a, ll b)
{
ll res = 1;
while (b)
{
if (b & 1)
res = res * a % MOD;
a = a * a % MOD;
b >>= 1;
}
return res;
}
int main()
{
scanf("%d%d", &k, &n);
for (int i = 1; i <= n; i++)
scanf("%lld%s", &v[i], s[i]);
build();
dfs(id[1]);
// printf("%lld\n", f[id[1]][0]);
ll ans = f[id[1]][0] * qmi(4, ((1LL << k) - 1 - tot)) % MOD;
printf("%lld", (ans + MOD) % MOD);
return 0;
}