BZOJ 3697: 采药人的路径
把权值为 \(0\) 的边权值设为 \(-1\)
那么阴阳平衡的路径就是权值和为 \(0\) 的路径
考虑点分治统计路径数
如果没有中间点的限制,那么只需要统计每种路径权值出现的个数,相加之和为 \(0\) 的路径数用乘法原理统计即可
现在有了中间点的限制,那么就是根到该节点的路径上,出现了一个中间节点,该根到中间节点的权值和到该点的权值相同
那么就把路径分成了两种,前缀出现过该权值的,前缀没出现过该权值的
分开统计即可
#include <bits/stdc++.h>
#define pb push_back
#define fi first
#define se second
#define pii pair<int, int>
#define pli pair<ll, int>
#define SZ(x) ((int)(x).size())
#define lp p << 1
#define rp p << 1 | 1
#define mid ((l + r) / 2)
#define lowbit(i) ((i) & (-i))
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
#define rep(i,a,b) for(int i=(a);i<(b);i++)
#define per(i,a,b) for(int i=((b)-1);i>=(a);i--)
#define Edg int ccnt=1,head[N],to[E],ne[E];void addd(int u,int v){to[++ccnt]=v;ne[ccnt]=head[u];head[u]=ccnt;}void add(int u,int v){addd(u,v);addd(v,u);}
#define Edgc int ccnt=1,head[N],to[E],ne[E],c[E];void addd(int u,int v,int w){to[++ccnt]=v;ne[ccnt]=head[u];c[ccnt]=w;head[u]=ccnt;}void add(int u,int v,int w){addd(u,v,w);addd(v,u,w);}
#define es(u,i,v) for(int i=head[u],v=to[i];i;i=ne[i],v=to[i])
const int MOD = 1000000007,INF=0x3f3f3f3f;
const ll inf=0x3f3f3f3f3f3f3f3f;
void M(int &x) {if (x >= MOD)x -= MOD; if (x < 0)x += MOD;}
int qp(int a, int b = MOD - 2, int mod = MOD) {int ans = 1; for (; b; a = 1LL * a * a % mod, b >>= 1)if (b & 1)ans = 1LL * ans * a % mod; return ans % mod;}
template<class T>T gcd(T a, T b) { while (b) { a %= b; std::swap(a, b); } return a; }
template<class T>bool chkmin(T &a, T b) { return a > b ? a = b, 1 : 0; }
template<class T>bool chkmax(T &a, T b) { return a < b ? a = b, 1 : 0; }
char buf[1 << 21], *p1 = buf, *p2 = buf;
inline char getc() {
return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++;
}
inline int _() {
int x = 0, f = 1; char ch = getc();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getc(); }
while (ch >= '0' && ch <= '9') { x = x * 10ll + ch - 48; ch = getc(); }
return x * f;
}
const int N = 1e5 + 77, E = N * 2;
Edgc
bool vis[N];
int sz[N], maxsz[N], totsz, root, n;
void getroot(int u, int f) {
sz[u] = 1; maxsz[u] = 0;
es (u, i, v) {
if (v == f || vis[v]) continue;
getroot(v, u);
sz[u] += sz[v];
chkmax(maxsz[u], sz[v]);
}
chkmax(maxsz[u], totsz - sz[u]);
if (maxsz[u] < maxsz[root]) root = u;
}
void recalc(int u, int f) {
sz[u] = 1;
es (u, i, v) {
if (v == f || vis[v]) continue;
recalc(v, u);
sz[u] += sz[v];
}
}
int B[E], BB[E], A[E], AA[E], have[E];
int d[N];
ll ans;
void dfs(int u, int f) {
int x = d[u] + n;
if (have[x]) BB[x]++;
else B[x]++;
have[x]++;
es (u, i, v) {
if (v == f || vis[v]) continue;
d[v] = d[u] + c[i];
dfs(v, u);
}
have[x]--;
}
void solve(int u) {
vis[u] = 1;
recalc(u, 0);
es (u, i, v) {
if (vis[v]) continue;
rep (i, 0, sz[v] + 1) B[n - i] = B[n + i] = BB[n - i] = BB[n + i] = 0;
d[v] = c[i];
dfs(v, u);
ans += BB[n] + (ll)BB[n] * A[n] + (ll)B[n] * AA[n] + (ll)BB[n] * AA[n] + (ll)B[n] * A[n];
rep (i, 1, sz[v] + 1) {
ans += (ll)BB[n - i] * A[n + i] + (ll)B[n - i] * AA[n + i] + (ll)BB[n - i] * AA[n + i];
ans += (ll)BB[n + i] * A[n - i] + (ll)B[n + i] * AA[n - i] + (ll)BB[n + i] * AA[n - i];
}
A[n] += B[n]; AA[n] += BB[n];
rep (i, 1, sz[v] + 1) {
A[n - i] += B[n - i];
AA[n - i] += BB[n - i];
AA[n + i] += BB[n + i];
A[n + i] += B[n + i];
}
}
rep (i, 0, sz[u] + 1) A[n - i] = A[n + i] = AA[n - i] = AA[n + i] = 0;
es (u, i, v) {
if (vis[v]) continue;
totsz = maxsz[root = 0] = sz[v];
getroot(v, u);
solve(root);
}
}
int main() {
#ifdef LOCAL
freopen("ans.out", "w", stdout);
#endif
n = _();
rep (i, 1, n) {
int u = _(), v = _(), c = _();
if (c == 0) c = -1;
add(u, v, c);
}
totsz = maxsz[root = 0] = n;
getroot(1, 0);
solve(root);
printf("%lld\n", ans);
#ifdef LOCAL
printf("%.10f\n", (db)clock() / CLOCKS_PER_SEC);
#endif
return 0;
}