BZOJ4543/BZOJ3522 [POI2014]Hotel加强版(长链剖分)
题目好神仙……这个叫长链剖分的玩意儿更神仙……
考虑dp,设\(f[i][j]\)表示以\(i\)为根的子树中到\(i\)的距离为\(j\)的点的个数,\(g[i][j]\)表示\(i\)的子树中有\(g[i][j]\)对点深度相同,他们到LCA的距离为\(d\),且他们的LCA到\(i\)的距离为\(d-j\)。或者换句话来说就是以\(i\)为根的子树中有这么多个点对,而且没有第三个点去和这些点对匹配,第三个点不在\(i\)的子树中且到\(i\)的距离为\(j\),\(g[i][j]\)表示这些点对的个数
设\(u\)为当前点,\(v\)为某一子树,那么转移方程如下
\[f[u][i]+=f[v][i+1]
\]
\[g[u][i-1]+=g[v][i]
\]
\[g[u][i+1]+=f[u][i+1]*f[v][i]
\]
\[ans+=f[u][i-1]*g[v][i]+g[u][i+1]*f[v][i]
\]
如果是原题的\(n\leq 5000\)已经足够了,然而当\(n\leq 100000\)的时候很明显gg了
发现状态数组的第二维实际上跟这个节点的深度有关,于是考虑用长链剖分优化。(不知道什么是长链剖分的可以看看蒟蒻的笔记)简单来说记每一个节点深度最大的儿子为它的重儿子。因为第一次转移的时候有\(f[u][i]=f[v][i-1],g[u][i]=g[v][i+1]\),于是可以类似于dsu on tree的思想,对于每个重儿子的信息直接继承,轻儿子暴力跑一遍。重儿子的信息可以直接用指针来达到\(O(1)\)的转移
这个时间复杂度大概是\(O(n)\)的,对于每个点转移的复杂度为\(\sum dep[v]-dep[son[u]]=\sum dep[v]-dep[u]+1\),然后所有点的加起来除了叶子结点都互相抵消,于是总的复杂度为\(O(n)\)
空间复杂度也是\(O(n)\),因为非叶节点的空间都是由它所在重链的儿子转移来的,所以对每个叶节点开正比于此重链长度的空间即可
//minamoto
#include<bits/stdc++.h>
#define ll long long
using namespace std;
#define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1<<21],*p1=buf,*p2=buf;
int read(){
int res,f=1;char ch;
while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0');
return res*f;
}
const int N=1e5+5,M=1005;
int head[N],Next[N<<1],ver[N<<1],tot;
inline void add(int u,int v){ver[++tot]=v,Next[tot]=head[u],head[u]=tot;}
ll memp[N*5],*f[N],*g[N],*to=memp+5,ans;
int n,dep[N],mx[N];
void dfs(int u,int fa){
mx[u]=u;
for(int i=head[u];i;i=Next[i]){
int v=ver[i];
if(v!=fa){
dep[v]=dep[u]+1,dfs(v,u);
if(dep[mx[v]]>dep[mx[u]])mx[u]=mx[v];
}
}
for(int i=head[u];i;i=Next[i]){
int v=ver[i];
if(v!=fa&&(mx[v]!=mx[u]||u==1)){
v=mx[v],to+=dep[v]-dep[u]+1;
f[v]=to,g[v]=(to+=1),to+=(dep[v]-dep[u])*2+1;
}
}
}
void dp(int u,int fa){
for(int i=head[u];i;i=Next[i]){
int v=ver[i];if(v==fa)continue;dp(v,u);
if(mx[v]==mx[u])f[u]=f[v]-1,g[u]=g[v]+1;
}
ans+=g[u][0],f[u][0]=1;
for(int i=head[u];i;i=Next[i]){
int v=ver[i];if(v==fa||mx[v]==mx[u])continue;
for(int j=0;j<=dep[mx[v]]-dep[u];++j)
ans+=f[u][j-1]*g[v][j]+g[u][j+1]*f[v][j];
for(int j=0;j<=dep[mx[v]]-dep[u];++j){
g[u][j-1]+=g[v][j];
g[u][j+1]+=f[u][j+1]*f[v][j];
f[u][j+1]+=f[v][j];
}
}
}
int main(){
// freopen("testdata.in","r",stdin);
n=read();
for(int i=1,u,v;i<n;++i)u=read(),v=read(),add(u,v),add(v,u);
while(to!=memp)*to=0,--to;*to=0,++to;
dep[1]=1;dfs(1,0),dp(1,0);
printf("%lld\n",ans);return 0;
}
深深地明白自己的弱小