树上路径问题 LCA+思维难题

P5588 小猪佩奇爬树:

P5588 小猪佩奇爬树 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)

v用来存储各个颜色的节点

一.v[i].size()=0时,不再赘述

二.v[i].size()=1时,

此时便是把 i 这个节点看成根节点,求他子节点所有两两之乘

显然把所有i进行dfs会超时O(n^2)

有一个巧妙的方法

void dfs(int x,int f)
{
    size[x]=1;
    dep[x]=dep[f]+1;
    fa[x][0]=f;
    for(int i=1;i<=lg[dep[x]];i++) fa[x][i]=fa[fa[x][i-1]][i-1];
    for(int i=lin[x];i;i=e[i].next)
    {
        int v=e[i].y;
        if(v==f) continue;
        dfs(v,x);
        ans1[x]+=size[x]*size[v]; 这步还是很妙的
        size[x]+=size[v];
    }
    ans1[x]+=size[x]*(n-size[x]); 最后这步别忘了
}ans1[x]便是v[x].size()=1时的情况了

 

三.v[i].size>=2时

 此时,所有的点都应该在一条路径上,否则无解

这是便有两个问题:如何判断所有点在一条路径,又如何获得两个端点是哪两个

别急,此时有在一条路径上有两种情况: 1.一条链组成 2.两条链组成

1.所有的点都在一条链上
 

 

 此时 3,4,5这三个相同颜色的点在一条链上

那么如何判断这些点在所有链上:
首先我们把这些点存到v[i]上,接着按照深度从深到浅排序,那么此时v[i]={5,4,3};

那么链上的时候两个端点分别是最深的那个点和最浅的那个点

我们把最深的点拿出来,如果LCA(其他点,最深的点)=其他这个点,那么就能保证所有点在一条链上(这步还是很妙的)

int l=v[i][0],r,flag1=1,flag2=1;//l便是最深的点,flag1表示是不是链的状态,flag2表示其他在一条路上的状态
for(int j=1;j<v[i].size();j++)
        if(LCA(l,v[i][j])!=v[i][j]) { r=v[i][j];flag1=0;break;} //判断出不是链的状态

此外这题还有很细节的地方,比如统计时r时(另一个端点,在链中是链里最上面的点),

此时不是(n-size[r]+1)而是(n-size[Son(r)])

如图符合的点是1,2,3,9.不是1,2,3

那门如何求Son(r)呢,我们暴力一点,直接从 l最深的点开始往上推,直到fa[][0]== break掉

if(flag1)
{
    for(r=v[i][0];fa[r][0]!=v[i].back();r=fa[r][0]); //找最上面的点的儿子点
    ans=size[l]*(n-size[r]); //统计答案
}
2.所有的点在一条路径上,但不在一条链上

 

这种情况,这条路其实是两条链,而两个端点是两条链的最深的点

 我们先回来看判断链的代码


int l=v[i][0],r,flag1=1,flag2=1;//l便是最深的点,flag1表示是不是链的状态,flag2表示其他在一条路上的状态
for(int j=1;j<v[i].size();j++)
        if(LCA(l,v[i][j])!=v[i][j]) { r=v[i][j];flag1=0;break;} //判断出不是链的状态

如何不是一条链,我们用 r 记录第一个不在这个链的点

由于已经由深到浅排序,那么这个r点就是另一条链的最深的点

接下来就是判断是否是这种情况,由于其他点是在这两条链上的,我们还是用上面的LCA判断

else
{
    for(int j=1;j<v[i].size();j++)
      if(LCA(l,v[i][j])!=v[i][j]&&LCA(r,v[i][j])!=v[i][j])  {flag2=0;break; } //这个点都不在这两条链上,那么就不是一条路径
   if(flag2) ans=size[l]*size[r];} //这种情况的答案就很好统计了
}

那么这题就做完了

Code:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define mp make_pair
#define pb push_back   //vector函数
#define popb pop_back  //vector函数
#define fi first
#define se second
const int N=1e6+10;
//const int M=;
//const int inf=0x3f3f3f3f;     //一般为int赋最大值,不用于memset中
//const ll INF=0x3ffffffffffff; //一般为ll赋最大值,不用于memset中
int n,m,len=0,root,dep[N],lg[N],lin[N],size[N],num[N],fa[N][20];
int l[N],r1[N],r2[N];
ll ans1[N];
bool vis[N];
struct edge{ int next,y; }e[N<<1];
vector<int> v[N];
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
    return x*f;
}
void insert(int xx,int yy)
{
    e[++len].next=lin[xx];
    lin[xx]=len;
    e[len].y=yy;
}
void dfs(int x,int f)
{
    size[x]=1;
    dep[x]=dep[f]+1;
    fa[x][0]=f;
    for(int i=1;i<=lg[dep[x]];i++) fa[x][i]=fa[fa[x][i-1]][i-1];
    for(int i=lin[x];i;i=e[i].next)
    {
        int v=e[i].y;
        if(v==f) continue;
        dfs(v,x);
        ans1[x]+=size[x]*size[v];
        size[x]+=size[v];
    }
    ans1[x]+=size[x]*(n-size[x]);
}
int LCA(int x,int y)
{
    if(dep[x]<dep[y]) swap(x,y);
    while(dep[x]>dep[y]) x=fa[x][lg[dep[x]-dep[y]]];
    if(x==y) return x;
    for(int i=lg[dep[x]];i>=0;i--)
        if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; 
    return fa[x][0];
}
int main()
{
//    freopen("","r",stdin);
//    freopen("","w",stdout);
    n=read();
    for(int i=2;i<=n;i++) lg[i]=lg[i>>1]+1; 
    for(int i=1;i<=n;i++)
    {
        int col=read();
        v[col].pb(i);
    }
    for(int i=1;i<n;i++)
    {
        int x=read(),y=read();
        insert(x,y);insert(y,x);
    }
    dfs(1,0);
    for(int i=1;i<=n;i++)
    {
        ll ans=0;
        if(!v[i].size()) ans=1LL*n*(n-1)/2;
        else if(v[i].size()==1) ans=ans1[v[i][0]];
        else
        {
            reverse(v[i].begin(),v[i].end());
            int l=v[i][0],r,flag1=1,flag2=1;
            for(int j=1;j<v[i].size();j++)
                if(LCA(l,v[i][j])!=v[i][j]) { r=v[i][j];flag1=0;break; }
            if(flag1)
            {
                for(r=v[i][0];fa[r][0]!=v[i].back();r=fa[r][0]);
                ans=size[l]*(n-size[r]);
            }
            else
            {
                for(int j=1;j<v[i].size();j++)
                    if(LCA(l,v[i][j])!=v[i][j]&&LCA(r,v[i][j])!=v[i][j]) { flag2=0;break; }
                if(flag2) ans=size[l]*size[r];
            }
        }
        printf("%lld\n",ans);
    }
    return 0;
}

 

posted @ 2023-01-09 14:06  QAQ啥也不会  阅读(23)  评论(0编辑  收藏  举报