LuoguP4178 Tree 解题报告 (点分治)

LuoguP4178 Tree

题意

给定一个 \(n\) 个点的带权树 (边权为正整数), 求树中距离小于等于 \(k\) 的点对数量.

\(1 \le n \le 4\times10^4, k \le 2\times10^4\)

思路

点分治.

对树上的点到当前根节点的距离建一个桶, 并在这个桶上建树状数组.

统计点对数量时在树状数组上查找 \(k-dis[u]\) 的前缀和即可.

时间复杂度 \(O(n\log^2 n)\).

其实不用树状数组也可以, 因为边权为正, 所以越往下, \(dis\) 越大,

那么只需要维护一个指针 \(t=k-dis[u]\), 指针往左移动的时候, 减去指针所指位置的节点数量, 剩下的就是满足 \(dis[v]+dis[u] \le k\) 的点.

时间复杂度 $O(n\log n) $.

但是在洛谷上跑出来反而更慢了....

代码

\(O(n\log^2 n)\)

#include<bits/stdc++.h>
using namespace std;
const int _=2e4+7;
const int __=4e4+7;
const int ___=8e4+7;
const int inf=0x3f3f3f3f;
int n,k,dis[__],sz[__],rt,minx=inf,ans,q[__],top;
int lst[__],nxt[___],to[___],len[___],tot;
int c[_];
bool vis[__];
void add(int x,int y,int w){ nxt[++tot]=lst[x]; to[tot]=y; len[tot]=w; lst[x]=tot; }
void pre(int u,int fa){
  sz[u]=1;
  for(int i=lst[u];i;i=nxt[i]){
    int v=to[i];
    if(v==fa||vis[v]) continue;
    pre(v,u);
    sz[u]+=sz[v];
  }
}
void g_rt(int u,int fa,int sum){
  int maxn=sum-sz[u];
  for(int i=lst[u];i;i=nxt[i]){
    int v=to[i];
    if(v==fa||vis[v]) continue;
    g_rt(v,u,sum);
    maxn=max(maxn,sz[v]);
  }
  if(maxn<minx){ minx=maxn; rt=u; }
}
void modify(int x,int v){
  for(int i=x;i<=k;i+=i&(-i)){
    c[i]+=v;
  }
}
int query(int x){
  int res=0;
  for(int i=x;i;i-=i&(-i))
    res+=c[i];
  return res;
}
void cnt(int u,int fa){
  if(dis[u]>k) return;
  ans+=query(k-dis[u])+1;
  for(int i=lst[u];i;i=nxt[i]){
    int v=to[i];
    if(v==fa||vis[v]) continue;
    dis[v]=dis[u]+len[i];
    cnt(v,u);
  }
}
void mrk(int u,int fa){
  if(dis[u]>k) return;
  modify(dis[u],1);
  q[++top]=dis[u];
  for(int i=lst[u];i;i=nxt[i]){
    int v=to[i];
    if(v==fa||vis[v]) continue;
    mrk(v,u);
  }
}
void calc(int u){
  for(int i=lst[u];i;i=nxt[i]){
    int v=to[i];
    if(vis[v]) continue;
    dis[v]=len[i];
    cnt(v,0);
    mrk(v,0);
  }
  for(int i=1;i<=top;i++) modify(q[i],-1);
  top=0;
}
void run(int u){
  pre(u,0);
  minx=inf;
  g_rt(u,0,sz[u]);
  u=rt;
  vis[u]=1;
  for(int i=lst[u];i;i=nxt[i]){
    int v=to[i];
    if(vis[v]) continue;
    run(v);
  }
  calc(u);
  vis[u]=0;
}
int main(){
  //freopen("x.in","r",stdin);
  //freopen("x.out","w",stdout);
  cin>>n; int x,y,w;
  for(int i=1;i<n;i++){
    scanf("%d%d%d",&x,&y,&w);
    add(x,y,w);
    add(y,x,w);
  }
  cin>>k;
  run(1);
  printf("%d\n",ans);   // 一对点只会计算到一次, 所以不用 /2
  return 0;
}

\(O(n\log n)\)

#include<bits/stdc++.h>
using namespace std;
const int _=2e4+7;
const int __=4e4+7;
const int ___=8e4+7;
const int inf=0x3f3f3f3f;
int n,k,dis[__],sz[__],rt,minx=inf,ans,q[__],top,all;
int lst[__],nxt[___],to[___],len[___],tot;
int c[_];
bool vis[__];
void add(int x,int y,int w){ nxt[++tot]=lst[x]; to[tot]=y; len[tot]=w; lst[x]=tot; }
void pre(int u,int fa){
  sz[u]=1;
  for(int i=lst[u];i;i=nxt[i]){
    int v=to[i];
    if(v==fa||vis[v]) continue;
    pre(v,u);
    sz[u]+=sz[v];
  }
}
void g_rt(int u,int fa,int sum){
  int maxn=sum-sz[u];
  for(int i=lst[u];i;i=nxt[i]){
    int v=to[i];
    if(v==fa||vis[v]) continue;
    g_rt(v,u,sum);
    maxn=max(maxn,sz[v]);
  }
  if(maxn<minx){ minx=maxn; rt=u; }
}
void cnt(int u,int fa,int t,int res){
  if(dis[u]>k) return;
  while(t>k-dis[u]) res-=c[t--];
  ans+=res;
  for(int i=lst[u];i;i=nxt[i]){
    int v=to[i];
    if(v==fa||vis[v]) continue;
    dis[v]=dis[u]+len[i];
    cnt(v,u,t,res);
  }
}
void mrk(int u,int fa){
  if(dis[u]>k) return;
  c[dis[u]]++; all++;
  q[++top]=dis[u];
  for(int i=lst[u];i;i=nxt[i]){
    int v=to[i];
    if(v==fa||vis[v]) continue;
    mrk(v,u);
  }
}
void calc(int u){
  all=1;
  for(int i=lst[u];i;i=nxt[i]){
    int v=to[i];
    if(vis[v]) continue;
    dis[v]=len[i];
    cnt(v,0,k,all);
    mrk(v,0);
  }
  for(int i=1;i<=top;i++) c[q[i]]=0;
  top=all=0;
}
void run(int u){
  pre(u,0);
  minx=inf;
  g_rt(u,0,sz[u]);
  u=rt;
  vis[u]=1;
  for(int i=lst[u];i;i=nxt[i]){
    int v=to[i];
    if(vis[v]) continue;
    run(v);
  }
  calc(u);
  vis[u]=0;
}
int main(){
  //freopen("x.in","r",stdin);
  //freopen("x.out","w",stdout);
  cin>>n; int x,y,w;
  for(int i=1;i<n;i++){
    scanf("%d%d%d",&x,&y,&w);
    add(x,y,w);
    add(y,x,w);
  }
  cin>>k;
  run(1);
  printf("%d\n",ans);   // 一对点只会计算到一次, 所以不用 /2
  return 0;
}

posted @ 2019-12-30 09:33  BruceW  阅读(103)  评论(0编辑  收藏  举报