thr [树链剖分+dp]
题面
思路
首先,可以有一个$dp$的思路
不难发现本题中,三个点如果互相距离相同,那么一定有一个“中心点”到三个点的距离都相同
那么我们可以把本题转化计算以每个点为根的情况下,从三个不同的子树中选择深度相同的三个点的方案数
进一步,我们选定1号点为根,这样定义我们的$dp$方程:
$f[u][dis]$表示以$u$为根的子树中,和$u$的距离为$dis$的点的数量
$g[u][dis]$表示以$u$为根的子树中已经选择了不属于同一个$u$的儿子子树的两个点,并且距离为dis的方案数
那么转移方程如下:
$f[u][i]=\sum_{v=son(u)}f[v][i-1]$
$g[u][i]=\sum_{v=son(u)}g[v][i+1]+\sum_{p=son(u)}\sum_{q=son(u),q\neq p} f[p][i-1]*f[q][i-1]$
这个过程非常方便$dp$,但是时间复杂度和空间复杂度都是$O(n^2)$的
然后我们惊喜地发现一件事情:
这里的$f,g$都可以先继承一个儿子的内容,然后再和别的儿子合并!
考虑我们继承的这个过程,也就是$u$从确定的继承者$v$那里获得信息的过程:
$f[u][i]=f[v][i-1]$
$g[u][i]=g[v][i+1]$
我们又惊喜地发现,如果我们一整条链都用这样的继承方式的话,那么长度为$n$的链只需要$O(n)$的空间!
具体而言,我们对这$n$个点开两个数组,$F,G$
此时,最下层的叶子最深,它的$f,g$数组的第二位只有0
那么$f[leaf][0]=F[1],g[leaf][0]=G[n]$
然后$f[fa[leaf]][0]=F[2],f[fa[leaf]][1]=F[1]$,$g$也类似
就这样一直下去
可以看到实际上$F[leaf]$保存了$n$个$dp$值,因为这$n$个是相等的
让后我们又又惊喜地发现,现在我们的$dp$,可以逐个子树做合并,每一次合并的复杂度是$maxdep(v)$
这样我们就得到了一个时间复杂度$O(n)$,空间复杂度$O(n)$的算法
PS.本题中其实什么剖分方法都是一样的qwq,都是$O(n)$
Summary
个人认为这道题是一道好题
其实,原本的$dp$思路并不难想,同时后面的做优化的原因其实都已经很好地植入在上一层的方法里面了
也就是说,这道题如果你的思维开阔,是很容易顺着思路一路就从$O(n2)+O(n2)$优化到$O(n)+O(n)$的
本题中体现的主要思想,在于观察已有方法中的一些性质,并代入一些非常规手法(启发式合并和一个位置保存多个$dp$值)
这和之前我曾经遇到的一些题目不同
大部分题目是在旧方法里面找时间复杂度高的部分,然后以这些部分为主要攻关点来进行突破、优化
然而当我们遇到影响复杂度的地方并不集中的算法时,我们就需要打开思路,从算法本身性质入手,套用更优秀的方法来解题
这道题并不是一道“套路题”,不是只要做过类似的东西就能一下子秒掉的
现在的OI考场上,我们遇到的这类题目也一定会越来越多——这实际上意味着OI比赛题目质量的整体上升
因此我们一定要锻炼自己打开思路、广撒网的能力,不要老是一条路走到黑(就像我以前想题一样)
Code
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
#define fpos DEEP_DARK_FANTASY_1
#define gpos DEEP_DARK_FANTASY_2
using namespace std;
inline int read(){
int re=0,flag=1;char ch=getchar();
while(ch>'9'||ch<'0'){
if(ch=='-') flag=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9') re=(re<<1)+(re<<3)+ch-'0',ch=getchar();
return re*flag;
}
int n,first[50010],cnte;
struct edge{
int to,next;
}a[100010];
inline void add(int u,int v){
a[++cnte]=(edge){v,first[u]};first[u]=cnte;
a[++cnte]=(edge){u,first[v]};first[v]=cnte;
}
int dep[50010],maxdep[50010],fpos[50010],gpos[50010],cntf,cntg,son[50010],siz[50010],top[50010];
ll f[200010],g[200010];
int getf(int u,int i){
return fpos[top[u]]+dep[u]-dep[top[u]]+i;
}
int getg(int u,int i){
return gpos[top[u]]-dep[u]+dep[top[u]]+i;
}
int getlen(int u){
return maxdep[u]-dep[u]+1;
}
int fa[50010];
void dfs1(int u,int f){
int i,v;fa[u]=f;
dep[u]=dep[f]+1;maxdep[u]=dep[u];
siz[u]=1;son[u]=0;
for(i=first[u];~i;i=a[i].next){
v=a[i].to;if(v==f) continue;
dfs1(v,u);
siz[u]+=siz[v];
if(siz[son[u]]<siz[v]) son[u]=v;
maxdep[u]=max(maxdep[u],maxdep[v]+1);
}
}
void dfs2(int u,int t){
int i,v;
top[u]=t;
if(u==t){
int len=getlen(u);
fpos[u]=cntf;//这里记录的是链头,每个链头都分配一段数组,长度等于这条链的长度
gpos[u]=cntg+len;//后来发现本题中什么剖法都是O(n),所以我写了重链剖分
cntf+=len;
cntg+=len<<1;
}
if(!son[u]) return;
dfs2(son[u],t);
for(i=first[u];~i;i=a[i].next){
v=a[i].to;if(v==son[u]||v==fa[u]) continue;
dfs2(v,v);
}
}
ll ans=0;
void dfs(int u){
f[getf(u,0)]=1;
if(son[u]){
dfs(son[u]);
ans+=g[getg(u,0)];
}
int i,v,j,lim;
for(i=first[u];~i;i=a[i].next){
v=a[i].to;if(v==fa[u]||v==son[u]) continue;
dfs(v);lim=getlen(v);
for(j=1;j<=lim;j++) ans+=g[getg(u,j)]*f[getf(v,j-1)];
for(j=1;j<=lim-1;j++) ans+=g[getg(v,j)]*f[getf(u,j-1)];
for(j=1;j<=lim;j++) g[getg(u,j)]+=f[getf(u,j)]*f[getf(v,j-1)];
for(j=0;j<=lim-2;j++) g[getg(u,j)]+=g[getg(v,j+1)];
for(j=1;j<=lim;j++) f[getf(u,j)]+=f[getf(v,j-1)];
}
}
void init(){
memset(first,-1,sizeof(first));
cnte=ans=0;cntf=cntg=1;
memset(f,0,sizeof(f));memset(g,0,sizeof(g));
memset(dep,0,sizeof(dep));memset(maxdep,0,sizeof(maxdep));
}
int main(){
int i,t1,t2;
while((n=read())){
init();
for(i=1;i<n;i++){
t1=read();t2=read();
add(t1,t2);
}
dfs1(1,0);dfs2(1,1);
dfs(1);
printf("%lld\n",ans);
}
}