NOIP2016 天天爱跑步 TarjanLCA+树上差分

题目描述
题目

这题的差分和一般的树上差分写法差好远,参考了dalao的题解还磨了好久才写出来

主要要注意的有以下几点:
1.起点s和终点t千万不要弄错(被它卡了半天的我QAQ)
2.记深度为d的起点的总数为cnt[d]:对于一条向上走的路,在起点处cnt[d]++,搜到终点的时候cnt[d]–;向下走的路,终点处cnt[d]++,起点处cnt[d]–

给这道题的细节处理跪了ORZ,磨了三天才终于A了
代码

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<vector>
using namespace std;
const int N=300010, M=N<<1;
int n, m, w[N], rt;

int ne, he[N], nq, hq[N];
struct E {int to, next;} e[M];
void build(int u, int v) {e[ne]=(E){v,he[u]}; he[u]=ne++; e[ne]=(E){u,he[v]}; he[v]=ne++;}
struct Q {int to, next, flag;} q[M];
void add(int u, int v) {q[nq]=(Q){v,hq[u],0}; hq[u]=nq++; q[nq]=(Q){u,hq[v],0}; hq[v]=nq++;}

vector< int > upS[N],upT[N],downS[N],downT[N];
//upS表示向上走的路的起点
int cntr,rlen[M],prelen[M];
int f[N],vis[N],dep[N],lca[M],S[M],T[M];
int find(int v) {return v == f[v] ? v : f[v]=find(f[v]);}
void tarjan(int u,int fa)
{
    dep[u]=dep[fa]+1; vis[u]=1; f[u]=u; int v;
    for(int i=he[u]; i != -1; i=e[i].next)
    {
        if((v=e[i].to) == fa) continue;
        tarjan(v,u); f[v]=u;
    }
    for(int i=hq[u]; i != -1; i=q[i].next)
    {
        if(!vis[v=q[i].to] || q[i].flag) continue;
        q[i].flag=q[i^1].flag=1;cntr++;
        int m=find(v), s, t;
        if(i&1) s=v,t=u;else s=u,t=v;
        if(m == s)
        {
            S[cntr]=s;T[cntr]=t;rlen[cntr]=dep[t]-dep[s];
            downS[s].push_back(cntr);downT[t].push_back(cntr);
        }
        else if(m == t)
        {
            S[cntr]=s;T[cntr]=t;rlen[cntr]=dep[s]-dep[t];
            upS[s].push_back(cntr);upT[t].push_back(cntr);
        }
        else
        {
            lca[cntr]=m;
            S[cntr]=s;T[cntr]=m;rlen[cntr]=dep[s]-dep[m];
            upS[s].push_back(cntr);upT[m].push_back(cntr);
            prelen[++cntr]=dep[s]-dep[m];
            S[cntr]=m;T[cntr]=t;rlen[cntr]=dep[t]-dep[m];
            downS[m].push_back(cntr);downT[t].push_back(cntr);
        }
    }
}

int ans[N],cnt1[M],cnt2[M];

void pushup(int u,int fa)
{
    int dep1=dep[u]+w[u]+N,ori1=cnt1[dep1],dep2=dep[u]-w[u]+N,ori2=cnt2[dep2],now,v;
    for(unsigned int i=0; i < upS[u].size(); i++) 
        now=upS[u][i],cnt1[dep[S[now]]+N]++;
    for(unsigned int i=0; i < downT[u].size(); i++)
        now=downT[u][i],cnt2[dep[T[now]]-rlen[now]-prelen[now]+N]++;
    for(int i=he[u]; i != -1; i=e[i].next)
        if((v=e[i].to) != fa) 
            pushup(v,u);

    ans[u]=cnt1[dep1]-ori1+cnt2[dep2]-ori2;

    for(unsigned int i=0; i < upT[u].size(); i++)
    {
        now=upT[u][i];
        cnt1[dep[S[now]]+N]--;
        if(lca[now] == u && dep[S[now]]+N == dep1) ans[u]--;
    }
    for(unsigned int i=0; i < downS[u].size(); i++)
        now=downS[u][i],cnt2[dep[T[now]]-rlen[now]-prelen[now]+N]--;
}

int siz[N], mind=N;
void dfs(int u,int fa)
{
    int v, minn=N, maxn=-N;siz[u]=1;
    for(int i=he[u]; i != -1; i=e[i].next)
    {
        if((v=e[i].to) == fa) continue;
        dfs(v,u); siz[u]+=siz[v];
        if(minn > siz[v]) minn=siz[v];
    }
    if(minn == N) return ;
    if(n-siz[u] < minn && fa) minn=n-siz[u];
    if(maxn < n-siz[u]) maxn=n-siz[u];
    if(mind > maxn-minn) mind=maxn-minn,rt=u;
}

void solve()
{
    dfs(1,0);
    tarjan(rt,0);
    pushup(rt,0);
    for(int i=1;i<=n;i++) printf("%d ",ans[i]);
}

int read(){
    int out=0; char c=getchar(); while(c < '0' || c > '9') c=getchar();
    while(c >= '0' && c <= '9') out=(out<<1)+(out<<3)+c-'0',c=getchar(); return out;
} 

void init()
{
    memset(he, -1, sizeof(he)); memset(hq, -1, sizeof(hq));
    n=read(), m=read(); int u, v;
    for(int i=1;i<n;i++) u=read(), v=read(), build(u,v);
    for(int i=1;i<=n;i++) w[i]=read();
    for(int i=1;i<=m;i++) u=read(), v=read(), add(u,v);
}

int main()
{
    init();solve();
    return 0;
}
posted @ 2017-12-15 21:09  zerolt  阅读(130)  评论(0编辑  收藏  举报