[bzoj4543] [POI2014]Hotel加强版
Description
有一个树形结构的宾馆,n个房间,n-1条无向边,每条边的长度相同,任意两个房间可以相互到达。吉丽要给他的三个妹子各开(一个)房(间)。三个妹子住的房间要互不相同(否则要打起来了),为了让吉丽满意,你需要让三个房间两两距离相同。
有多少种方案能让吉丽满意?
Input
第一行一个数n。
接下来n-1行,每行两个数x,y,表示x和y之间有一条边相连。
Output
让吉丽满意的方案数。
Sample Input
7
1 2
5 7
2 5
2 3
5 6
4 5
Sample Output
5
Solution
先考虑暴力怎么做。
设\(f[x][d]\)表示\(x\)的子树里距离\(x\)为\(d\)的点的个数,\(g[x][a]\)表示\(x\)子树内距离\(lca\)为\(d\),\(lca\)距离\(x\)为\(d-a\)的点对个数。
那么,转移就是:
\[f[x][d]+=f[v][d-1],\\
g[x][d]+=g[v][d+1]+f[x][d]\cdot f[x][d-1]
\]
更新答案就是:
\[ans+=f[x][a-1]\cdot g[v][a]+g[v][a]\cdot f[x][a+1]
\]
转移顺序注意下,注意这里的\(f[x]\)和\(g[x]\)都是不包括当前子树的。
那个关于\(g\)的转移的后半部分是以\(x\)为\(lca\)的点数。
然后长链剖分一波,就\(O(n)\)了。
#include<bits/stdc++.h>
using namespace std;
#define int long long
void read(int &x) {
x=0;int f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
void print(int x) {
if(x<0) putchar('-'),x=-x;
if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}
const int maxn = 1e6+10;
int n,head[maxn],tot,mxdep[maxn],hs[maxn],ans;
int space[maxn<<2],*f[maxn],*g[maxn],*t=space;
struct edge{int to,nxt;}e[maxn<<1];
void add(int u,int v) {e[++tot]=(edge){v,head[u]},head[u]=tot;}
void ins(int u,int v) {add(u,v),add(v,u);}
void dfs(int x,int fa) {
for(int i=head[x];i;i=e[i].nxt)
if(e[i].to!=fa) {
dfs(e[i].to,x);
mxdep[x]=max(mxdep[x],mxdep[e[i].to]);
if(mxdep[e[i].to]>=mxdep[hs[x]]) hs[x]=e[i].to;
}
mxdep[x]++;
}
void solve(int x,int fa) {
if(hs[x]) f[hs[x]]=f[x]+1,g[hs[x]]=g[x]-1,solve(hs[x],x);
f[x][0]=1,ans+=g[x][0];
for(int i=head[x];i;i=e[i].nxt) {
int v=e[i].to;if(v==hs[x]||v==fa) continue;
f[v]=t,t+=mxdep[v]*2+3,g[v]=t,t+=mxdep[v]*2+3;
solve(e[i].to,x);
for(int j=0;j<mxdep[v];j++) {
if(j) ans+=f[x][j-1]*g[v][j];
ans+=f[v][j]*g[x][j+1];
}
for(int j=0;j<mxdep[v];j++) g[x][j+1]+=f[x][j+1]*f[v][j];
for(int j=0;j<mxdep[v];j++) {
if(j) g[x][j-1]+=g[v][j];
f[x][j+1]+=f[v][j];
}
}
}
signed main() {
read(n);
for(int i=1,x,y;i<n;i++) read(x),read(y),ins(x,y);
dfs(1,0);
f[1]=t,t+=mxdep[1]*2+3,g[1]=t,t+=mxdep[1]*2+3;
solve(1,0),write(ans);
return 0;
}