[BZOJ 4719] 天天爱跑步

Link:

BZOJ 4719 传送门

Solution:

感觉求LCA又有了新姿势啊:$Tarjan$离线$O(n+m)$

每次递归返回时将子树和父节点合并,如果询问节点已访问过则LCA就是已合并的最高节点

 

这题部分分提示非常多啊

首先要将路径拆为$(S,LCA),(LCA,T)$

发现如果$(S,LCA)$能对点$x$产生贡献要满足$w[x]+dep[x]=dep[S]$

而$(LCA,T)$能对点$x$产生贡献要满足$dep[x]-w[x]=dep[T]-len$

这样用$cnt$数组维护等式右边的$dep[S]$和$dep[T]-len$的值有多少个就能快速得出有几条路径满足条件

于是可以在路径起点加入该路径特征值并在路径末尾将其消除即可

 

注意:

1、$LCA$处可能算了两遍,最后要逐一判断

2、要在刚进入该点时记录当前$cnt[w[x]+dep[x]]$的值否则可能会将其它子树中未走完的路径计算在内

3、此题需要从下往上统计答案,因此路径起点都要设置为深度较大的,否则不好消除不经过该点路径的贡献

Code:

#include <bits/stdc++.h>

using namespace std;
#define X first
#define Y second
#define pb push_back
typedef double db;
typedef long long ll;
typedef pair<int,int> P;
const int MAXN=1e6+10,ADD=3e5;
int vis[MAXN],f[MAXN],st[MAXN];
int n,q,x,y,w[MAXN],res[MAXN],cnt[MAXN],head[MAXN],dep[MAXN],tot;

vector<P> par[MAXN];
vector<int> in[MAXN],out[MAXN];
struct edge{int nxt,to;}e[MAXN<<2];
struct Query{int x,y,lca;}qry[MAXN];

void add(int x,int y)
{e[++tot]=(edge){head[x],y};head[x]=tot;}
int find(int x)
{return f[x]==x?x:f[x]=find(f[x]);}
void tarjan(int x,int anc)
{
    vis[x]=1;f[x]=x;
    for(int i=0;i<par[x].size();i++)
        if(vis[par[x][i].X]) qry[par[x][i].Y].lca=find(par[x][i].X);
    for(int i=head[x];i;i=e[i].nxt)
        if(e[i].to!=anc)
        {
            dep[e[i].to]=dep[x]+1;
            tarjan(e[i].to,x);f[e[i].to]=x;
        }
}
void dfs1(int x,int anc)
{
    int cur=cnt[w[x]+dep[x]];
    for(int i=head[x];i;i=e[i].nxt)
        if(e[i].to!=anc) dfs1(e[i].to,x);
    cnt[dep[x]]+=st[x];
    res[x]+=cnt[w[x]+dep[x]]-cur;
    for(int i=0;i<out[x].size();i++) cnt[out[x][i]]--;
}
void dfs2(int x,int anc)
{
    int cur=cnt[ADD-w[x]+dep[x]];
    for(int i=head[x];i;i=e[i].nxt)
        if(e[i].to!=anc) dfs2(e[i].to,x);
    //都要看成从底向上的路径 
    for(int i=0;i<in[x].size();i++) cnt[in[x][i]]++;
    res[x]+=cnt[ADD-w[x]+dep[x]]-cur;
    for(int i=0;i<out[x].size();i++) cnt[out[x][i]]--;
}

int main()
{
    scanf("%d%d",&n,&q);
    for(int i=1;i<n;i++)
        scanf("%d%d",&x,&y),add(x,y),add(y,x);
    for(int i=1;i<=n;i++) scanf("%d",&w[i]);
    for(int i=1;i<=q;i++) 
    {
        scanf("%d%d",&x,&y);
        qry[i].x=x,qry[i].y=y;
        par[x].pb(P(y,i));par[y].pb(P(x,i));
    }
    tarjan(1,0);
    
    for(int i=1;i<=q;i++)
        out[qry[i].lca].pb(dep[qry[i].x]),st[qry[i].x]++;
    dfs1(1,0);
    memset(cnt,0,sizeof(cnt));
    for(int i=1;i<=n;i++) out[i].clear();
    for(int i=1;i<=q;i++)
    {
        int len=dep[qry[i].x]+dep[qry[i].y]-2*dep[qry[i].lca];
        in[qry[i].y].pb(ADD+dep[qry[i].y]-len);
        out[qry[i].lca].pb(ADD+dep[qry[i].y]-len);
    }
    dfs2(1,0);
    for(int i=1;i<=q;i++)
        if(dep[qry[i].x]-dep[qry[i].lca]==w[qry[i].lca]) 
            res[qry[i].lca]--;
    for(int i=1;i<=n;i++)
        printf("%d ",res[i]);
    return 0;
}

 

posted @ 2018-09-28 14:09  NewErA  阅读(140)  评论(0编辑  收藏  举报