51NOD 1353:树——题解

http://www.51nod.com/onlineJudge/questionCode.html#!problemId=1353

今天小a在纸上研究树的形态,众所周知的,有芭蕉树,樟树,函树,平衡树,树套树等等。那么小a今天在研究的就是其中的平衡树啦。

小a认为一棵平衡树的定义为一个n个点,从1到n编号,n-1条边,且任意两点间一定存在唯一一条简单路径,且n>=k。

现在小a看到一棵很大很大的树,足足有n个节点,这里n一定大于等于k!为了方便起见,它想把这个树删去某些边,使得剩下的若干个联通块都满足是平衡树。这时,小b走过来,不屑一顾的说,如果我一条边都不删,那么也算一棵平衡树咯。

小a对于小b的不屑感到很不爽,并问小b,你能算出我删边的方案总数使得满足我的条件吗?两个删边的方案A,B不同当且仅当存在某一条边属于集合A且不属于集合B,或者存在某一条边属于集合B且不属于集合A。为了让你方便,你只要告诉我答案对1000000007(1e9+7)取模就行了。

小b犯了难,找到了身为程序猿的你。

(我dp真垃圾)

题解如下:

我们令dp[i][j]表示以i为根且当前联通块大小为j的方案总数,特别的,dp[i][0]表示割点当前点与其父亲是棵平衡树的方案总数。

对于u的一个孩子v可以得到转移方程dp[u][j+k]=dp[u][j]*dp[v][k]

另外dp[u][0]=Σdp[u][j](j>=题目给定的k)

这样乍看是n^3的,有一个技巧可以做到n^2即每次dp时,只枚举当前u所在子树的大小,每当枚举到它的其中孩子时,当前u所在子树的大小加上它孩子为根的子树的大小。可以理解为每一个点对只被枚举到一次。

最后答案即为dp[root][0]

如果你没看懂的话,反正我也没看懂,我讲一遍我的思路。

我们还是按照上一道题通过dfs序来更新dp值降低复杂度。

设dp[i][j]表示以i为根i所在联通块大小为j的方案数,dp[i][0]为符合条件的总方案数。

可以看出一定有dp[u][j+k]=dp[u][j]*dp[v][k],(相当于u和v之间断开),j的大小为当前我们所遍历完的子树大小sz(毕竟更大的你也没更新过。)

当然也会有dp[u][j]=dp[v][0]*dp[u][j];

但是复杂度算起来为什么是O(n^2)的呢?考虑这就相当于在u和v的子树当中找点对使得这些点一下割掉,u和v割掉,变相等于找点对。

#include<cstdio>
#include<iostream>
#include<vector>
#include<queue>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=2500;
const int p=1e9+7;
inline int read(){
    int X=0,w=0;char ch=0;
    while(!isdigit(ch)){w|=ch=='-';ch=getchar();}
    while(isdigit(ch))X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
    return w?-X:X;
}
struct node{
    int to,nxt;
}e[N*2];
int cnt,head[N];
int n,k,sz[N],dp[N][N];
inline void add(int u,int v){
    e[++cnt].to=v;e[cnt].nxt=head[u];head[u]=cnt;
}
void dfs(int u,int f){
    sz[u]=1;dp[u][1]=1;
    for(int i=head[u];i;i=e[i].nxt){
    int v=e[i].to;
    if(v==f)continue;
    dfs(v,u);
    for(int j=sz[u];j>=1;j--){
        for(int l=sz[v];l>=1;l--){
        dp[u][j+l]=(dp[u][j+l]+(ll)dp[u][j]*dp[v][l])%p;
        }
        dp[u][j]=(ll)dp[v][0]*dp[u][j]%p;
    }
    sz[u]+=sz[v];
    }
    for(int i=k;i<=sz[u];i++)
    dp[u][0]=(dp[u][0]+dp[u][i])%p;
}
int main(){
    n=read(),k=read();
    for(int i=1;i<n;i++){
    int u=read(),v=read();
    add(u,v);add(v,u);
    }
    dfs(1,0);
    printf("%d\n",dp[1][0]);
    return 0;
}

+++++++++++++++++++++++++++++++++++++++++++

+本文作者:luyouqi233。               +

+欢迎访问我的博客:http://www.cnblogs.com/luyouqi233/+

+++++++++++++++++++++++++++++++++++++++++++

posted @ 2018-04-14 14:59  luyouqi233  阅读(269)  评论(0编辑  收藏  举报