「HEOI2014」大工程
知识点:虚树,DP
我个人十分痛恨这种多合一的题目。
*这简直野蛮至极*
简述
给定一棵 \(n\) 个节点的树,边权均为 1。
给定 \(m\) 次询问,每次给定 \(k\) 个关键点,求 \(k\) 个点对之间的路径长度和、最短路径长度、最长路径长度。
\(1\le n\le 10^6\),\(1\le m\le 5\times 10^4\),\(\sum k \le 2\times n\)。
2S,256MB。
分析
先建立虚树,维护各点的深度,之后简单 DP。
第 2、3 问简单维护子树内关键点到根的最长链/最短链即可,考虑如何做第 1 问。
设 \(f_u\) 表示以 \(u\) 为根的子树内关键点对的路径长度之和,\(g_u\) 表示以 \(u\) 为根的子树内关键节点到 \(u\) 的距离之和,\(\operatorname{size}_u\) 表示以 \(u\) 为根的子树内关键节点的个数。
转移时分路径在子树内/跨越根节点讨论,则有显然的转移方程:
\[\begin{aligned}
f_u &= \sum_{v\in son'_u} f_v + (g_v + \operatorname{size}_v\times \operatorname{dis}(u, v))\times (\operatorname{size}_u - \operatorname{size}_v)\\
g_u &= \sum_{v\in son'_u} g_v + \operatorname{size}_v \times \operatorname{dis}(u,v)\\
\operatorname{size}_u &= [u\text{ is a key node}] + \sum_{v\in son'_u} \operatorname{size}_v
\end{aligned}\]
其中 \(\operatorname{dis}(u,v) = \operatorname{dep}_v - \operatorname{dep}_u\)。
代码实现中使用了树链剖分,总复杂度 \(O(\sum k\log n)\) 级别。
细节比较多。
代码
//知识点:虚树
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <vector>
#define LL long long
const int kN = 1e6 + 10;
const LL kInf = 1e15 + 2077;
//=============================================================
int n, q, k;
int e_num, head[kN], v[kN << 1], ne[kN << 1];
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Chkmax(LL &fir, LL sec) {
if (sec > fir) fir = sec;
}
void Chkmin(LL &fir, LL sec) {
if (sec < fir) fir = sec;
}
void Add(int u_, int v_) {
v[++ e_num] = v_, ne[e_num] = head[u_], head[u_] = e_num;
}
namespace Cut {
const int kMaxNode = kN;
int fa[kMaxNode], dep[kMaxNode], siz[kMaxNode];
int dfn_num, dfn[kN], son[kMaxNode], top[kMaxNode];
void Dfs1(int u_, int fa_) {
fa[u_] = fa_, dfn[u_] = ++ dfn_num, siz[u_] = 1, dep[u_] = dep[fa_] + 1;
for (int i = head[u_]; i; i = ne[i]) {
int v_ = v[i];
if (v_ == fa_) continue;
Dfs1(v_, u_);
if (siz[v_] > siz[son[u_]]) son[u_] = v_;
siz[u_] += siz[v_];
}
}
void Dfs2(int u_, int top_) {
top[u_] = top_;
if (son[u_]) Dfs2(son[u_], top_);
for (int i = head[u_]; i; i = ne[i]) {
int v_ = v[i];
if (v_ != son[u_] && v_ != fa[u_]) Dfs2(v_, v_);
}
}
int Lca(int u_, int v_) {
for (; top[u_] != top[v_]; u_ = fa[top[u_]]) {
if (dep[top[u_]] < dep[top[v_]]) std::swap(u_, v_);
}
return dep[u_] < dep[v_] ? u_ : v_;
}
}
namespace VT { //Virtual Tree
#define dep Cut::dep
const int kMaxNode = kN;
int top, node[kMaxNode], st[kMaxNode], tag[kMaxNode];
LL f1[kMaxNode], f2[kMaxNode], f3[kMaxNode];
LL sumdis[kMaxNode], maxdis[kMaxNode], mindis[kMaxNode], siz[kMaxNode];
std::vector <int> newv[kMaxNode];
bool CMP(int fir_, int sec_) {
return Cut::dfn[fir_] < Cut::dfn[sec_];
}
void Push(int u_) {
int lca = Cut::Lca(u_, st[top]);
for (; dep[st[top - 1]] > dep[lca]; -- top) {
newv[st[top - 1]].push_back(st[top]);
}
if (lca != st[top]) {
newv[lca].push_back(st[top]); -- top;
if (lca != st[top]) st[++ top] = lca;
}
if (st[top] != u_) st[++ top] = u_;
}
void Build(int siz_) {
for (int i = 1; i <= siz_; ++ i) {
node[i] = read();
tag[node[i]] = 1;
}
std::sort(node + 1, node + siz_ + 1, CMP);
st[top = 0] = 1;
for (int i = 1; i <= siz_; ++ i) Push(node[i]);
for (; top; -- top) newv[st[top - 1]].push_back(st[top]);
}
void Dfs(int u_) {
f1[u_] = f3[u_] = 0, f2[u_] = kInf;
sumdis[u_] = 0, maxdis[u_] = tag[u_] ? 0 : -kInf, mindis[u_] = tag[u_] ? 0 : kInf;
siz[u_] = tag[u_];
for (int i = 0, lim = newv[u_].size(); i < lim; ++ i) {
int v_ = newv[u_][i];
LL dis = dep[v_] - dep[u_];
Dfs(v_);
siz[u_] += siz[v_];
sumdis[u_] += sumdis[v_] + siz[v_] * dis;
Chkmin(mindis[u_], mindis[v_] + dis);
Chkmax(maxdis[u_], maxdis[v_] + dis);
Chkmin(f2[u_], f2[v_]);
Chkmax(f3[u_], f3[v_]);
}
LL maxv = -1, maxvv = -1, minv = kInf, minvv = kInf;
if (tag[u_]) maxv = minv = 0;
for (int i = 0, lim = newv[u_].size(); i < lim; ++ i) {
int v_ = newv[u_][i];
LL dis = dep[v_] - dep[u_];
f1[u_] += f1[v_] + (sumdis[v_] + siz[v_] * dis) * (siz[u_] - siz[v_]);
if (maxdis[v_] + dis >= maxv) maxvv = maxv, maxv = maxdis[v_] + dis;
else if (maxdis[v_] + dis > maxvv) maxvv = maxdis[v_] + dis;
if (mindis[v_] + dis <= minv) minvv = minv, minv = mindis[v_] + dis;
else if (mindis[v_] + dis < minvv) minvv = mindis[v_] + dis;
}
if (minv != kInf && minvv != kInf) Chkmin(f2[u_], minv + minvv);
if (maxv != -1 && maxvv != -1) Chkmax(f3[u_], maxv + maxvv);
tag[u_] = 0;
newv[u_].clear();
}
void Solve(int siz_) {
Build(siz_);
Dfs(1);
printf("%lld %lld %lld\n", f1[1], f2[1], f3[1]);
}
}
//=============================================================
int main() {
n = read();
for (int i = 1; i < n; ++ i) {
int u_ = read(), v_ = read();
Add(u_, v_), Add(v_, u_);
}
Cut::Dfs1(1, 0), Cut::Dfs2(1, 1);
int q = read();
while (q --) {
k = read();
VT::Solve(k);
}
return 0;
}
作者@Luckyblock,转载请声明出处。