虚树
虚树
概念
大概就是在要做多次树形 dp 的时候,发现并不是所有点都用得到。也就是只有“关键点”才用得到。考虑在每次 dp 的时候,重新建一棵树,只有这些关键点。
把关键点拿出来,按照 dfs 序排序,用单调栈维护一条链,然后每次加点,建树。
例题
T1 洛谷 P2495 [SDOI2011]消耗战
题目描述
在一场战争中,战场由n个岛屿和n-1个桥梁组成,保证每两个岛屿间有且仅有一条路径可达。现在,我军已经侦查到敌军的总部在编号为1的岛屿,而且他们已经没有足够多的能源维系战斗,我军胜利在望。已知在其他k个岛屿上有丰富能源,为了防止敌军获取能源,我军的任务是炸毁一些桥梁,使得敌军不能到达任何能源丰富的岛屿。由于不同桥梁的材质和结构不同,所以炸毁不同的桥梁有不同的代价,我军希望在满足目标的同时使得总代价最小。
侦查部门还发现,敌军有一台神秘机器。即使我军切断所有能源之后,他们也可以用那台机器。机器产生的效果不仅仅会修复所有我军炸毁的桥梁,而且会重新随机资源分布(但可以保证的是,资源不会分布到1号岛屿上)。不过侦查部门还发现了这台机器只能够使用m次,所以我们只需要把每次任务完成即可。
输入输出格式
输入格式:
第一行一个整数n,代表岛屿数量。
接下来n-1行,每行三个整数u,v,w,代表u号岛屿和v号岛屿由一条代价为c的桥梁直接相连,保证1<=u,v<=n且1<=c<=100000。
第n+1行,一个整数m,代表敌方机器能使用的次数。
接下来m行,每行一个整数ki,代表第i次后,有ki个岛屿资源丰富,接下来k个整数h1,h2,…hk,表示资源丰富岛屿的编号。
输出格式:
输出有m行,分别代表每次任务的最小代价。
输入输出样例
输入样例#1:
10
1 5 13
1 9 6
2 1 19
2 4 8
2 3 91
5 6 8
7 5 4
7 8 31
10 7 9
3
2 10 6
4 5 7 8 3
3 9 4 6
输出样例#1:
12
32
22
说明
【数据规模和约定】
对于10%的数据,2<=n<=10,1<=m<=5,1<=ki<=n-1
对于20%的数据,2<=n<=100,1<=m<=100,1<=ki<=min(10,n-1)
对于40%的数据,2<=n<=1000,m>=1,sigma(ki)<=500000,1<=ki<=min(15,n-1)
对于100%的数据,2<=n<=250000,m>=1,sigma(ki)<=500000,1<=ki<=n-1
思路:考虑只有一次询问的话是一个非常简单的 dp。但是询问次数很多,可询问的总点数与 n 同阶,那么利用虚树,就可以做到 \(O(klog_2k)\)
看起来非常的 naive,但我抄代码还写了好久
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#define ll long long
#define maxn 250010
#define maxk 500010
#define INF 1000000000000000
using namespace std;
int n, m;
inline ll min(ll a, ll b, ll c){return min(a, min(b, c));}
struct Edge{
int to, next; ll w;
}e[maxn * 2], E[maxn * 2]; int c1, head[maxn], c2, Head[maxn];
inline void add_edge(int u, int v, ll w){
e[c1].to = v; e[c1].w = w;
e[c1].next = head[u]; head[u] = c1++;
}
inline void Add_edge(int u, int v, ll w){
E[c2].to = v; E[c2].w = w;
E[c2].next = Head[u]; Head[u] = c2++;
}
int f[maxn][21], id[maxn], c3, dep[maxn]; ll mn[maxn][21];
void dfs(int u, int fa){
id[u] = ++c3;
for(int i = head[u]; ~i; i = e[i].next){
int v = e[i].to; ll w = e[i].w; if(v == fa) continue;
dep[v] = dep[u] + 1; f[v][0] = u; mn[v][0] = w;
dfs(v, u);
}
}
void init_lca(){
for(int j = 1; j <= 20; ++j)
for(int i = 1; i <= n; ++i){
f[i][j] = f[f[i][j - 1]][j - 1];
mn[i][j] = min(mn[i][j - 1], mn[f[i][j - 1]][j - 1]);
}
}
int get_lca(int x, int y, ll &M){
M = INF;
if(dep[x] < dep[y]) swap(x, y);
for(int i = 20; i >= 0; --i)
if(dep[f[x][i]] >= dep[y]){
M = min(M, mn[x][i], mn[y][i]);
x = f[x][i];
}
if(x == y) return x;
for(int i = 20; i >= 0; --i)
if(f[x][i] != f[y][i]){
M = min(M, mn[x][i], mn[y][i]);
x = f[x][i], y = f[y][i];
}
M = min(M, min(mn[x][0], mn[y][0]));
return f[x][0];
}
bool vis[maxn]; ll dp[maxn];
void Dfs(int u, int fa){
dp[u] = 0;
for(int i = Head[u]; ~i; i = E[i].next){
int v = E[i].to; ll w = E[i].w; if(v == fa) continue;
Dfs(v, u);
if(vis[v]) dp[u] += w;
else dp[u] += min(w, dp[v]);
}
}
inline bool cmp(int x, int y){return id[x] < id[y];}
int st[maxn], top, a[maxn];
void solve(int m){
sort(a + 1, a + m + 1, cmp);
st[top = 1] = 1; Head[1] = -1; c2 = 0;;
for(int i = 1; i <= m; ++i){ ll t;
int l = get_lca(st[top], a[i], t);
if(l != st[top]){
while(id[l] < id[st[top - 1]]){
ll w; int u = st[top], v = st[top - 1];
get_lca(u, v, w); Add_edge(u, v, w); Add_edge(v, u, w);
--top;
}
if(id[l] > id[st[top - 1]]){
ll w; Head[l] = -1; get_lca(l, st[top], w);
Add_edge(l, st[top], w); Add_edge(st[top], l, w);
st[top] = l;
}
else{
ll w; get_lca(l, st[top], w);
Add_edge(l, st[top], w); Add_edge(st[top], l, w);
--top;
}
}
Head[a[i]] = -1; st[++top] = a[i];
}
for(int i = 1; i < top; ++i){
int u = st[i], v = st[i + 1]; ll w;
get_lca(u, v, w);
Add_edge(u, v, w); Add_edge(v, u, w);
}
Dfs(1, 0); printf("%lld\n", dp[1]);
}
int main(){ memset(head, -1, sizeof head); memset(Head, -1, sizeof Head);
scanf("%d", &n);
for(int i = 0; i <= n; ++i)
for(int j = 0; j <= 20; ++j) mn[i][j] = INF;
for(int i = 1; i < n; ++i){
int x, y, z; scanf("%d%d%d", &x, &y, &z);
add_edge(x, y, z); add_edge(y, x, z);
}
dep[1] = 1; dfs(1, 0); init_lca();
scanf("%d", &m);
for(int i = 1; i <= m; ++i){
int x; scanf("%d", &x);
for(int j = 1; j <= x; ++j) scanf("%d", &a[j]), vis[a[j]] = 1;
solve(x);
for(int j = 1; j <= x; ++j) vis[a[j]] = 0;
}
return 0;
}
T2 洛谷 P4103 [HEOI2014]大工程
题目描述
国家有一个大工程,要给一个非常大的交通网络里建一些新的通道。
我们这个国家位置非常特殊,可以看成是一个单位边权的树,城市位于顶点上。
在 2 个国家 a,b 之间建一条新通道需要的代价为树上 a,b 的最短路径。
现在国家有很多个计划,每个计划都是这样,我们选中了 k 个点,然后在它们两两之间 新建 C(k,2)条 新通道。现在对于每个计划,我们想知道: 1.这些新通道的代价和 2.这些新通道中代价最小的是多少 3.这些新通道中代价最大的是多少
输入输出格式
输入格式:
第一行 n 表示点数。
接下来 n-1 行,每行两个数 a,b 表示 a 和 b 之间有一条边。点从 1 开始标号。
接下来一行 q 表示计划数。对每个计划有 2 行,第一行 k 表示这个计划选中了几个点。
第二行用空格隔开的 k 个互不相同的数表示选了哪 k 个点。
输出格式:
输出 q 行,每行三个数分别表示代价和,最小代价,最大代价。
输入输出样例
输入样例#1:
10
2 1
3 2
4 1
5 2
6 4
7 5
8 6
9 7
10 9
5
2
5 4
2
10 4
2
5 2
2
6 1
2
6 1
输出样例#1:
3 3 3
6 6 6
1 1 1
2 2 2
2 2 2
说明
对于第 1,2 个点: n<=10000
对于第 3,4,5 个点: n<=100000,交通网络构成一条链
对于第 6,7 个点: n<=100000
对于第 8,9,10 个点: n<=1000000
对于所有数据, q<=50000并且保证所有k之和<=2*n
思路:显然虚树。考虑建出虚树后。对于最大和最小,\(f[u]\) 表示 u 的子树中的最短链,\(g[u]\) 表示 u 的子树的中最长链,然后就可以维护,稍微需要特判下关键点和非关键点。对于总代价,只需要按照边来计算,即算每条边下面有多少点,上面有多少点,乘起来就是这条边的贡献。
#include<iostream>
#include<cstdio>
#include<cctype>
#include<algorithm>
#include<cstring>
#define maxn 1000010
#define INF 1000000
#define ll long long
#define gc() getchar()
using namespace std;
int n, m;
int read(){
int x = 0; char c = gc();
while(!isdigit(c)) c = gc();
while(isdigit(c)){x = x * 10 + c - '0'; c = gc();}
return x;
}
struct Edge{
int to, next, w;
}e[maxn * 2], E[maxn * 2]; int c1, head[maxn], c2, Head[maxn];
inline void add_edge(int u, int v, int w){
e[c1].to = v; e[c1].w = w;
e[c1].next = head[u]; head[u] = c1++;
e[c1].to = u; e[c1].w = w;
e[c1].next = head[v]; head[v] = c1++;
}
inline void Add_edge(int u, int v, int w){
E[c2].to = v; E[c2].w = w;
E[c2].next = Head[u]; Head[u] = c2++;
E[c2].to = u; E[c2].w = w;
E[c2].next = Head[v]; Head[v] = c2++;
}
int ou[maxn * 2], in[maxn], id[maxn], c3, dis[maxn], c4;
void Dfs(int u, int fa, int w){
ou[++c3] = u; in[u] = c3; dis[u] = w; id[u] = ++c4;
for(int i = head[u]; ~i; i = e[i].next){
int v = e[i].to; if(v == fa) continue;
Dfs(v, u, w + 1); ou[++c3] = u;
}
}
inline int st_min(int l, int r){
return in[l] < in[r] ? l : r;
}
int Log[maxn * 2], St[maxn * 2][21];
void init_lca(){ Log[0] = -1;
for(int i = 1; i <= c3; ++i) Log[i] = Log[i >> 1] + 1;
for(int i = 1; i <= c3; ++i) St[i][0] = ou[i];
for(int j = 1; j <= 20; ++j)
for(int i = 1; i + (1 << j) - 1 <= c3; ++i)
St[i][j] = st_min(St[i][j - 1], St[i + (1 << j - 1)][j - 1]);
}
inline int get_lca(int l, int r){
l = in[l]; r = in[r]; if(l > r) swap(l, r);
int x = Log[r - l + 1];
return st_min(St[l][x], St[r - (1 << x) + 1][x]);
}
inline int D(int x, int y){
return dis[x] + dis[y] - 2 * dis[get_lca(x, y)];
}
int sz[maxn], f[maxn], g[maxn], N, du[maxn]; ll s; int s1, s2; bool vis[maxn];
void dfs1(int u, int fa){
sz[u] = vis[u];
for(int i = Head[u]; ~i; i = E[i].next){
int v = E[i].to, w = E[i].w; if(v == fa) continue;
dfs1(v, u); sz[u] += sz[v];
}
}
void dfs2(int u, int fa){
f[u] = vis[u] ? 0 : INF; g[u] = vis[u] ? 0 : -INF;
for(int i = Head[u]; ~i; i = E[i].next){
int v = E[i].to, w = E[i].w; if(v == fa) continue;
dfs2(v, u);
s1 = min(s1, f[u] + f[v] + w); s2 = max(s2, g[u] + g[v] + w);
f[u] = min(f[u], f[v] + w); g[u] = max(g[u], g[v] + w);
s += 1ll * w * (N - sz[v]) * sz[v];
}
}
inline bool cmp(int x, int y){return in[x] < in[y];}
int a[maxn], st[maxn], top;
void solve(int m){
sort(a + 1, a + m + 1, cmp);
st[top = 1] = 1; Head[1] = -1; s = 0; c2 = 0; s2 = -INF; s1 = INF;
for(int i = 1; i <= m; ++i){ if(a[i] == 1) continue;
int lca = get_lca(a[i], st[top]);
if(lca != st[top]){
while(id[lca] < id[st[top - 1]]){
Add_edge(st[top], st[top - 1], D(st[top], st[top - 1]));
--top;
}
if(id[lca] > id[st[top - 1]]){
Head[lca] = -1;
Add_edge(lca, st[top], D(lca, st[top]));
st[top] = lca;
}
else{
Add_edge(lca, st[top], D(lca, st[top]));
--top;
}
}
Head[a[i]] = -1; st[++top] = a[i];
}
for(int i = 1; i < top; ++i) Add_edge(st[i], st[i + 1], D(st[i], st[i + 1]));
dfs1(1, 0); N = sz[1]; dfs2(1, 0);
printf("%lld %d %d\n", s, s1, s2);
}
int main(){ memset(head, -1, sizeof head); memset(Head, -1, sizeof Head);
n = read();
for(int i = 1; i < n; ++i){
int x = read(), y = read();
add_edge(x, y, 1);
}
Dfs(1, 0, 0); init_lca();
m = read();
for(int i = 1; i <= m; ++i){
int x = read();
for(int j = 1; j <= x; ++j) vis[a[j] = read()] = 1;
solve(x);
for(int j = 1; j <= x; ++j) vis[a[j]] = 0;
}
return 0;
}
T3 洛谷 P3233 [HNOI2014]世界树
题目描述
世界树是一棵无比巨大的树,它伸出的枝干构成了整个世界。在这里,生存着各种各样的种族和生灵,他们共同信奉着绝对公正公平的女神艾莉森,在他们的信条里,公平是使世界树能够生生不息、持续运转的根本基石。
世界树的形态可以用一个数学模型来描述:世界树中有 nnn 个种族,种族的编号分别从 111 到 nnn ,分别生活在编号为 111 到 nnn 的聚居地上,种族的编号与其聚居地的编号相同。有的聚居地之间有双向的道路相连,道路的长度为 111 。保证连接的方式会形成一棵树结构,即所有的聚居地之间可以互相到达,并且不会出现环。定义两个聚居地之间的距离为连接他们的道路的长度;例如,若聚居地 aaa 和 bbb 之间有道路, bbb 和 ccc 之间有道路,因为每条道路长度为 111 而且又不可能出现环,所以 aaa 与 ccc 之间的距离为 222 。
于对公平的考虑,第 iii 年,世界树的国王需要授权 mim_imi 个种族的聚居地为临时议事处。对于某个种族 xxx ( xxx 为种族的编号),如果距离该种族最近的临时议事处为 yyy ( yyy 为议事处所在聚居地的编号),则种族 xxx 将接受 yyy 议事处的管辖(如果有多个临时议事处到该聚居地的距离一样,则 yyy 为其中编号最小的临时议事处)。
在国王想知道,在 qqq 年的时间里,每一年完成授权后,当年每个临时议事处将会管理多少个种族(议事处所在的聚居地也将接受该议事处管理)。 现在这个任务交给了以智慧著称的灵长类的你:程序猿。请帮国王完成这个任务吧。
输入输出格式
输入格式:
第一行为一个正整数n,表示世界树中种族的个数。接下来n-l行,每行两个正整数x,y,表示x聚居地与y聚居地之间有一条长度为1的双向道路。接下来一行为一个正整数q,表示国王询问的年数。接下来q块,每块两行:第i块的第一行为1个正整数m[i],表示第i年授权的临时议事处的个数。第i块的第二行为m[i]个正整数h[l]、h[2]、...、h[m[i]],表示被授权为临时议事处的聚居地编号(保证互不相同)。
输出格式:
输出包含q行,第i行为m[i]个整数,该行的第j(j=1,2...,,m[i])个数表示第i年被授权的聚居地h[j]的临时议事处管理的种族个数。
输入输出样例
输入样例#1:
10
2 1
3 2
4 3
5 4
6 1
7 3
8 3
9 4
10 1
5
2
6 1
5
2 7 3 6 9
1
8
4
8 7 10 3
5
2 9 3 5 8
输出样例#1:
1 9
3 1 4 1 1
10
1 1 3 5
4 1 3 1 1
说明
N<=300000, q<=300000,m[1]+m[2]+...+m[q]<=300000
思路:首先是虚树题。先考虑只有一次询问的情况,发现可以两次 dfs 求出里每个点最近的关键点。第一遍用儿子更新父亲,第二遍用父亲更新儿子。这是一次询问。考虑多次询问,建出虚树。考虑计算答案,对于虚树上的每条边统计答案,即这条边的端点分别能获得多少。考虑如果都是关键点,那么直接二分中间的分界点就好了,如果一个关键点一个非关键点的话,就二分离非关键点最近的关键点和这个关键点的路径,反正关键点一定在这个关键点和非关键点的路径上。嘴巴AC 。代码是看参考别人的。
#include<iostream>
#include<cstdio>
#include<cctype>
#include<cstring>
#include<algorithm>
#define maxn 300010
#define gc() getchar();
using namespace std;
int n, m;
int read(){
int x = 0; char c = gc();
while(!isdigit(c)) c = gc();
while(isdigit(c)){x = x * 10 + c - '0'; c = gc();}
return x;
}
struct Edge{
int to, next;
}e[maxn * 2], E[maxn * 2]; int c1, head[maxn], c2, Head[maxn];
inline void add_edge(int u, int v){
e[c1].to = v; e[c1].next = head[u]; head[u] = c1++;
e[c1].to = u; e[c1].next = head[v]; head[v] = c1++;
}
inline void Add_edge(int u, int v){
E[c2].to = v; E[c2].next = Head[u]; Head[u] = c2++;
E[c2].to = u; E[c2].next = Head[v]; Head[v] = c2++;
}
int sz[maxn], f[maxn][21], dep[maxn], id[maxn], c3;
void dfs(int u, int fa){
sz[u] = 1; id[u] = ++c3;
for(int i = head[u]; ~i; i = e[i].next){
int v = e[i].to; if(v == fa) continue;
dep[v] = dep[u] + 1; f[v][0] = u;
dfs(v, u); sz[u] += sz[v];
}
}
void init_lca(){
for(int j = 1; j <= 20; ++j)
for(int i = 1; i <= n; ++i) f[i][j] = f[f[i][j - 1]][j - 1];
}
int get_lca(int x, int y){
if(dep[x] < dep[y]) swap(x, y);
for(int i = 20; i >= 0; --i)
if(dep[f[x][i]] >= dep[y]) x = f[x][i];
if(x == y) return x;
for(int i = 20; i >= 0; --i)
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
inline int D(int x, int y){
return dep[x] + dep[y] - 2 * dep[get_lca(x, y)];
}
int bl[maxn], g[maxn];
void dfs1(int u, int fa){
g[u] = sz[u];
for(int i = Head[u]; ~i; i = E[i].next){
int v = E[i].to, d1, d2; if(v == fa) continue;
dfs1(v, u); if(!bl[v]) continue;
if(!bl[u]) bl[u] = bl[v];
else{
d1 = D(u, bl[u]); d2 = D(u, bl[v]);
if(d2 < d1 || (d1 == d2 && bl[v] < bl[u])) bl[u] = bl[v];
}
}
}
void dfs2(int u, int fa){
for(int i = Head[u]; ~i; i = E[i].next){
int v = E[i].to, d1, d2; if(v == fa) continue;
if(!bl[v]) bl[v] = bl[u];
else{
d1 = D(v, bl[v]); d2 = D(v, bl[u]);
if(d2 < d1 || (d1 == d2 && bl[u] < bl[v])) bl[v] = bl[u];
}
dfs2(v, u);
}
}
int p[maxn];
void work(int u, int v){
int son = v;
for(int i = 20; i >= 0; --i) if(dep[f[son][i]] > dep[u]) son = f[son][i];
g[u] -= sz[son];
if(bl[u] == bl[v]){p[bl[u]] += sz[son] - sz[v]; return ;}
int ans = v;
for(int i = 20; i >= 0; --i){
int x = f[ans][i]; if(dep[x] <= dep[u]) continue;
int d1 = D(x, bl[u]), d2 = D(x, bl[v]);
if(d1 > d2 || (d1 == d2 && bl[v] < bl[u])) ans = x;
}
p[bl[u]] += sz[son] - sz[ans];
p[bl[v]] += sz[ans] - sz[v];
}
void dfs3(int u, int fa){
for(int i = Head[u]; ~i; i = E[i].next){
int v = E[i].to; if(v == fa) continue;
work(u, v); dfs3(v, u);
}
}
inline bool cmp(int x, int y){return id[x] < id[y];}
int a[maxn], b[maxn], c4, st[maxn], top;
void solve(int m){
for(int i = 1; i <= m; ++i) bl[a[i]] = a[i], b[i] = a[i];
sort(a + 1, a + m + 1, cmp); c4 = m; Head[1] = -1; st[top = 1] = 1; if(a[1] != 1) b[++c4] = 1;
for(int i = 1; i <= m; ++i){ if(a[i] == 1) continue;
int lca = get_lca(a[i], st[top]);
if(lca != st[top]){
while(id[lca] < id[st[top - 1]]){
Add_edge(st[top], st[top - 1]);
--top;
}
if(id[lca] > id[st[top - 1]]){
Head[lca] = -1; b[++c4] = lca;
Add_edge(st[top], lca); st[top] = lca;
}
else Add_edge(st[top], lca), --top;
}
Head[a[i]] = -1; st[++top] = a[i];
}
for(int i = 1; i < top; ++i) Add_edge(st[i], st[i + 1]);
dfs1(1, 0); dfs2(1, 0);
dfs3(1, 0);
for(int i = 1; i <= c4; ++i) p[bl[b[i]]] += g[b[i]];
for(int i = 1; i <= m; ++i) printf("%d ", p[b[i]]); putchar('\n');
for(int i = 1; i <= c4; ++i) bl[b[i]] = p[b[i]] = g[b[i]] = 0;
}
int main(){
memset(head, -1, sizeof head); memset(Head, -1, sizeof Head);
n = read();
for(int i = 1; i < n; ++i){
int x = read(), y = read();
add_edge(x, y);
}
dep[1] = 1; dfs(1, 0); init_lca();
m = read();
for(int i = 1; i <= m; ++i){
int x = read();
for(int j = 1; j <= x; ++j) a[j] = read();
solve(x);
}
return 0;
}