【UOJ】树上gcd

点分治

这道题还有很多种其它写法,什么长链剖分啦,启发式合并啦等等。

首先,我们可以把点对\((u,v)\)分成两类:
1.u到v的路径是一条链
2.u到v的路径不是一条链(废话)
对于第一类,显然\(f(u,v)\)就是链的长度,可以单独统计

对于第二类,就要在点分治上搞了
我们可以先计算出为d的倍数的点对数,最后容斥一下即可
在点分治中,我们取出当前子树的重心root,统计路径经过root的点对,那么又可以分成两类:
A.u和v都在root的子树内
B.u和v一个在root的子树内,另一个不在

对于A类:
\(num[x]\)表示以root的儿子为根的子树内深度(相对于root的深度,下同)为x的点个数
\(Tnum[x]\)表示以root为根的子树内深度为x的点个数
\(sum[x]\)表示以root的儿子为根的子树内深度为x的倍数的点个数
\(Tsum[x]\)表示以root为根的子树内深度为x的倍数的点个数
那么我们暴力遍历root的儿子:

void DFS_Deep(int x,int fa,int dep,int &max_dep){//max_dep表示当前子树的最大深度,方便更新和清空
    max_dep=max(max_dep,dep);
    num[dep]++;
    for(int i=0;i<G[x].size();i++){
        int t=G[x][i];
        if(vis[t]||t==fa)continue;
        DFS_Deep(t,x,dep+1,max_dep);
    }
}

然后我们求出\(sum[x]\),并把\(num[x]\)\(sum[x]\)加到\(Tnum[x]\)\(Tsum[x]\),同时更新答案。
这部分整体代码:

void Update_Ans1(int x,int fa){
    for(int i=0;i<G[x].size();i++){
        int t=G[x][i];
        if(vis[t]||t==fa)continue;
        int tmp=0;//tmp表示以当前儿子为根的子树的最大深度
        DFS_Deep(t,x,1,tmp);//遍历
        Maxd=max(Maxd,tmp);//更新总的最大深度
        for(int j=1;j<=tmp;j++)for(int k=j;k<=tmp;k+=j)sum[j]+=num[k];//求出sum
        for(int j=1;j<=Maxd;j++)ans1[j]+=1LL*sum[j]*Tsum[j];//更新答案
        for(int j=1;j<=tmp;j++)Tsum[j]+=sum[j],Tnum[j]+=num[j];//加进去
        for(int j=1;j<=tmp;j++)sum[j]=num[j]=0;//清空
    }
}

对于B类:
这就有点麻烦了。。。
我们设当前子树的根节点为g,那我们遍历\(pre[root]\)到g的路径,对于每个点i,我们可以向上面一样求出以i为根的子树内\(num[x]\)\(sum[x]\)的值:

int tmp=0;
for(int i=0;i<G[x].size();i++){//我拿x来代指i的
    int t=G[x][i];
    if(vis[t]||t==pre[x]||t==la)continue;//la表示这条路径是由la上去的,这也不能遍历下去
    DFS_Deep(t,x,1,tmp);
    Maxd=max(Maxd,tmp);
}
for(int j=1;j<=tmp;j++)for(int k=j;k<=tmp;k+=j)sum[j]+=num[k];

然后,我们枚举d,我们需要在root的子树中找深度间隔为d的点个数和,由于root到i还是有距离的,所以并不是从root点直接开始找。不过这样只有d种情况,我们可以记忆化,但是记忆化空间是开不下的,所以对于\(d\le \sqrt n\)我们就记忆化处理,而对于\(d>\sqrt n\)我们就直接找(难以口胡,实在不行就看代码吧)
这部分整体代码:

void Update_Ans2(int x,int la,int count){//count表示root到x的距离(x代指i)
    int tmp=0;
    for(int i=0;i<G[x].size();i++){
        int t=G[x][i];
        if(vis[t]||t==pre[x]||t==la)continue;
        DFS_Deep(t,x,1,tmp);
        Maxd=max(Maxd,tmp);
    }
    for(int j=1;j<=tmp;j++)for(int k=j;k<=tmp;k+=j)sum[j]+=num[k];
    int limit=min(tmp,Up);//Up表示sqrt(n)
    for(int j=1;j<=limit;j++){//小于sqrt(n)记忆化
        if(dp[j][count%j]==-1){//dp用来记忆化
            dp[j][count%j]=0;
            for(int k=(j-count%j)%j;k<=Maxd;k+=j)dp[j][count%j]+=Tnum[k];
        }
        ans1[j]+=1LL*dp[j][count%j]*sum[j];//更新答案
    }
    for(int j=limit+1;j<=tmp;j++){//大于sqrt(n)直接枚举
        ll res=0;
        for(int k=(j-count%j)%j;k<=Maxd;k+=j)res+=Tnum[k];
        ans1[j]+=1LL*res*sum[j];//更新答案
    }
    for(int i=1;i<=tmp;i++)sum[i]=num[i]=0;//清空
}

