【BZOJ2152】聪聪可可-树的点分治
测试地址:聪聪可可
做法:这题要求满足两点之间距离是3的倍数(也就是对3取模余数为0)的点对数目,我们知道求满足特定条件(一般和路径有关)的点对数目的问题一般可以用树分治算法解决。
和其他同类题目一样,这道题目的O(N^2)算法非常显然,然而N可达20000,O(N^2)无法通过全部数据。
怎么办呢?我们就从分治的角度来思考问题:以树的重心分治,路径情况分为过根和不过根两种,不过根的递归解决,这里只考虑过根的情况。因为我们要求的是距离为3的倍数且不属于同一棵子树的点对,问题可以转化成为:1.求距离为3的倍数的点对数目;2.求距离为3的倍数且属于同一棵子树的点对数目。问题1的答案减去问题2的答案就是原问题的答案。
我们发现问题1和问题2本质上都是相同的,因此我们只需思考这个问题:给定A,求使得A[i]+A[j]=0(mod 3)的(i,j)数目。因为一个数对3取模余数只有0,1,2三种,那么我们只要存储余数为这三种的点数目,记为sum[0...2],那么答案就是sum[0]*sum[0]+sum[1]*sum[2]+sum[2]*sum[1],问题解决的复杂度是O(N)。配合以树的重心分治,总复杂度大约是O(NlogN),可以通过全部数据。
以下是本人代码:
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define inf 1000000000
using namespace std;
int n,tot=0,first[20010]={0};
int a[20010],p[20010],fa[20010],siz[20010],sum[3],dis[20010];
int ans=0,ans2;
bool vis[20010];
struct edge {int v,d,next;} e[40010];
int gcd(int a,int b)
{
return (b==0)?a:gcd(b,a%b);
}
void insert(int a,int b,int d)
{
e[++tot].v=b,e[tot].d=d,e[tot].next=first[a],first[a]=tot;
}
void dfs(int v)
{
siz[v]=1;
a[++a[0]]=dis[v]%3;
p[a[0]]=v;
for(int i=first[v];i;i=e[i].next)
if (e[i].v!=fa[v]&&!vis[e[i].v])
{
fa[e[i].v]=v;
dis[e[i].v]=dis[v]+e[i].d;
dfs(e[i].v);
siz[v]+=siz[e[i].v];
}
}
int find(int v)
{
a[0]=0;fa[v]=0;dfs(v);
int s,mx=inf;
for(int i=1;i<=siz[v];i++)
{
int maxsiz=0;
for(int j=first[p[i]];j;j=e[j].next)
if (!vis[e[j].v]) maxsiz=max(maxsiz,siz[e[j].v]);
maxsiz=max(maxsiz,siz[v]-siz[p[i]]);
if (maxsiz<mx) s=p[i],mx=maxsiz;
}
return s;
}
int work(int v,int start)
{
a[0]=fa[v]=0;dis[v]=start;dfs(v);
sum[0]=sum[1]=sum[2]=0;
for(int i=1;i<=siz[v];i++)
sum[a[i]]++;
return sum[0]*sum[0]+2*sum[1]*sum[2];
}
void solve(int v)
{
v=find(v);
ans+=work(v,0);
vis[v]=1;
for(int i=first[v];i;i=e[i].next)
if (!vis[e[i].v]) ans-=work(e[i].v,dis[e[i].v]);
for(int i=first[v];i;i=e[i].next)
if (!vis[e[i].v]) solve(e[i].v);
}
int main()
{
scanf("%d",&n);
for(int i=1;i<n;i++)
{
int x,y;long long w;
scanf("%d%d%lld",&x,&y,&w);
w%=3;
insert(x,y,w),insert(y,x,w);
}
solve(1);
ans2=n*n;
int g=gcd(ans,ans2);
ans/=g,ans2/=g;
printf("%d/%d",ans,ans2);
return 0;
}