[模板]点分治

洛谷p3806

code
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const int maxn = 1e5 + 7;
const int inf = 1e7;

int n, m, mx[maxn], siz[maxn], dis[maxn], rem[maxn], qu[1010];
int sum, root, ans, q[maxn];
bool vis[maxn], test[inf], judge[inf];

inline int read()
{
    int x = 0, f = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9')
    {
        if(ch == '-')
        {
            f = -1;
        }
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9')
    {
        x = (x << 1) + (x << 3) + (ch^48);
        ch = getchar();
    }
    return x * f;
}

struct node 
{
    int next, to, w;
}a[maxn<<1];
int head[maxn], len;

void add(int x, int y, int w)
{
    a[++len].to = y; a[len].next = head[x]; a[len].w = w;
    head[x] = len;
}


void findroot(int x, int fa)
{
    siz[x] = 1; mx[x] = 0;
    for(int i=head[x]; i; i=a[i].next)
    {
        int v = a[i].to;
        if(v == fa || vis[v]) continue;
        findroot(v, x);
        siz[x] += siz[v];
        mx[x] = max(mx[x], siz[v]);
    }
    mx[x] = max(mx[x], sum-siz[x]);
    if(mx[x] < mx[root]) root = x;
}

void getdis(int u, int fa)
{
    rem[++rem[0]] = dis[u];
    for(int i=head[u]; i; i=a[i].next)
    {
        int v = a[i].to;
        if(v == fa || vis[v]) continue;
        dis[v] = dis[u] + a[i].w;
        getdis(v, u);
    }
}

void calc(int x)
{
    int p = 0;
    for(int i=head[x]; i; i=a[i].next)
    {
        int v = a[i].to;
        if(vis[v]) continue;
        rem[0] = 0; dis[v] = a[i].w;
        getdis(v, x);
        for(int j=rem[0]; j; j--) 
        {
            for(int k=1; k<=m; k++)
            {
                if(qu[k] >= rem[j])
                {
                    test[k] |= judge[qu[k]-rem[j]];
                }
            }
        }
        for(int j=rem[0]; j; j--)
        {
            q[++p] = rem[j]; judge[rem[j]] = 1;
        }
    }
    for(int i=1; i<=p; i++) judge[q[i]] = 0;
}

void solve(int x)
{
    vis[x] = judge[0] = 1; calc(x);
    for(int i=head[x]; i; i=a[i].next)
    {
        int v = a[i].to;
        if(vis[v]) continue;
        sum = siz[v]; root = 0;
        findroot(v, 0); solve(root);
    }
}

int main()
{   
    n = read(); m = read();
    for(int i=1; i<n; i++)
    {
        int x = read(), y = read(), w = read();
        add(x, y, w); add(y, x, w);
    }
    for(int i=1; i<=m; i++)
    {
        qu[i] = read();
    }
    mx[0] = sum = n;
    findroot(1, 0);
    solve(root);
    for(int i=1; i<=m; i++)
    {
        if(test[i]) printf("AYE\n");
        else printf("NAY\n");
    }
    
    return 0;
}

 

posted @ 2022-10-11 17:47  Catherine_leah  阅读(19)  评论(1编辑  收藏  举报
/* */