【题解】[POI2014]HOT-Hotels 加强版
\(\text{Solution:}\)
我们先试着观察一下答案的形式,最初我猜是全部长成一个点在上两个点在子树内的,但这样显然漏掉了一些情况:有可能它不是祖先但是到达另外两个点的距离是相同的。
那么应该怎么处理这种问题呢?我们发现即使是这样的情况也一定可以作为一个点子树内的情况来做。所以我们联想到能不能对它 \(dp.\)
首先一个简单的 \(dp:\) 设 \(f_{i,j}\) 是子树 \(i\) 中距离 \(i\) 为 \(j\) 的节点个数。为了方便就直接让 \(1\) 当做根就好了。那么就会有显然的转移:
考虑接下去怎么做:对于一棵子树,我们一个常见的 \(dp\) 模型是和树形背包一样每次合并一棵子树。在这里可以吗?
可以。
设 \(g_{i,j}\) 表示 \(i\) 子树中,\((x,y),\) 满足它们到它们最近公共祖先的距离为 \(dis\) , 最近公共祖先到 \(i\) 的距离是 \(dis-j\) 的无序二元组对数。
我们发现这样的 \(dp\) 具有好的性质:从 \(g\) 转移只需要从之前 \(dp\) 过的子树中利用 \(f\) 找到合适的就行了。
那么它如何转移?
一种是直接继承其子节点的:
另一种是从合并后的树中新凑出来的:
发现这样不会算重。因为继承之前孩子信息的缘故,我们只需要维护在上一层不会被记录到的部分就行了。观察到 \(g\) 数组关于 \(j\) 实际上是倒序的,这也可以提示我们上一层 \(dp\) 不到的其实就是使得两个点的 \(lca\) 恰好是 \(i\) 的部分。
于是它们的值就被顺利递推出来了,那么,答案如何计算?
第一种情况是:\(i\) 本身作为三元组中的一个出现,这也对应着数组 \(g_{i,0}.\)
另外的情况:一种是从之前 \(dp\) 过的部分凑出两个点作为二元组,另一种是从新加入的部分凑出两个点作为二元组。
这样,我们又发现,这东西 \(dp\) 的复杂度是 \(O(n^2)\) 太大了,怎么办?
观察到数组的下标都是和深度有关的,所以我们可以上 长链剖分。
具体地:对于每一条长链分配一个共用的内存,因为 \(dp\) 数组中有一个重要的特性:继承孩子信息的时候实际上有一个整体位移的操作。
那么这个操作我们就可以用指针的技巧做到 \(O(1)\) 实现。
进行长链剖分之后,我们把轻链的部分暴力统计到长链上,观察到每个点最多在链顶端被暴力统计一次,所以均摊复杂度 \(O(n).\)
关于细节:这题中 \(g\) 数组内存恰好是反向的。
有些问题:笔者对于两个 \(dp\) 数组分别开了两个内存池却答案错误,共用一个却没有问题,是一个神奇的问题 还有待解决。
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int MAXN=1e6+10;
int head[MAXN],tot,n;
int pa[MAXN],len[MAXN],son[MAXN];
int *g[MAXN];
int *f[MAXN],tmp[MAXN],*pos=tmp;
int ans;
inline int read(){
int s=0,w=1;
char ch=getchar();
while(!isdigit(ch)){
if(ch=='-')w=-1;
ch=getchar();
}
while(isdigit(ch)){
s=s*10-48+ch;
ch=getchar();
}
return s*w;
}
struct E{int nxt,to;}e[MAXN];
inline void add(int x,int y){e[++tot]=(E){head[x],y};head[x]=tot;}
void dfs1(int x,int faa){
pa[x]=faa;
for(int i=head[x];i;i=e[i].nxt){
int j=e[i].to;
if(j==faa)continue;
dfs1(j,x);
if(len[j]>len[son[x]])son[x]=j;
}
len[x]=len[son[x]]+1;
}
void dp(int x){
f[x][0]=1;
if(son[x]){
f[son[x]]=f[x]+1;
g[son[x]]=g[x]-1;
dp(son[x]);
}
ans+=g[x][0];
for(int i=head[x];i;i=e[i].nxt){
int v=e[i].to;
if(v==son[x]||v==pa[x])continue;
f[v]=pos;pos+=(len[v]<<1);
g[v]=pos;pos+=(len[v]<<1);
dp(v);
for(int j=0;j<len[v];++j){
if(j)ans+=f[x][j-1]*g[v][j];
ans+=f[v][j]*g[x][j+1];
}
for(int j=0;j<len[v];++j){
g[x][j+1]+=f[x][j+1]*f[v][j];
if(j)g[x][j-1]+=g[v][j];
f[x][j+1]+=f[v][j];
}
}
}
signed main(){
n=read();
for(int i=1;i<n;++i){
int u=read();
int v=read();
add(u,v);add(v,u);
}
dfs1(1,0);
f[1]=pos;pos+=(len[1]<<1);
g[1]=pos;pos+=(len[1]<<1);
dp(1);printf("%lld\n",ans);
return 0;
}