bzoj 4543 HOTEL 加强版
题目大意:
求树上取三个点这三个点两两距离相等的方案数
思路:
远古时候的$n^2$做法是换根 但那样无法继续优化了
学习了一波长链剖分
考虑如何在一棵树上进行dp
设$f[i][j]$表示以$i$为根的子树内与$i$的距离为$j$的点数量
$g[i][j]$表示以$i$为根的子树内满足与lca距离为$d$且lca与$i$的距离为$d-j$的点对数(lca在子树内)
对于每个子树 对答案的贡献为$g[x][0]$与对每一个新进来的子树$\sum _{i=0} ^{mxd[v]} f[x][i-1]*g[v][i]+g[x][i+1]*f[v][i]$
更新$f,g$的时候除了正常的继承的子树 $g[x][i+1]+=f[x][i+1]*f[v][i]$
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 #include<cstdlib> 5 #include<cmath> 6 #include<algorithm> 7 #include<queue> 8 #include<vector> 9 #include<map> 10 #include<set> 11 #define ll long long 12 #define inf 2139062143 13 #define MAXN 100100 14 #define MOD 998244353 15 #define rep(i,s,t) for(register int i=(s),i##__end=(t);i<=i##__end;++i) 16 #define dwn(i,s,t) for(register int i=(s),i##__end=(t);i>=i##__end;--i) 17 #define ren for(register int i=fst[x];i;i=nxt[i]) 18 #define pb(i,x) vec[i].push_back(x) 19 #define pls(a,b) (a+b)%MOD 20 #define mns(a,b) (a-b+MOD)%MOD 21 #define mul(a,b) (1LL*(a)*(b))%MOD 22 using namespace std; 23 inline int read() 24 { 25 int x=0,f=1;char ch=getchar(); 26 while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();} 27 while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();} 28 return x*f; 29 } 30 int n,*f[MAXN],*g[MAXN],tmp[MAXN<<2],*id=tmp,ans; 31 int fst[MAXN],nxt[MAXN<<1],to[MAXN<<1],cnt,mxd[MAXN],son[MAXN]; 32 void add(int u,int v) {nxt[++cnt]=fst[u],fst[u]=cnt,to[cnt]=v;} 33 void dfs(int x,int pa) 34 { 35 ren if(to[i]^pa) {dfs(to[i],x);if(mxd[to[i]]>mxd[son[x]]) son[x]=to[i];} 36 mxd[x]=mxd[son[x]]+1; 37 } 38 void New(int x) {f[x]=id,id+=mxd[x]<<1,g[x]=id,id+=mxd[x]<<1;} 39 void dp(int x,int pa) 40 { 41 if(son[x]) {f[son[x]]=f[x]+1,g[son[x]]=g[x]-1;dp(son[x],x);} 42 f[x][0]=1,ans+=g[x][0]; 43 ren if(to[i]^pa&&to[i]^son[x]) 44 { 45 New(to[i]);dp(to[i],x); 46 rep(j,0,mxd[to[i]]) 47 { 48 ans+=g[x][j+1]*f[to[i]][j]; 49 if(j) ans+=f[x][j-1]*g[to[i]][j]; 50 } 51 rep(j,0,mxd[to[i]]) 52 { 53 g[x][j+1]+=f[x][j+1]*f[to[i]][j],f[x][j+1]+=f[to[i]][j]; 54 if(j) g[x][j-1]+=g[to[i]][j]; 55 } 56 } 57 } 58 int main() 59 { 60 n=read();int a,b;rep(i,2,n) a=read(),b=read(),add(a,b),add(b,a); 61 dfs(1,0);New(1);dp(1,0);printf("%d\n",ans); 62 }