全部代码:

#include<bits/stdc++.h>
#define ll long long
#define MAXN 200010
using namespace std;
int n,pre[MAXN],size[MAXN],W[MAXN],root,num[MAXN],Tnum[MAXN],sum[MAXN],Tsum[MAXN],Maxd,dp[1010][1010],Up,SIZE,deep[MAXN];
ll ans1[MAXN],ans2[MAXN];
vector<int> G[MAXN];
bool vis[MAXN];
void Getroot(int x,int fa){
    size[x]=1,W[x]=0;
    for(int i=0;i<G[x].size();i++){
        int t=G[x][i];
        if(t==fa||vis[t])continue;
        Getroot(t,x);
        size[x]+=size[t];
        W[x]=max(W[x],size[t]);
    }
    W[x]=max(W[x],SIZE-size[x]);
    if(W[x]<W[root])root=x;
}
void DFS_Deep(int x,int fa,int dep,int &max_dep){
    max_dep=max(max_dep,dep);
    num[dep]++;
    for(int i=0;i<G[x].size();i++){
        int t=G[x][i];
        if(vis[t]||t==fa)continue;
        DFS_Deep(t,x,dep+1,max_dep);
    }
}
void Update_Ans1(int x,int fa){
    for(int i=0;i<G[x].size();i++){
        int t=G[x][i];
        if(vis[t]||t==fa)continue;
        int tmp=0;
        DFS_Deep(t,x,1,tmp);
        Maxd=max(Maxd,tmp);
        for(int j=1;j<=tmp;j++)for(int k=j;k<=tmp;k+=j)sum[j]+=num[k];
        for(int j=1;j<=Maxd;j++)ans1[j]+=1LL*sum[j]*Tsum[j];
        for(int j=1;j<=tmp;j++)Tsum[j]+=sum[j],Tnum[j]+=num[j];
        for(int j=1;j<=tmp;j++)sum[j]=num[j]=0;
    }
}
void Update_Ans2(int x,int la,int count){
    int tmp=0;
    for(int i=0;i<G[x].size();i++){
        int t=G[x][i];
        if(vis[t]||t==pre[x]||t==la)continue;
        DFS_Deep(t,x,1,tmp);
        Maxd=max(Maxd,tmp);
    }
    for(int j=1;j<=tmp;j++)for(int k=j;k<=tmp;k+=j)sum[j]+=num[k];
    int limit=min(tmp,Up);
    for(int j=1;j<=limit;j++){
        if(dp[j][count%j]==-1){
            dp[j][count%j]=0;
            for(int k=(j-count%j)%j;k<=Maxd;k+=j)dp[j][count%j]+=Tnum[k];
        }
        ans1[j]+=1LL*dp[j][count%j]*sum[j];
    }
    for(int j=limit+1;j<=tmp;j++){
        ll res=0;
        for(int k=(j-count%j)%j;k<=Maxd;k+=j)res+=Tnum[k];
        ans1[j]+=1LL*res*sum[j];
    }
    for(int i=1;i<=tmp;i++)sum[i]=num[i]=0;
}
void DFS_Point(int x){
    Maxd=0;
    vis[x]=true;
    Update_Ans1(x,pre[x]);
    Tnum[0]=1;
    int count=0;
    for(int i=x;;i=pre[i]){
        if(vis[pre[i]]||pre[i]==0)break;
        count++;
        Update_Ans2(pre[i],i,count);
    }
    for(int i=0;i<=Maxd;i++)Tnum[i]=Tsum[i]=0;
    int limit=min(Maxd,Up);
     for(int i=1;i<=limit;i++){
         for(int j=0;j<=i-1;j++)dp[i][j]=-1;
     }
    for(int i=0;i<G[x].size();i++){
        int t=G[x][i];
        if(vis[t])continue;
        root=0,SIZE=size[t];
        Getroot(t,0);
        DFS_Point(root);
    }
}
int main(){
    W[0]=2e9+7;
    memset(dp,-1,sizeof(dp));
    scanf("%d",&n);
    SIZE=n;
    Up=sqrt(n);
    for(int i=2;i<=n;i++)scanf("%d",&pre[i]),G[pre[i]].push_back(i),G[i].push_back(pre[i]);
    Getroot(1,0);
    DFS_Point(root);
    //统计路径是一条链的点对
    for(int i=1;i<=n;i++)deep[i]=deep[pre[i]]+1,++ans2[deep[i]-1];
    for(int i=n-1;i>=1;i--)ans2[i]+=ans2[i+1];
    for(int i=n-1;i>=1;i--){//容斥
        for(int j=i+i;j<=n-1;j+=i)ans1[i]-=ans1[j];
    }
    for(int i=1;i<=n-1;i++)printf("%lld\n",ans1[i]+ans2[i]);
    return 0;
}
posted @ 2019-08-09 18:46  TieT  阅读(369)  评论(0编辑  收藏  举报