[BZOJ3697]采药人的路径

[BZOJ3697]采药人的路径

试题描述

采药人的药田是一个树状结构,每条路径上都种植着同种药材。
采药人以自己对药材独到的见解,对每种药材进行了分类。大致分为两类,一种是阴性的,一种是阳性的。
采药人每天都要进行采药活动。他选择的路径是很有讲究的,他认为阴阳平衡是很重要的,所以他走的一定是两种药材数目相等的路径。采药工作是很辛苦的,所以他希望他选出的路径中有一个可以作为休息站的节点(不包括起点和终点),满足起点到休息站和休息站到终点的路径也是阴阳平衡的。他想知道他一共可以选择多少种不同的路径。

输入

第1行包含一个整数N。
接下来N-1行,每行包含三个整数a_i、b_i和t_i,表示这条路上药材的类型。

输出

输出符合采药人要求的路径数目。

输入示例

7
1 2 0
3 1 1
2 4 0
5 2 0
6 3 1
5 7 1

输出示例

1

数据规模及约定

对于100%的数据,N ≤ 100,000。

题解

很明显这是一个求点对数的问题,所以想到点分治。那么对于跨重心的合法链我们怎么求呢?

首先考虑没有休息点的情况,那么显然我们可以把权值 0 看成 -1 统计一下每条从重心往下搜的链的权值和,那么对于一个子树 i,我们设权值和为 x 的链有 f[x] 条,那么我们只需要找到之前所有子树中权值和为 -x 的链的条数(设它为 g[-x])那么权值 x 对答案的贡献为 f[x] * g[-x],因为权值种数总是与子树大小相关的,所以直接暴力累加对于所有的 x 的贡献即可。

那么现在考虑上有休息点的情况。显然休息点可以在重心、重心左边或是重心右边。所以现在问题的关键在于如何得到 f[0~1][x],f[0][x] 表示起点为子树中的一个节点,终点为重心,权值和为 x 且没有休息点的链的条数,f[1][x] 表示起点为子树中的一个节点,终点为重心,权值和为 x 且有休息点的链的条数。(请慢慢理解。。。)仔细想想发现若是有休息点,那么子树中那个节点到休息点的权值和一定等于 0,意味着休息点到重心的权值和等于 x,那么这个事情就很好统计了。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;

int read() {
	int x = 0, f = 1; char c = getchar();
	while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); }
	while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); }
	return x * f;
}

#define maxn 100010
#define maxm 200010
#define LL long long
int n, m, head[maxn], next[maxm], to[maxm], dist[maxm];
LL ans;

void AddEdge(int a, int b, int c) {
	to[++m] = b; dist[m] = c; next[m] = head[a]; head[a] = m;
	swap(a, b);
	to[++m] = b; dist[m] = c; next[m] = head[a]; head[a] = m;
	return ;
}

bool vis[maxn];
int root, size, f[maxn], siz[maxn];
void getroot(int u, int fa) {
	siz[u] = 1; f[u] = 0;
	for(int e = head[u]; e; e = next[e]) if(to[e] != fa && !vis[to[e]]) {
		getroot(to[e], u);
		siz[u] += siz[to[e]];
		f[u] = max(f[u], siz[to[e]]);
	}
	f[u] = max(f[u], size - siz[u]);
	if(f[root] > f[u]) root = u;
	return ;
}
int has[maxn<<1], A[2][maxn<<1], B[2][maxn<<1], mxd, mnd;
void dfs(int u, int fa, int d) {
//	printf("(d)%d(%d) ", d, has[d+n]);
	mxd = max(mxd, d); mnd = min(mnd, d);
	A[has[d+n]?1:0][d+n]++;
	has[d+n]++;
	for(int e = head[u]; e; e = next[e]) if(to[e] != fa && !vis[to[e]])
		dfs(to[e], u, d + dist[e]);
	has[d+n]--;
	return ;
}
void solve(int u) {
//	printf("u: %d\n", u);
	vis[u] = 1;
	bool fir = 1;
	int Mxd = -n - 1, Mnd = n + 1;
	for(int e = head[u]; e; e = next[e]) if(!vis[to[e]]) {
		mxd = -n - 1; mnd = n + 1;
		dfs(to[e], u, dist[e]);
		Mxd = max(Mxd, mxd); Mnd = min(Mnd, mnd);
		if(fir) ;
		else {
			ans += (LL)A[0][n] * B[0][n];
//			printf("%d ", A[0][n] * B[0][n]);
			for(int i = n + mnd; i <= n + mxd; i++) {
				int d = i - n;
				ans += (LL)A[0][i] * B[1][n-d] + A[1][i] * B[0][n-d] + A[1][i] * B[1][n-d];
//				printf("(%d)%d ", d, A[0][i] * B[1][n-d] + A[1][i] * B[0][n-d] + A[1][i] * B[1][n-d]);
			}
		}
		ans += (LL)A[1][n];
		fir = 0;
		for(int i = n + mnd; i <= n + mxd; i++)
			B[0][i] += A[0][i], B[1][i] += A[1][i], A[0][i] = A[1][i] = 0, has[i] = 0;
//		putchar('\n');
	}
	for(int i = n + Mnd; i <= n + Mxd; i++) B[0][i] = B[1][i] = 0;
	for(int e = head[u]; e; e = next[e]) if(!vis[to[e]]) {
		root = 0; f[0] = n + 1; size = siz[u]; getroot(to[e], u);
		solve(root);
	}
	return ;
}

int main() {
	n = read();
	for(int i = 1; i < n; i++) {
		int a = read(), b = read(), c = read() ? 1 : -1;
		AddEdge(a, b, c);
	}
	
	root = 0; f[0] = n + 1; size = n; getroot(1, 0);
	solve(root);
	
	printf("%lld\n", ans);
	
	return 0;
}

 

posted @ 2016-09-18 22:15  xjr01  阅读(346)  评论(0编辑  收藏  举报