51Nod 1677 treecnt 【树形dp+组合数学+逆元】
51Nod 1677 treecnt
Description:
给定一棵n个节点的树,从1到n标号。选择k个点,你需要选择一些边使得这k个点通过选择的边联通,目标是使得选择的边数最少。
现需要计算对于所有选择k个点的情况最小选择边数的总和为多少。
样例解释:
一共有三种可能:(下列配图蓝色点表示选择的点,红色边表示最优方案中的边)
选择点{1,2}:至少要选择第一条边使得1和2联通。
选择点{1,3}:至少要选择第二条边使得1和3联通。
选择点{2,3}:两条边都要选择才能使2和3联通。
Input
第一行两个数n,k(1<=k<=n<=100000)
接下来n-1行,每行两个数x,y描述一条边(1<=x,y<=n)
Output
一个数,答案对1,000,000,007取模。
Input示例
3 2
1 2
1 3
Output示例
4
题解:
一道非常6的好题,思维性很强!!!值得一做!!!
首先直接做,要选出 k 个点,然后还要选择一些边使他们构成联通块,复杂度很高。
我们换个角度考虑(思维含量贼高):我们考虑一条边对答案的贡献:
设一条边连着 x,y 两侧的节点,那么每次选 k 个点,我们有三种取法:
(以上图片转载)
1:x 一侧的k个点的联通块
2:y 一侧的k个点的联通块
3:既包含 x 又包含 y 的k个点的联通块
显然,只有第三种情况这条边对答案有贡献,所以这条边对答案的贡献即为: C(n,k)-C(x,k)-C(y,k);
这样一来就简单了:我们通过 dfs dp出每个点的子节点数,这样一来就直接带进公式求值就可以了。
(还要注意:这里的组合数要用到阶乘逆元,否则精度会爆)
上代码:
1 #include<bits/stdc++.h> 2 #define ll long long 3 #define mo 1000000007 4 using namespace std; 5 const ll N=100005; 6 const ll M=1e5+5; 7 vector <ll> mp[100005]; 8 ll ans,n,k; 9 ll vis[100500]; 10 ll step[100005]; //阶乘数组 11 ll unstep[100005]; // 阶乘逆元数组 12 13 //快速幂部分 14 ll qpow(ll x,ll n) 15 { 16 ll res=1; 17 for (; n; n>>=1) 18 { 19 if (n&1) res=res*x%mo; 20 x=x*x%mo; 21 } 22 return res; 23 } 24 25 //阶乘逆元部分 26 void init() 27 { 28 step[1]=1; 29 for (int i=2; i<=M; i++) 30 step[i]=step[i-1]*i%mo; 31 unstep[M]=qpow(step[M],mo-2); //根据费马小定理,一个数x模p意义下的逆元就是x^(p-2) 32 for (int i=M-1; i>=0; i--) 33 unstep[i]=unstep[i+1]*(i+1)%mo; 34 } 35 36 //组合数学部分 37 ll C(ll a,ll b) 38 { 39 if (b>a) return 0; 40 if (b==0) return 1; 41 return step[a]*unstep[b]%mo*unstep[a-b]%mo; 42 // 由公式 C(m,n)=m!/(n!*(m-n)!) 转换过来:除变成乘逆元 43 } 44 45 //dfs部分 46 ll dfs(int x) 47 { 48 vis[x]=1; 49 ll count=1; 50 for (ll int i=0; i<mp[x].size(); i++) 51 { 52 ll u=mp[x][i]; 53 if (vis[u]==0) 54 { 55 ll tmp=dfs(u); 56 ans=(ans+(C(n,k)%mo-C(tmp,k)%mo-C(n-tmp,k)%mo)%mo+mo)%mo; 57 count+=tmp; 58 } 59 } 60 return count; 61 } 62 63 64 int main() 65 { 66 while (~scanf("%I64d%I64d",&n,&k)) 67 { 68 init(); 69 memset(vis,0,sizeof(vis)); 70 for (ll i=1; i<=n; i++) mp[i].clear(); 71 for (ll i=1; i<=n-1; i++) 72 { 73 ll x,y; 74 scanf("%I64d%I64d",&x,&y); 75 mp[x].push_back(y); mp[y].push_back(x); 76 } 77 ans=0; 78 dfs(1); 79 printf("%I64d\n",(ans+mo)%mo); 80 } 81 return 0; 82 }
加油加油加油!!! fighting fighting fighting !!!