CF1009F Dominant Indices——长链剖分优化DP
原题链接
\(EDU\)出一道长链剖分优化\(dp\)裸题?
简化版题意
问你每个点的子树中与它距离为多少的点的数量最多,如果有多解,最小化距离
思路
方法1.
用\(dsu\ on\ tree\)做到\(O(nlogn)\)
方法2.
考虑\(dp\),也就是设\(f[u][d]\)表示以\(u\)为根的子树中有多少个点与它的距离为\(j\),则转移如下:
\(f[u][0]=1\),\(f[u][d]+=f[v][d-1]\)
发现可以直接通过把数组右移直接把一个儿子的信息继承过来,又因为转移是跟深度相关的,那么我们直接把长儿子的信息继承过来就好了,然后暴力合并短儿子的信息
这样的时间复杂度都是\(O(n)\)的,怎么证明?直接继承长儿子的信息通过指针可以做到\(O(1)\),然后每条长链只会在顶端被合并,而长链的长度和是\(O(n)\),于是总复杂度就\(O(n)\)啦
空间复杂度的证明同理
代码如下
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <string>
#include <vector>
#include <cmath>
#include <ctime>
#include <queue>
#include <map>
#include <set>
using namespace std;
#define IINF 0x3f3f3f3f3f3f3f3fLL
#define ull unsigned long long
#define pii pair<int, int>
#define uint unsigned int
#define mii map<int, int>
#define lbd lower_bound
#define ubd upper_bound
#define INF 0x3f3f3f3f
#define vi vector<int>
#define ll long long
#define mp make_pair
#define pb push_back
#define N 1000000
struct Edge {
int next, to;
}e[2*N+5];
int n;
int head[N+5], eid, len[N+5], longson[N+5];
int memory[N+5], ans[N+5];
void addEdge(int from, int to) {
e[++eid].next = head[from];
e[eid].to = to;
head[from] = eid;
}
void dfs1(int u, int fa) {
for(int i = head[u], v; i; i = e[i].next) {
v = e[i].to;
if(v == fa) continue;
dfs1(v, u);
if(len[longson[u]] < len[v]) longson[u] = v;
}
len[u] = len[longson[u]]+1;
}
void dp(int u, int fa, int *f) {
ans[u] = 0;
f[0] = 1;
int *g;
if(longson[u]) {
g = f+1;
dp(longson[u], u, g);
if(g[ans[longson[u]]] > f[ans[u]] || (g[ans[longson[u]]] == f[ans[u]] && ans[longson[u]] < ans[u]))
ans[u] = ans[longson[u]]+1;
}
g = f+len[u];
for(int i = head[u], v; i; i = e[i].next) {
v = e[i].to;
if(v == fa || v == longson[u]) continue;
dp(v, u, g);
for(int j = 1; j <= len[v]; ++j) {
f[j] += g[j-1];
if(f[j] > f[ans[u]] || (f[j] == f[ans[u]] && j < ans[u]))
ans[u] = j;
}
}
}
int main() {
scanf("%d", &n);
for(int i = 1, x, y; i < n; ++i) {
scanf("%d%d", &x, &y);
addEdge(x, y), addEdge(y, x);
}
dfs1(1, 0);
dp(1, 0, memory);
for(int i = 1; i <= n; ++i) printf("%d\n", ans[i]);
return 0;
}