51nod-1322: 关于树的函数

【传送门:51nod-1322


简要题意:

  给出n个点的两棵无根树,编号都是从0到n-1

  现在每棵树任意选出一条边割断,设第一棵树选出的边为e1,第二棵树选出的边为e2

  很显然割断后两棵树各分成了四棵树,设第一棵树分成了A1树和B1树,第二棵树分成了A2树和B2树

  设S(a,b)为a树和b树之间相同编号的点的个数

  那么割断这两条边的价值为S(A1,B1),S(A1,B2),S(A2,B1),S(A2,B2)中的最大值的平方

  求出每对e1,e2的价值和


题解:

  又是一道卡了挺久的题

  我们把节点全部+1,把第二棵树的节点编号设为n+1到2n(方便操作)

  设两棵树的根节点分别为1和n+1

  先在第一棵树中找到一个非根节点,然后在第二棵树中找到另一个非根节点

  对于这两个节点显然有三种情况:

  1.取两个节点的子树

  2.取其中一个点的子树,另一个点的反子树(就是除了子树外的点)

  3.取两个节点的反子树

  然后对于三种情况更新答案就行了

  具体操作有点麻烦,看代码吧(有点小懒)

  PS:不知道该分到哪个专题,看了一下路牌,发现dalao都说是树形计数DP,那我就只好分成树形计数DP


参考代码:

#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long LL;
struct node
{
    int x,y,next;
}a[21000];int len,last[8100];
void ins(int x,int y)
{
    len++;
    a[len].x=x;a[len].y=y;
    a[len].next=last[x];last[x]=len;
}
int tot[8100],n;
bool v[4100][4100];
void pre(int x,int fa)
{
    tot[x]=1;
    if(x<=n) v[x][x]=true;
    for(int k=last[x];k;k=a[k].next)
    {
        int y=a[k].y;
        if(y==fa) continue;
        pre(y,x);
        if(x<=n) for(int i=1;i<=n;i++) v[x][i]|=v[y][i]; 
        tot[x]+=tot[y];
    }
}
LL ans;int s;
int dou(int x){return x*x;}
void getd(int x,int fa,int p)
{
    if(v[p][x-n]==true) s++;
    int t=s;
    for(int k=last[x];k;k=a[k].next)
    {
        int y=a[k].y;
        if(y==fa) continue;
        getd(y,x,p);
        t=s-t;
        ans+=dou(max(max(t,tot[y]-t),max(tot[p]-t,n-tot[p]-tot[y]+t)));
        t=s;
    }
}
void solve(int x,int fa)
{
    for(int k=last[x];k;k=a[k].next)
    {
        int y=a[k].y;
        if(y==fa) continue;
        s=0;getd(n+1,0,y);
        solve(y,x);
    }
}
int main()
{
    scanf("%d",&n);
    len=0;memset(last,0,sizeof(last));
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);x++;y++;
        ins(x,y);ins(y,x);
    }
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);x++;y++;
        ins(x+n,y+n);ins(y+n,x+n);
    }
    memset(v,false,sizeof(v));
    pre(1,0);pre(n+1,0);
    ans=0;solve(1,0);
    printf("%lld\n",ans);
    return 0;
}

 

posted @ 2018-10-12 20:35  Star_Feel  阅读(240)  评论(0编辑  收藏  举报