洛谷 P4178 Tree

给你一棵树,以及这棵树上边的距离.问有多少对点它们两者间的距离小于等于K

点分治

我们如果开桶记录路径长度的数,那么需要维护一个单点加和前缀和,用树状数组维护就行了

不过还有种排序双指针的方法,复杂度一样,懒得写了qwq

复杂度\(O(nlog^2n)\)

Code

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
const int N = 4e4;
const int INF = 4e5;
using namespace std;
struct node
{
    int to,cost;
};
int n,k,rt,maxp[N + 5],size[N + 5],su,vis[N + 5],ans,c[INF + 5],s[N + 5],num,cnt,now[N + 5];
vector <node> d[N + 5];
int lowbit(int x)
{
    return x & (-x);
}
void add(int x,int s)
{
    for (;x <= k;x += lowbit(x))
        c[x] += s;
}
int query(int x)
{
    int ans = 0;
    for (;x;x -= lowbit(x))
        ans += c[x];
    return ans;
}
void get_rt(int u,int fa)
{
    maxp[u] = 0;
    size[u] = 1;
    vector <node>::iterator it;
    for (it = d[u].begin();it != d[u].end();it++)
    {
        int v = (*it).to;
        if (v == fa || vis[v])
            continue;
        get_rt(v,u);
        size[u] += size[v];
        maxp[u] = max(maxp[u],size[v]);
    }
    maxp[u] = max(maxp[u],su - size[u]);
    if (maxp[u] < maxp[rt])
        rt = u;
}
void get_dis(int u,int fa,int s)
{
    if (s > k)
        return;
    now[++cnt] = s;
    vector <node>::iterator it;
    for (it = d[u].begin();it != d[u].end();it++)
    {
        int v = (*it).to,w = (*it).cost;
        if (vis[v] || v == fa)
            continue;
        get_dis(v,u,s + w);
    }
}
void calc(int u)
{
    num = 0;
    vector <node>::iterator it;
    for (it = d[u].begin();it != d[u].end();it++)
    {
        int v = (*it).to,w = (*it).cost;
        if (vis[v])
            continue;
        cnt = 0;
        get_dis(v,u,w);
        for (int i = 1;i <= cnt;i++)
            ans += query(k - now[i]) + 1;
        for (int i = 1;i <= cnt;i++)
        {
            s[++num] = now[i];
            add(now[i],1);
        }
    }
    for (int i = 1;i <= num;i++)
        add(s[i],-1);
}
void solve(int u)
{
    vis[u] = 1;
    calc(u);
    vector <node>::iterator it;
    for (it = d[u].begin();it != d[u].end();it++)
    {
        int v = (*it).to;
        if (vis[v]) 
            continue;
        su = size[v];
        rt = 0;
        maxp[0] = INF;
        get_rt(v,0);
        solve(rt);
    }
}
int main()
{
    scanf("%d",&n);
    int u,v,w;
    for (int i = 1;i < n;i++)
    {
        scanf("%d%d%d",&u,&v,&w);
        d[u].push_back((node){v,w});
        d[v].push_back((node){u,w});
    }
    scanf("%d",&k);
    maxp[0] = INF;
    su = n;
    get_rt(1,0);
    solve(rt);
    printf("%d\n",ans);
    return 0;
}
posted @ 2020-06-08 21:36  eee_hoho  阅读(71)  评论(0编辑  收藏  举报