返回顶部

树形dp

 

这里是主要的公式,可以这样理解:所有点到父亲节点u的距离和sall[u]已经算出来了,
那么算v这个节点的时候,不在v子树范围内的点到v的距离都多了1,
所以加上n-sz[v],v的子树的点到v的距离都减少了1,所以要减去sz[v].

 

题意:

给一棵树,然后有Q次询问,每次询问给你两个点,问你加一条边之后,这两个点所在的简单环的期望长度是多少

简单环即这两个点在一个环上,这个环是没有重边的。

solution:

两个点u,v,只有3种情况

1.lca(u,v)=v;

这种情况的答案等于v上面的点的距离和除以v上面的点数量+u下面的点距离和除以u下面的点数。

2.lca(u,v)=u;

同上

3.lca(u,v)!=u!=v

这种情况的答案等于v下面的点的距离和除以v下面的点的数量+u下面的点的距离和除以v下面的点的数量。

下面的点的距离和,这个东西,用树形dp去解决就好了。

至于上面的点的距离和,假设lca(u,v)=v这种情况,那么sumUp[v]=sumAll[v]-sumDown[z]-sz[z],z点是v到u的那条路径上的v的儿子。

sumAll[i]是所有点到i点的距离和,sz是这棵子树的大小,sumDown[i]是子树的距离和。

然后这道题就结束了。

  1 /*************************************************************************
  2     > File Name: 树形dp.cpp
  3     > Author: QWX
  4     > Mail: 
  5     > Created Time: 2018/11/6 17:43:14
  6  ************************************************************************/
  7 
  8 
  9 //{{{ #include
 10 //#include<iostream>
 11 #include<cstdio>
 12 #include<algorithm>
 13 #include<iomanip>
 14 #include<vector>
 15 #include<cmath>
 16 #include<queue>
 17 #include<map>
 18 #include<set>
 19 #include<string>
 20 #include<cstring>
 21 #include<complex>
 22 #include<cassert>
 23 //#include<bits/stdc++.h>
 24 #define vi vector<int>
 25 #define pii pair<int,int>
 26 #define mp make_pair
 27 #define pb push_back
 28 #define fi first
 29 #define se second
 30 #define pw(x) (1ll << (x))
 31 #define sz(x) ((int)(x).size())
 32 #define all(x) (x).begin(),(x).end()
 33 #define rep(i,l,r) for(int i=(l);i<(r);i++)
 34 #define per(i,r,l) for(int i=(r);i>=(l);i--)
 35 #define FOR(i,l,r) for(int i=(l);i<=(r);i++)
 36 #define cl(a,b) memset(a,b,sizeof(a))
 37 #define fastio ios::sync_with_stdio(false);cin.tie(0);
 38 #define lson l , mid , ls
 39 #define rson mid + 1 , r , rs
 40 #define INF 0x3f3f3f3f
 41 #define LINF 0x3f3f3f3f3f3f3f3f
 42 #define ll long long
 43 #define ull unsigned long long
 44 #define dd(x) cout << #x << " = " << (x) << "," 
 45 #define de(x) cout << #x << " = " << (x) << "\n" 
 46 #define endl "\n"
 47 using namespace std;
 48 //}}}
 49 
 50 
 51 const int N=1e5+7;
 52 int n,m;
 53 int siz[N],fa[N][25],dep[N];
 54 ll sdown[N],sall[N];
 55 vi G[N];
 56 
 57 void dfs(int u,int f)
 58 {
 59     siz[u]=1;
 60     sdown[u]=0;
 61     for(auto v:G[u]){
 62         if(v==f)continue;
 63         dep[v]=dep[u]+1;
 64         fa[v][0]=u;
 65 //        dd(v),de(fa[v][0]);
 66         dfs(v,u);
 67         siz[u]+=siz[v];
 68         sdown[u]+=sdown[v]+siz[v];
 69     }
 70 }
 71 void dfs2(int u,int f)
 72 {
 73     for(auto v:G[u]){
 74         if(v==f)continue;
 75         sall[v]=sall[u]+n-2*siz[v];
 76         //这里是主要的公式,可以这样理解:所有点到父亲节点u的距离和sall[u]已经算出来了,
 77         //那么算v这个节点的时候,不在v子树范围内的点到v的距离都多了1,    
 78         //所以加上n-sz[v],v的子树的点到v的距离都减少了1,所以要减去sz[v].
 79         dfs2(v,u);    
 80     }
 81                 
 82 }
 83 int up(int u,int d)
 84 {
 85     per(i,20,0)if(d&(1<<i))u=fa[u][i];
 86     return u;
 87 }
 88 int LCA(int u,int v)
 89 {
 90     if(dep[u]<dep[v])swap(u,v);
 91     u=up(u,dep[u]-dep[v]);
 92     if(u==v)return u;
 93     per(i,20,0)if(fa[u][i]!=fa[v][i])u=fa[u][i],v=fa[v][i];
 94 //    dd(u);de(fa[u][0]);
 95 //    dd(v);de(fa[v][0]); 
 96     return fa[u][0];
 97 }
 98 void Init()
 99 {
100     dfs(1,-1);
101     sall[1]=sdown[1];
102     dfs2(1,-1);
103     FOR(j,1,20)FOR(i,1,n)fa[i][j]=fa[fa[i][j-1]][j-1];
104 //    FOR(i,1,5)dd(i),de(fa[i][0]);
105 }
106 int main()
107 {
108     scanf("%d%d",&n,&m); 
109     rep(i,0,n-1){
110         int a,b; scanf("%d%d",&a,&b);
111         G[a].pb(b);
112         G[b].pb(a);    
113     }
114     Init();
115     rep(i,0,m){
116         int u,v;scanf("%d%d",&u,&v);
117         int lca=LCA(u,v);
118 //        de(lca);
119         double ans=dep[u]+dep[v]-2*dep[lca]+1;
120         if(lca==u)swap(u,v);
121         if(lca==v){
122             int v2=up(u,dep[u]-dep[v]-1);
123 //            de(v2);
124 //              de(sall[2]);
125             double supv=sall[v]-sdown[v2]-siz[v2];
126 //            dd(sall[v]),dd(sdown[v2]),de(siz[v2]);
127             ans+=1.0*supv/(n-siz[v2])+1.0*sdown[u]/siz[u];
128         }else ans+=1.0*sdown[u]/siz[u]+1.0*sdown[v]/siz[v];
129         printf("%.10f\n",ans);
130     }
131 }
View Code

 

posted @ 2018-11-06 21:21  牛奶加咖啡~  阅读(142)  评论(0编辑  收藏  举报