【UOJ】树上gcd
点分治
这道题还有很多种其它写法,什么长链剖分啦,启发式合并啦等等。
首先,我们可以把点对\((u,v)\)分成两类:
1.u到v的路径是一条链
2.u到v的路径不是一条链(废话)
对于第一类,显然\(f(u,v)\)就是链的长度,可以单独统计
对于第二类,就要在点分治上搞了
我们可以先计算出为d的倍数的点对数,最后容斥一下即可
在点分治中,我们取出当前子树的重心root,统计路径经过root的点对,那么又可以分成两类:
A.u和v都在root的子树内
B.u和v一个在root的子树内,另一个不在
对于A类:
\(num[x]\)表示以root的儿子为根的子树内深度(相对于root的深度,下同)为x的点个数
\(Tnum[x]\)表示以root为根的子树内深度为x的点个数
\(sum[x]\)表示以root的儿子为根的子树内深度为x的倍数的点个数
\(Tsum[x]\)表示以root为根的子树内深度为x的倍数的点个数
那么我们暴力遍历root的儿子:
void DFS_Deep(int x,int fa,int dep,int &max_dep){//max_dep表示当前子树的最大深度,方便更新和清空
max_dep=max(max_dep,dep);
num[dep]++;
for(int i=0;i<G[x].size();i++){
int t=G[x][i];
if(vis[t]||t==fa)continue;
DFS_Deep(t,x,dep+1,max_dep);
}
}
然后我们求出\(sum[x]\),并把\(num[x]\)和\(sum[x]\)加到\(Tnum[x]\)和\(Tsum[x]\),同时更新答案。
这部分整体代码:
void Update_Ans1(int x,int fa){
for(int i=0;i<G[x].size();i++){
int t=G[x][i];
if(vis[t]||t==fa)continue;
int tmp=0;//tmp表示以当前儿子为根的子树的最大深度
DFS_Deep(t,x,1,tmp);//遍历
Maxd=max(Maxd,tmp);//更新总的最大深度
for(int j=1;j<=tmp;j++)for(int k=j;k<=tmp;k+=j)sum[j]+=num[k];//求出sum
for(int j=1;j<=Maxd;j++)ans1[j]+=1LL*sum[j]*Tsum[j];//更新答案
for(int j=1;j<=tmp;j++)Tsum[j]+=sum[j],Tnum[j]+=num[j];//加进去
for(int j=1;j<=tmp;j++)sum[j]=num[j]=0;//清空
}
}
对于B类:
这就有点小麻烦了。。。
我们设当前子树的根节点为g,那我们遍历\(pre[root]\)到g的路径,对于每个点i,我们可以向上面一样求出以i为根的子树内\(num[x]\)和\(sum[x]\)的值:
int tmp=0;
for(int i=0;i<G[x].size();i++){//我拿x来代指i的
int t=G[x][i];
if(vis[t]||t==pre[x]||t==la)continue;//la表示这条路径是由la上去的,这也不能遍历下去
DFS_Deep(t,x,1,tmp);
Maxd=max(Maxd,tmp);
}
for(int j=1;j<=tmp;j++)for(int k=j;k<=tmp;k+=j)sum[j]+=num[k];
然后,我们枚举d,我们需要在root的子树中找深度间隔为d的点个数和,由于root到i还是有距离的,所以并不是从root点直接开始找。不过这样只有d种情况,我们可以记忆化,但是记忆化空间是开不下的,所以对于\(d\le \sqrt n\)我们就记忆化处理,而对于\(d>\sqrt n\)我们就直接找(难以口胡,实在不行就看代码吧)
这部分整体代码:
void Update_Ans2(int x,int la,int count){//count表示root到x的距离(x代指i)
int tmp=0;
for(int i=0;i<G[x].size();i++){
int t=G[x][i];
if(vis[t]||t==pre[x]||t==la)continue;
DFS_Deep(t,x,1,tmp);
Maxd=max(Maxd,tmp);
}
for(int j=1;j<=tmp;j++)for(int k=j;k<=tmp;k+=j)sum[j]+=num[k];
int limit=min(tmp,Up);//Up表示sqrt(n)
for(int j=1;j<=limit;j++){//小于sqrt(n)记忆化
if(dp[j][count%j]==-1){//dp用来记忆化
dp[j][count%j]=0;
for(int k=(j-count%j)%j;k<=Maxd;k+=j)dp[j][count%j]+=Tnum[k];
}
ans1[j]+=1LL*dp[j][count%j]*sum[j];//更新答案
}
for(int j=limit+1;j<=tmp;j++){//大于sqrt(n)直接枚举
ll res=0;
for(int k=(j-count%j)%j;k<=Maxd;k+=j)res+=Tnum[k];
ans1[j]+=1LL*res*sum[j];//更新答案
}
for(int i=1;i<=tmp;i++)sum[i]=num[i]=0;//清空
}
全部代码:
#include<bits/stdc++.h>
#define ll long long
#define MAXN 200010
using namespace std;
int n,pre[MAXN],size[MAXN],W[MAXN],root,num[MAXN],Tnum[MAXN],sum[MAXN],Tsum[MAXN],Maxd,dp[1010][1010],Up,SIZE,deep[MAXN];
ll ans1[MAXN],ans2[MAXN];
vector<int> G[MAXN];
bool vis[MAXN];
void Getroot(int x,int fa){
size[x]=1,W[x]=0;
for(int i=0;i<G[x].size();i++){
int t=G[x][i];
if(t==fa||vis[t])continue;
Getroot(t,x);
size[x]+=size[t];
W[x]=max(W[x],size[t]);
}
W[x]=max(W[x],SIZE-size[x]);
if(W[x]<W[root])root=x;
}
void DFS_Deep(int x,int fa,int dep,int &max_dep){
max_dep=max(max_dep,dep);
num[dep]++;
for(int i=0;i<G[x].size();i++){
int t=G[x][i];
if(vis[t]||t==fa)continue;
DFS_Deep(t,x,dep+1,max_dep);
}
}
void Update_Ans1(int x,int fa){
for(int i=0;i<G[x].size();i++){
int t=G[x][i];
if(vis[t]||t==fa)continue;
int tmp=0;
DFS_Deep(t,x,1,tmp);
Maxd=max(Maxd,tmp);
for(int j=1;j<=tmp;j++)for(int k=j;k<=tmp;k+=j)sum[j]+=num[k];
for(int j=1;j<=Maxd;j++)ans1[j]+=1LL*sum[j]*Tsum[j];
for(int j=1;j<=tmp;j++)Tsum[j]+=sum[j],Tnum[j]+=num[j];
for(int j=1;j<=tmp;j++)sum[j]=num[j]=0;
}
}
void Update_Ans2(int x,int la,int count){
int tmp=0;
for(int i=0;i<G[x].size();i++){
int t=G[x][i];
if(vis[t]||t==pre[x]||t==la)continue;
DFS_Deep(t,x,1,tmp);
Maxd=max(Maxd,tmp);
}
for(int j=1;j<=tmp;j++)for(int k=j;k<=tmp;k+=j)sum[j]+=num[k];
int limit=min(tmp,Up);
for(int j=1;j<=limit;j++){
if(dp[j][count%j]==-1){
dp[j][count%j]=0;
for(int k=(j-count%j)%j;k<=Maxd;k+=j)dp[j][count%j]+=Tnum[k];
}
ans1[j]+=1LL*dp[j][count%j]*sum[j];
}
for(int j=limit+1;j<=tmp;j++){
ll res=0;
for(int k=(j-count%j)%j;k<=Maxd;k+=j)res+=Tnum[k];
ans1[j]+=1LL*res*sum[j];
}
for(int i=1;i<=tmp;i++)sum[i]=num[i]=0;
}
void DFS_Point(int x){
Maxd=0;
vis[x]=true;
Update_Ans1(x,pre[x]);
Tnum[0]=1;
int count=0;
for(int i=x;;i=pre[i]){
if(vis[pre[i]]||pre[i]==0)break;
count++;
Update_Ans2(pre[i],i,count);
}
for(int i=0;i<=Maxd;i++)Tnum[i]=Tsum[i]=0;
int limit=min(Maxd,Up);
for(int i=1;i<=limit;i++){
for(int j=0;j<=i-1;j++)dp[i][j]=-1;
}
for(int i=0;i<G[x].size();i++){
int t=G[x][i];
if(vis[t])continue;
root=0,SIZE=size[t];
Getroot(t,0);
DFS_Point(root);
}
}
int main(){
W[0]=2e9+7;
memset(dp,-1,sizeof(dp));
scanf("%d",&n);
SIZE=n;
Up=sqrt(n);
for(int i=2;i<=n;i++)scanf("%d",&pre[i]),G[pre[i]].push_back(i),G[i].push_back(pre[i]);
Getroot(1,0);
DFS_Point(root);
//统计路径是一条链的点对
for(int i=1;i<=n;i++)deep[i]=deep[pre[i]]+1,++ans2[deep[i]-1];
for(int i=n-1;i>=1;i--)ans2[i]+=ans2[i+1];
for(int i=n-1;i>=1;i--){//容斥
for(int j=i+i;j<=n-1;j+=i)ans1[i]-=ans1[j];
}
for(int i=1;i<=n-1;i++)printf("%lld\n",ans1[i]+ans2[i]);
return 0;
}