P4178 Tree——点分治 容斥

Tree

题目描述

给定一棵 \(n\) 个节点的树,每条边有边权,求出树上两点距离小于等于 \(k\) 的点对数量。

输入格式

第一行输入一个整数 \(n\),表示节点个数。

第二行到第 \(n\) 行每行输入三个整数 \(u,v,w\) ,表示 \(u\)\(v\) 有一条边,边权是 \(w\)

\(n+1\) 行一个整数 \(k\)

输出格式

一行一个整数,表示答案。

样例 #1

样例输入 #1

7
1 6 13 
6 3 9 
3 5 7 
4 1 3 
2 4 20 
4 7 2 
10

样例输出 #1

5

提示

数据规模与约定

对于全部的测试点,保证:

  • \(1\leq n\leq 4\times 10^4\)
  • \(1\leq u,v\leq n\)
  • \(0\leq w\leq 10^3\)
  • \(0\leq k\leq 2\times 10^4\)

codes

#include<bits/stdc++.h>
using namespace std;
const int N=4e4+100;
const long long M=1e9+100;
struct edge{int y,x,n,z;}e[N<<1];
int lowbit(int x){return x&(-x);}
int n,m,cnt,head[N],q[N],top;
int siz[N],mxa[N],root,all,dis[N],c[N];
int ans=0;bool vis[N];
void upd(int x,int z){if(x<=0)return ;while(x<=m)c[x]+=z,x+=lowbit(x);}
int que(int x){if(x<=0)return 0;int sum=0;while(x)sum+=c[x],x-=lowbit(x);return sum;}
int po(int x){return que(x)-que(x-1);}
void ad(int x,int y,int z)
{
    e[++cnt].n=head[x];
    e[cnt].y=y;
    e[cnt].x=x;
    e[cnt].z=z;
    head[x]=cnt;
}
void init()
{
    scanf("%d",&n);
    for(int i=1,x,y,z;i<n;++i)
    {
        scanf("%d%d%d",&x,&y,&z);
        ad(x,y,z);ad(y,x,z);
    }
    scanf("%d",&m);
}
void getrt(int u,int fa)
{
    siz[u]=1;
    mxa[u]=0;
    for(int i=head[u];i;i=e[i].n)
    {
        int v=e[i].y;
        if(v==fa || vis[v])continue;
        getrt(v,u);
        siz[u]+=siz[v];
        mxa[u]=max(mxa[u],siz[v]);
    }
    mxa[u]=max(mxa[u],all-siz[u]);
    if(mxa[u]<mxa[root])root=u;
}

void getdis(int u,int fa)
{
    if(dis[u]<=m)
    {
        ++ans;
        ans+=que(m-dis[u]);
    }
    for(int i=head[u];i;i=e[i].n)
    {
        int v=e[i].y;
        if(vis[v] || v==fa)continue;
        dis[v]=dis[u]+e[i].z;
        getdis(v,u);
    }
}
void ch(int u,int fa,int z)
{
    if(dis[u]<=m)
        upd(dis[u],z);
    for(int i=head[u];i;i=e[i].n)
    {
        int v=e[i].y;
        if(vis[v] || v==fa)continue;
        ch(v,u,z);
    }
}
void calc(int u)
{
    int num=0;

    for(int i=head[u];i;i=e[i].n)
    {
        int v=e[i].y;
        if(vis[v])continue;
        dis[v]=e[i].z;
        getdis(v,0);
        ch(v,u,1);
    }

    ch(u,0,-1);
}

void solve(int nw)
{
    vis[nw]=1;
    dis[nw]=0;
    calc(nw);
    for(int i=head[nw];i;i=e[i].n)
    {
        int v=e[i].y;
        if(vis[v])continue;
        all=siz[v];
        root=0;
        mxa[root]=M;
        getrt(v,0);
        solve(root);
    }
}

void work()
{
    all=n;
    mxa[root]=M;
    getrt(1,0);
    solve(root);
    cout<<ans;
}

int main()
{

    init();
    work();
    return 0;
}








posted @ 2024-11-05 21:08  Glowingfire  阅读(4)  评论(3编辑  收藏  举报