[bzoj4543] [POI2014]Hotel加强版

Description

有一个树形结构的宾馆,n个房间,n-1条无向边,每条边的长度相同,任意两个房间可以相互到达。吉丽要给他的三个妹子各开(一个)房(间)。三个妹子住的房间要互不相同(否则要打起来了),为了让吉丽满意,你需要让三个房间两两距离相同。
有多少种方案能让吉丽满意?

Input

第一行一个数n。
接下来n-1行,每行两个数x,y,表示x和y之间有一条边相连。

Output

让吉丽满意的方案数。

Sample Input

7
1 2
5 7
2 5
2 3
5 6
4 5

Sample Output

5

Solution

先考虑暴力怎么做。

\(f[x][d]\)表示\(x\)的子树里距离\(x\)\(d\)的点的个数,\(g[x][a]\)表示\(x\)子树内距离\(lca\)\(d\)\(lca\)距离\(x\)\(d-a\)的点对个数。

那么,转移就是:

\[f[x][d]+=f[v][d-1],\\ g[x][d]+=g[v][d+1]+f[x][d]\cdot f[x][d-1] \]

更新答案就是:

\[ans+=f[x][a-1]\cdot g[v][a]+g[v][a]\cdot f[x][a+1] \]

转移顺序注意下,注意这里的\(f[x]\)\(g[x]\)都是不包括当前子树的。

那个关于\(g\)的转移的后半部分是以\(x\)\(lca\)的点数。

然后长链剖分一波,就\(O(n)\)了。

#include<bits/stdc++.h>
using namespace std;

#define int long long 

void read(int &x) {
    x=0;int f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
    for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
 
void print(int x) {
    if(x<0) putchar('-'),x=-x;
    if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}

const int maxn = 1e6+10;

int n,head[maxn],tot,mxdep[maxn],hs[maxn],ans;
int space[maxn<<2],*f[maxn],*g[maxn],*t=space;
struct edge{int to,nxt;}e[maxn<<1];

void add(int u,int v) {e[++tot]=(edge){v,head[u]},head[u]=tot;}
void ins(int u,int v) {add(u,v),add(v,u);}

void dfs(int x,int fa) {
	for(int i=head[x];i;i=e[i].nxt)
		if(e[i].to!=fa) {
			dfs(e[i].to,x);
			mxdep[x]=max(mxdep[x],mxdep[e[i].to]);
			if(mxdep[e[i].to]>=mxdep[hs[x]]) hs[x]=e[i].to;
		}
	mxdep[x]++;
}

void solve(int x,int fa) {
	if(hs[x]) f[hs[x]]=f[x]+1,g[hs[x]]=g[x]-1,solve(hs[x],x);
	f[x][0]=1,ans+=g[x][0];
	for(int i=head[x];i;i=e[i].nxt) {
		int v=e[i].to;if(v==hs[x]||v==fa) continue;
		f[v]=t,t+=mxdep[v]*2+3,g[v]=t,t+=mxdep[v]*2+3;
		solve(e[i].to,x);
		for(int j=0;j<mxdep[v];j++) {
			if(j) ans+=f[x][j-1]*g[v][j];
			ans+=f[v][j]*g[x][j+1];
		}
		for(int j=0;j<mxdep[v];j++) g[x][j+1]+=f[x][j+1]*f[v][j];
		for(int j=0;j<mxdep[v];j++) {
			if(j) g[x][j-1]+=g[v][j];
			f[x][j+1]+=f[v][j];
		}
	}
}

signed main() {
	read(n);
	for(int i=1,x,y;i<n;i++) read(x),read(y),ins(x,y);
	dfs(1,0);
	f[1]=t,t+=mxdep[1]*2+3,g[1]=t,t+=mxdep[1]*2+3;
	solve(1,0),write(ans);
	return 0;
}
posted @ 2019-01-25 15:27  Hyscere  阅读(170)  评论(0编辑  收藏  举报