P3085 [USACO13OPEN] Yin and Yang G
题目大意
给出一个边权为 \(1\) 或 \(-1\) 的树,求除出有多少条 \(s\rightarrow x\rightarrow t\) 的路径使得 \(s\rightarrow x\) 的边权和与 \(x\rightarrow t\) 的边权和均为 \(0\)。
\(n\leq 10^5\)
思路
考虑点分治。首先设 \([a,b]\) 表示 \(a\rightarrow b\) 的边权和。若 \([s,t]=0\),且路径上有 \([s,x]=0\) 或 \([x,t]=0\) 则路径一定合法。
维护点分治。考虑到根 \(u\),按子树外和子树内考虑。分别记录 \([v,u]=d\) 的路径上是否有点 \(x\) 使 \([v,x]=0\),然后开两个数组记录有和没有的数量。记 \(o_{d,0}\) 表示子树外没有 \([v,x]=0\),\(o_{d,1}\) 表示有。 \(f_{d,0}\) 表示子树内没有 \([v,x]=0\),\(f_{d,1}\) 表示有。
之后考虑子树内,则枚举每一个可能出现的 \(d\),然后这几种情况均合法:
- \(f_{d,0}\times o_{d,1}\)
- \(f_{d,1}\times o_{d,0}\)
- \(f_{d,1}\times o_{d,1}\)
- \(f_{0,0}\times o_{0,0}\)
- \(f_{0,1}\)
累加即可。
本题错点:在点分治记录重心时的减法是当前能访问的节点个数大小,而不是 \(n\)。
代码
#include <bits/stdc++.h>
#define endl "\n"
#define f(x,y) ff[x+200000][y]
#define o(x,y) oo[x+200000][y]
#define vis(x) visvis[x+200000]
using namespace std;
typedef long long ll;
const ll MAXN=2e6+5;
ll ff[MAXN*2][2],oo[MAXN*2][2],visvis[MAXN*2];
ll n;
struct edge{
ll v,w;
};
vector<edge>adj[MAXN];
ll sz[MAXN],core;
bool block[MAXN];
ll ans=0,mx[MAXN];
void gc(ll u,ll fa,ll Sz){
sz[u]=1;
mx[u]=0;
for(auto e:adj[u]){
ll v=e.v;
if(block[v]||v==fa){
continue;
}
gc(v,u,Sz);
sz[u]+=sz[v];
mx[u]=max(mx[u],sz[v]);
}
mx[u]=max(mx[u],Sz-sz[u]);
if(mx[u]<mx[core]){
core=u;
}
}
ll mv;
void dfs(ll u,ll fa,ll dis){
mv=max(mv,abs(dis));
if(vis(dis)){
f(dis,1)++;
}else{
f(dis,0)++;
}
vis(dis)++;
for(auto e:adj[u]){
ll v=e.v,w=e.w;
if(block[v]||v==fa){
continue;
}
dfs(v,u,dis+w);
}
vis(dis)--;
}
void solve(ll u,ll fa,ll Sz){
if(adj[u].size()==1&&block[adj[u][0].v]){
return;
}
core=0;
gc(u,fa,Sz);
u=core;
ll Mv=0;
for(auto e:adj[u]){
ll v=e.v,w=e.w;
if(block[v]){
continue;
}
mv=0;
dfs(v,u,w);
Mv=max(Mv,mv);
ans+=f(0,1);
ans+=f(0,0)*o(0,0);
for(int V=-mv;V<=mv;++V){
ans+=f(V,1)*o(-V,0);
ans+=f(V,0)*o(-V,1);
ans+=f(V,1)*o(-V,1);
}
for(int V=-mv;V<=mv;++V){
o(V,0)+=f(V,0);
o(V,1)+=f(V,1);
f(V,0)=f(V,1)=0;
}
}
for(int v=-Mv;v<=Mv;++v){
o(v,0)=o(v,1)=0;
}
block[u]=true;
for(auto e:adj[u]){
ll v=e.v,w=e.w;
if(block[v]){
continue;
}
solve(v,u,sz[v]);
}
}
signed main(){
mx[0]=1e18;
ios::sync_with_stdio(false);
cin.tie(0);cout.tie(0);
cin>>n;
for(int i=1;i<n;++i){
ll u,v,w;
cin>>u>>v>>w;
if(w==0){
w=-1;
}
adj[u].push_back({v,w});
adj[v].push_back({u,w});
}
solve(1,0,n);
cout<<ans<<endl;
return 0;
}