luogu 3565 bzoj 3522 & bzoj 4543
hotel解题报告
1 方法1
-
我们可以用\(down[i][j]\)表示在\(i\)的子树里面距离为\(j\)的节点的个数,\(up[i][j]\)表示通过\(i\)的父亲走到的距离为\(j\)的点的个数。
\[down[i][j]=\sum_{all\_son}down[son][j-1] \]\[up[i][j]=up[fa][j-1]+down[fa][j-1]-down[i][j-2] \]对于每个\(x\),对于它的每个深度\(i\)我们可以对每个儿子\(son\)执行:
\[ans+=t*down[son][i-1] \]\[t+=sum*down[son][i-1] \]\[sum+=down[son][i-1] \]最后我们做:
\[ans+=t*up[x][i] \] -
时间复杂度\(o(n^2)\)。
-
#include <bits/stdc++.h> using namespace std; int const N = 5000 + 10; #define ll long long struct edge { int to, nt; } e[N << 1]; int h[N], cnt, n; short up[N][N], down[N][N]; ll ans; void add(int a, int b) { e[++cnt].to = b; e[cnt].nt = h[a]; h[a] = cnt; } void dfs(int x, int fa) { down[x][0] = 1; for (int i = h[x]; i; i = e[i].nt) { int v = e[i].to; if (v == fa) continue; dfs(v, x); for (int j = 0; j < n; j++) down[x][j + 1] += down[v][j]; } } void dfs2(int x, int fa) { up[x][0] = 1; if (fa) up[x][1] = 1; for (int i = 2; i <= n; i++) { up[x][i] = up[fa][i - 1]; if (fa) up[x][i] += down[fa][i - 1] - down[x][i - 2]; } for (int i = h[x]; i; i = e[i].nt) { int v = e[i].to; if (v == fa) continue; dfs2(v, x); } } void dfs3(int x, int fa) { for (int i = 1; i <= n; i++) { int sum = 0, t = 0; for (int j = h[x]; j; j = e[j].nt) { int v = e[j].to; if (v == fa) continue; ans += t * down[v][i - 1]; t += sum * down[v][i - 1]; sum += down[v][i - 1]; } ans += t * up[x][i]; } for (int i = h[x]; i; i = e[i].nt) { int v = e[i].to; if (v == fa) continue; dfs3(v, x); } } int main() { scanf("%d", &n); for (int i = 1; i < n; i++) { int x, y; scanf("%d%d", &x, &y); add(x, y); add(y, x); } dfs(1, 0); dfs2(1, 0); dfs3(1, 0); cout << ans << endl; return 0; }
2 方法2
-
我们可以减少内存消耗,不用定义二维数组,我们每次只要考虑对于\(x\)来说的3个子孙就可以,不用考虑\(x\)的父亲的关系。我们枚举任何一个点都可以作为根节点。
-
这样的时间复杂度仍旧是\(O(n^2)\),但是空间复杂度可以优化到\(O(n)\)。
-
#include <bits/stdc++.h> using namespace std; int const N = 5000 + 10; #define ll long long struct edge { int to, nt; } e[N << 1]; int h[N], cnt, n, tot[N], tmp[N], dep[N], num[N]; ll ans; void add(int a, int b) { e[++cnt].to = b; e[cnt].nt = h[a]; h[a] = cnt; } void dfs(int x, int fa, int d) { dep[x] = d; for (int i = h[x]; i; i = e[i].nt) { int v = e[i].to; if (v == fa) continue; dfs(v, x, d + 1); dep[x] = max(dep[x], dep[v]); } } void dfs2(int x, int fa, int d) { tmp[d]++; for (int i = h[x]; i; i = e[i].nt) { int v = e[i].to; if (v == fa) continue; dfs2(v, x, d + 1); } } int main() { scanf("%d", &n); for (int i = 1; i < n; i++) { int x, y; scanf("%d%d", &x, &y); add(x, y); add(y, x); } for (int i = 1; i <= n; i++) { dfs(i, 0, 0); memset(tot, 0, sizeof(tot)); memset(num, 0, sizeof(num)); for (int j = h[i]; j; j = e[j].nt) { int v = e[j].to; for (int k = 1; k <= dep[v]; k++) tmp[k] = 0; dfs2(v, i, 1); for (int k = 1; k <= dep[v]; k++) { ans += num[k] * tmp[k]; num[k] += tot[k] * tmp[k]; tot[k] += tmp[k]; } } } cout << ans << endl; return 0; }