[BZOJ 4543] [POI2014]Hotel加强版

[POI2014]Hotel加强版

参考博客

https://blog.bill.moe/bzoj4543-hotel/

BZOJ 4543

1

题目大意

给出 \(n\) 个点的树,求有多少个元素互不相同的无序三元组 \((a, b, c)\) 满足两两之间树上距离相等

数据范围

\(n\le100000\)

时空限制

10sec,128MB

分析

首先设 \(f(i,j)\)\(i\) 的子树中距离 \(i\)\(j\) 的子树个数,设 \(g(i,j)\)\(i\) 的子树中存在两点的 \(lca\) 与它们距离皆为 \(d\) ,且\(lca\) 距离 \(i\)\(d-j\) 的方案数,那么 \(f(i,j)\times g(i,j)\) 就是答案,转移复杂度 \(O(n^2)\) 因为答案只与深度有关,用长链剖分优化至 \(O(n)\)

Code

#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
inline char nc() {
	static char buf[100000], *l = buf, *r = buf;
	return l==r&&(r=(l=buf)+fread(buf,1,100000,stdin),l==r)?EOF:*l++;
}
template<class T> void read(T & x) {
	x = 0; int f = 1, ch = nc();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=nc();}
	while(ch>='0'&&ch<='9'){x=x*10-'0'+ch;ch=nc();}
	x *= f;
}
typedef long long ll;
const int maxn = 100000 + 5;
const int maxe = maxn * 2;
const int maxnode = maxn * 6;
int n; ll an;
int head[maxn], ecnt;
int len[maxn], son[maxn];
ll temp[maxnode], * f[maxn], * g[maxn], * now = temp;
struct edge {
	int to, nex;
	edge(int to=0, int nex=0) : to(to), nex(nex) {}
} G[maxe];
inline void addedge(int u, int v ) {
	G[ecnt] = edge(v, head[u]), head[u] = ecnt++;
	G[ecnt] = edge(u, head[v]), head[v] = ecnt++;
}
void dfs(int u, int fa) {
	son[u] = -1;
	for(int i = head[u]; ~ i; i = G[i].nex) {
		int v = G[i].to; if(v == fa) continue;
		dfs(v, u);
		len[u] = max(len[u], len[v] + 1);
		if(son[u] == -1 || len[v] > len[son[u]]) {
			son[u] = v;
		}
	}
}
int tim;
void dp(int u, int fa) {
	if(~ son[u]) {
		int v = son[u];
		f[v] = f[u] + 1;
		g[v] = g[u] - 1;
		dp(v, u);
	}
	f[u][0] = 1; an += f[u][0] * g[u][0];
	for(int i = head[u]; ~ i; i = G[i].nex) {
		int v = G[i].to; if(v == fa || v == son[u]) continue;
		f[v] = now, now += (len[v] + 1) << 1;
		g[v] = now, now += (len[v] + 1) << 1;
		dp(v, u);
		for(int j = 1; j <= len[v]; ++j) an += g[v][j] * f[u][j - 1];
		for(int j = 0; j <= len[v]; ++j) an += f[v][j] * g[u][j + 1];
		for(int j = 1; j <= len[v]; ++j) g[u][j - 1] += g[v][j];
		for(int j = 0; j <= len[v]; ++j) g[u][j + 1] += f[v][j] * f[u][j + 1];
		for(int j = 0; j <= len[v]; ++j) f[u][j + 1] += f[v][j];
	}
}
void solve() {
	dfs(1, 0);
	f[1] = now, now += (len[1] + 1) << 1;
	g[1] = now, now += (len[1] + 1) << 1;
	dp(1, 0);
	cout << an << endl;
}
int main() {
//	freopen("2.txt", "r", stdin);
	read(n);
	memset(head, -1, sizeof(head));
	for(int i = 1; i < n; ++i) {
		int u, v; read(u), read(v);
		addedge(u, v);
	}
	solve();
	return 0;
}
posted @ 2019-01-10 15:22  LJZ_C  阅读(145)  评论(0编辑  收藏  举报