[BZOJ4543]Hotel加强版

description

题面

给定一棵\(n\)个节点的树,求出不同三元组\((x,y,z)\)的个数,
其中\(dist(x,y)=dist(y,z)=dist(x,z)\)

solution

考虑暴力。

\(f[i][j]\)表示\(i\)的子树中深度为\(j\)的点数,
\(g[i][j]\)表示\(i\)子树中两点深度相同\((d)\)且其\(lca\)\(i\)的距离为\(d-j\)的点对数。

那么初值为\(f[i][0]=1\),在\(i\)加入一棵子树\(son\)时,我们有

\[ans+=f[i][j]\times g[son][j+1]+g[i][j]\times f[son][j-1] \]

\[g[i][j]+=g[son][j+1]+f[i][j]\times f[son][j-1] \]

\[f[i][j]+=f[son][j-1] \]

这样直接\(DP\)\(O(n^2)\)

但我们发现这样的\(DP\)的第二关键字是深度:第一棵子树在赋值给父亲时,会整体移位
因此我们引入长链剖分优化
每次非链首的部分使用指针可以在\(O(1)\)的时间内将\(DP\)数组移位给其父亲,
链首部分使用\(O(链长)\)的时间暴力转移到父亲
因为\(\sum链长=n\),所以这样做是\(O(n)\)

code

#include<bits/stdc++.h>
#include<algorithm>
#include<iostream>
#include<fstream>
#include<cstdlib>
#include<iomanip>
#include<cstring>
#include<complex>
#include<clocale>
#include<cctype>
#include<vector>
#include<cstdio>
#include<string>
#include<bitset>
#include<ctime>
#include<cmath>
#include<queue>
#include<stack>
#include<list>
#include<map>
#include<set>
#define FILE "a"
#define mp make_pair
#define pb push_back
#define RG register
#define il inline
//#define RAND
using namespace std;
typedef unsigned long long ull;
typedef vector<int>VI;
typedef long long ll;
typedef double dd;
const int inf=2147483647;
const dd pi=acos(-1);
const dd eps=1e-10;
const int mod=1e9+7;
const int N=100010;
const ll INF=1e18+1;
il ll read(){
	RG ll data=0,w=1;RG char ch=getchar();
	while(ch!='-'&&(ch<'0'||ch>'9'))ch=getchar();
	if(ch=='-')w=-1,ch=getchar();
	while(ch<='9'&&ch>='0')data=data*10+ch-48,ch=getchar();
	return data*w;
}
il int make(int l,int r){return rand()%(r-l+1)+l;}
il void file(){	
#ifdef RAND
	freopen("seed.in","r",stdin);
	RG int seed=read();fclose(stdin);
	srand(time(NULL)+seed);
	freopen("seed.out","w",stdout);
	seed=rand();printf("%d\n",seed);
	fclose(stdout);	
	freopen(FILE".in","w",stdout);
#endif
#ifndef RAND
	freopen(FILE".in","r",stdin);
	freopen(FILE".out","w",stdout);
#endif
}
/*********************************************************************/

int n;
int head[N],nxt[N<<1],to[N<<1],cnt;
il void add(int u,int v){
	to[++cnt]=v;
	nxt[cnt]=head[u];
	head[u]=cnt;
}

int fa[N],dep[N],len[N],son[N];
void dfs1(int u,int ff){
	fa[u]=ff;dep[u]=dep[ff]+1;len[u]=1;
	for(RG int i=head[u];i;i=nxt[i]){
		RG int v=to[i];if(v==ff)continue;
		dfs1(v,u);if(len[u]<len[v]+1)len[u]=len[v]+1,son[u]=v;
	}
}

ll tmp[6*N],*f[N],*g[N],*pos=tmp+1,ans;
void dfs2(int u){
	if(son[u]){f[son[u]]=f[u]+1;g[son[u]]=g[u]-1;dfs2(son[u]);}
	f[u][0]=1;ans+=g[u][0];
	for(RG int i=head[u];i;i=nxt[i]){
		RG int v=to[i];
		if(v==fa[u]||v==son[u])continue;
		f[v]=pos;pos+=(len[v]+1)<<1;
		g[v]=pos;pos+=len[v]+1;dfs2(v);
		for(RG int j=len[v];~j;j--){
			if(j)ans+=g[u][j]*f[v][j-1];
			if(j<len[v])ans+=f[u][j]*g[v][j+1];
			if(j<len[v])g[u][j]+=g[v][j+1];
			if(j)g[u][j]+=f[u][j]*f[v][j-1];
			if(j)f[u][j]+=f[v][j-1];
		}
	}
}

int main()
{
	n=read();
	for(RG int i=1,u,v;i<n;i++){
		u=read();v=read();add(u,v);add(v,u);
	}
	dfs1(1,0);
	f[1]=pos;pos+=(len[1]+1)<<1;
	g[1]=pos;pos+=len[1]+1;
	dfs2(1);
	printf("%lld\n",ans);
	return 0;
}

posted @ 2018-08-10 20:16  cjfdf  阅读(214)  评论(0编辑  收藏  举报