【BZOJ2152】聪聪可可-树的点分治

测试地址:聪聪可可

做法:这题要求满足两点之间距离是3的倍数(也就是对3取模余数为0)的点对数目,我们知道求满足特定条件(一般和路径有关)的点对数目的问题一般可以用树分治算法解决。

和其他同类题目一样,这道题目的O(N^2)算法非常显然,然而N可达20000,O(N^2)无法通过全部数据。

怎么办呢?我们就从分治的角度来思考问题:以树的重心分治,路径情况分为过根和不过根两种,不过根的递归解决,这里只考虑过根的情况。因为我们要求的是距离为3的倍数且不属于同一棵子树的点对,问题可以转化成为:1.求距离为3的倍数的点对数目;2.求距离为3的倍数且属于同一棵子树的点对数目。问题1的答案减去问题2的答案就是原问题的答案。

我们发现问题1和问题2本质上都是相同的,因此我们只需思考这个问题:给定A,求使得A[i]+A[j]=0(mod 3)的(i,j)数目。因为一个数对3取模余数只有0,1,2三种,那么我们只要存储余数为这三种的点数目,记为sum[0...2],那么答案就是sum[0]*sum[0]+sum[1]*sum[2]+sum[2]*sum[1],问题解决的复杂度是O(N)。配合以树的重心分治,总复杂度大约是O(NlogN),可以通过全部数据。

以下是本人代码:

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define inf 1000000000
using namespace std;
int n,tot=0,first[20010]={0};
int a[20010],p[20010],fa[20010],siz[20010],sum[3],dis[20010];
int ans=0,ans2;
bool vis[20010];
struct edge {int v,d,next;} e[40010];
 
int gcd(int a,int b)
{
  return (b==0)?a:gcd(b,a%b);
}
 
void insert(int a,int b,int d)
{
  e[++tot].v=b,e[tot].d=d,e[tot].next=first[a],first[a]=tot;
}
 
void dfs(int v)
{
  siz[v]=1;
  a[++a[0]]=dis[v]%3;
  p[a[0]]=v;
  for(int i=first[v];i;i=e[i].next)
    if (e[i].v!=fa[v]&&!vis[e[i].v])
    {
      fa[e[i].v]=v;
      dis[e[i].v]=dis[v]+e[i].d;
      dfs(e[i].v);
      siz[v]+=siz[e[i].v];
    }
}
 
int find(int v)
{
  a[0]=0;fa[v]=0;dfs(v);
  int s,mx=inf;
  for(int i=1;i<=siz[v];i++)
  {
    int maxsiz=0;
    for(int j=first[p[i]];j;j=e[j].next)
      if (!vis[e[j].v]) maxsiz=max(maxsiz,siz[e[j].v]);
    maxsiz=max(maxsiz,siz[v]-siz[p[i]]);
    if (maxsiz<mx) s=p[i],mx=maxsiz;
  }
  return s;
}
 
int work(int v,int start)
{
  a[0]=fa[v]=0;dis[v]=start;dfs(v);
  sum[0]=sum[1]=sum[2]=0;
  for(int i=1;i<=siz[v];i++)
    sum[a[i]]++;
  return sum[0]*sum[0]+2*sum[1]*sum[2];
}
 
void solve(int v)
{
  v=find(v);
  ans+=work(v,0);
  vis[v]=1;
  for(int i=first[v];i;i=e[i].next)
    if (!vis[e[i].v]) ans-=work(e[i].v,dis[e[i].v]);
  for(int i=first[v];i;i=e[i].next)
    if (!vis[e[i].v]) solve(e[i].v);
}
 
int main()
{
  scanf("%d",&n);
  for(int i=1;i<n;i++)
  {
    int x,y;long long w;
    scanf("%d%d%lld",&x,&y,&w);
    w%=3;
    insert(x,y,w),insert(y,x,w);
  }
   
  solve(1);
  ans2=n*n;
  int g=gcd(ans,ans2);
  ans/=g,ans2/=g;
  printf("%d/%d",ans,ans2);
   
  return 0;
}


posted @ 2017-04-27 17:18  Maxwei_wzj  阅读(86)  评论(0编辑  收藏  举报