【BZOJ4675】点对游戏-点分治+概率期望

测试地址:点对游戏
做法:本题需要用到点分治+概率期望。
首先,我们发现每个人选的点数一定是固定的。其次,我们发现一个人选k个点时,选到每种k个点的组合的概率都相等(因为每一步都等概率)。那么根据期望的线性性,我们可以分开考虑每个点对的贡献。
如果一个点对之间的距离不是幸运数,显然不对答案有贡献,否则它就处于Cn2k2个组合中,因为每个组合会做出1Cnk的贡献,所以每个点对就会做出Cn2k2Cnk的贡献,这个式子可以简化成k(k1)n(n1)。所以,我们现在的目的就是求出距离是幸运数的点对数pairs,那么如果一个人能够选k个点,答案就是pairsk(k1)n(n1)
而求距离是某些数的点对数的问题,就是点分治的经典问题了,所以我们可以用点分治O(mnlogn)算出答案,就可以通过此题了。
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,m,num[15],first[50010]={0},tot=0;
int siz[50010],mxson[50010],q[50010],top;
ll ps=0,sum[50010]={0};
bool vis[50010]={0};
struct edge
{
    int v,next;
}e[100010];

void insert(int a,int b)
{
    e[++tot].v=b;
    e[tot].next=first[a];
    first[a]=tot;
}

void init()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=m;i++)
        scanf("%d",&num[i]);
    for(int i=1;i<n;i++)
    {
        int a,b;
        scanf("%d%d",&a,&b);
        insert(a,b),insert(b,a);
    }
}

void dp(int v,int f)
{
    q[++top]=v;
    siz[v]=1,mxson[v]=0;
    for(int i=first[v];i;i=e[i].next)
        if (e[i].v!=f&&!vis[e[i].v])
        {
            dp(e[i].v,v);
            siz[v]+=siz[e[i].v];
            mxson[v]=max(mxson[v],siz[e[i].v]);
        }
}

int find(int v)
{
    top=0;
    dp(v,0);
    int mn=1000000000,ans;
    for(int i=1;i<=top;i++)
        if (max(mxson[q[i]],siz[v]-siz[q[i]])<mn)
            mn=max(mxson[q[i]],siz[v]-siz[q[i]]),ans=q[i];
    return ans;
}

void dfs(int v,int f,int dis)
{
    for(int i=1;i<=m;i++)
        if (num[i]>=dis) ps+=sum[num[i]-dis];
    for(int i=first[v];i;i=e[i].next)
        if (e[i].v!=f&&!vis[e[i].v]) dfs(e[i].v,v,dis+1);
}

void add(int v,int f,int dis)
{
    sum[dis]++;
    for(int i=first[v];i;i=e[i].next)
        if (e[i].v!=f&&!vis[e[i].v]) add(e[i].v,v,dis+1);
}

void clear(int v,int f,int dis)
{
    sum[dis]=0;
    for(int i=first[v];i;i=e[i].next)
        if (e[i].v!=f&&!vis[e[i].v]) clear(e[i].v,v,dis+1);
}

void solve(int v)
{
    v=find(v);
    vis[v]=1;
    sum[0]=1;
    for(int i=first[v];i;i=e[i].next)
        if (!vis[e[i].v])
        {
            dfs(e[i].v,0,1);
            add(e[i].v,0,1);
        }
    sum[0]=0;
    for(int i=first[v];i;i=e[i].next)
        if (!vis[e[i].v]) clear(e[i].v,0,1);
    for(int i=first[v];i;i=e[i].next)
        if (!vis[e[i].v]) solve(e[i].v);
}

void calc()
{
    int k=n/3;
    double k1=(double)k,k2=(double)k,k3=(double)k,tot=(double)n,pairs=(double)ps;
    if (n%3>=1) k1+=1.0;
    if (n%3>=2) k2+=1.0;
    printf("%.2lf\n",pairs*k1*(k1-1)/(tot*(tot-1)));
    printf("%.2lf\n",pairs*k2*(k2-1)/(tot*(tot-1)));
    printf("%.2lf",pairs*k3*(k3-1)/(tot*(tot-1)));
}

int main()
{
    init();
    solve(1);
    calc();

    return 0;
}
posted @ 2018-04-08 09:21  Maxwei_wzj  阅读(129)  评论(0编辑  收藏  举报