【BZOJ3451】Normal-概率期望+点分治+NTT

测试地址:Normal
题目大意:将点分治中找分治重心的过程,变成随机在当前块中取一个点,点分治的每一步骤(即处理一块)消耗的时间为块的大小,问总消耗时间的期望。
做法:本题需要用到概率期望+点分治+NTT。
首先根据期望的线性性,不难想到分开计算每个点被计算的期望次数,累加起来就是答案。而每个点被计算的次数,等于它在点分树上的深度(根深度为1),那么对于一个点x,某点y(可以是点x自己,它自己一定为自己的祖先)作为点分树上它的祖先的概率,等同于在原树中,点y是在路径xy上的点中第一个被选为分治重心的概率,它们是相互独立的,把这些概率累加起来就是点x的期望深度。具体地,因为每个点被第一次选的概率相同,所以点y作为点x祖先的概率为1dis(x,y),其中dis(x,y)xy路径上点的数目。
因此答案就是求i=1nj=1n1dis(i,j),暴力计算是O(n2)的,为了加快这个速度,容易想到计算dis为不同数值时的路径数目,这是一个经典的点分治问题,而在具体计算时,有两种可行的写法:
第一种做法,是在处理某一个分治重心时,将所有分出的子树按大小从小到大排序,然后顺次用FFT/NTT合并信息,显然这样是O(nlog2n)的。
第二种做法,是在处理某一个分治重心时,先直接用一次FFT/NTT算出该块中过分治重心的路径(可能自交)的信息,然后枚举每棵子树去重,显然这样也是O(nlog2n)的。
两种做法都可行,而第二种做法写起来更简单,所以这里我用了第二种做法,于是我们就完成了这一题。至于为什么可以用NTT,因为300002<998244353,所以取模后和原值是相同的,NTT写起来又特别方便,还不用担心精度误差,美滋滋。
我傻逼的地方:TLE,以为是常数写挂,结果是分治重心求错了……简直是太菜了……
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
const ll g=3;
int n,first[30010]={0},tot=0,top,q[30010],r[120010];
int siz[30010],mxson[30010];
ll now[120010]={0},final[120010]={0};
bool vis[30010]={0};
struct edge
{
    int v,next;
}e[60010];

void insert(int a,int b)
{
    e[++tot].v=b;
    e[tot].next=first[a];
    first[a]=tot;
}

void dp(int v,int fa)
{
    siz[v]=1;
    mxson[v]=0;
    q[++top]=v;
    for(int i=first[v];i;i=e[i].next)
        if (e[i].v!=fa&&!vis[e[i].v])
        {
            dp(e[i].v,v);
            mxson[v]=max(mxson[v],siz[e[i].v]);
            siz[v]+=siz[e[i].v];
        }
}

int find(int v)
{
    top=0;
    dp(v,-1);
    int mn=1000000000,mni;
    for(int i=1;i<=top;i++)
        if (max(mxson[q[i]],siz[v]-siz[q[i]])<mn)
        {
            mn=max(mxson[q[i]],siz[v]-siz[q[i]]);
            mni=q[i];
        }
    return mni;
}

ll power(ll a,ll b)
{
    ll s=1,ss=a;
    if (b<0) b+=mod-1;
    while(b)
    {
        if (b&1) s=s*ss%mod;
        ss=ss*ss%mod;b>>=1;
    }
    return s;
}

void NTT(ll *a,int type,int n)
{
    for(int i=0;i<=n;i++)
        if (i<r[i]) swap(a[i],a[r[i]]);
    for(int mid=1;mid<n;mid<<=1)
    {
        ll W=power(g,type*(mod-1)/(mid<<1));
        for(int l=0;l<n;l+=(mid<<1))
        {
            ll w=1;
            for(int k=0;k<mid;k++,w=w*W%mod)
            {
                ll x=a[l+k],y=w*a[l+mid+k]%mod;
                a[l+k]=(x+y)%mod;
                a[l+mid+k]=(x-y+mod)%mod;
            }
        }
    }
    if (type==-1)
    {
        ll inv=power(n,mod-2);
        for(int i=0;i<=n;i++)
            a[i]=a[i]*inv%mod;
    }
}

void calc(int v,int fa,int dis)
{
    now[dis]++;
    for(int i=first[v];i;i=e[i].next)
        if (e[i].v!=fa&&!vis[e[i].v])
            calc(e[i].v,v,dis+1);
}

void calctot(int v,int d,int siz,ll type)
{
    int x=1,bit=0;
    while(x<=(siz<<1)) x<<=1,bit++;
    r[0]=0;
    for(int i=1;i<=x;i++)
        r[i]=(r[i>>1]>>1)|((i&1)<<(bit-1));

    for(int i=0;i<=x;i++) now[i]=0;
    calc(v,-1,d);
    NTT(now,1,x);
    for(int i=0;i<=x;i++) now[i]=now[i]*now[i]%mod;
    NTT(now,-1,x);
    for(int i=0;i<=x;i++) final[i]+=type*now[i];
}

int solve(int v)
{
    int totsiz=1;
    v=find(v);
    vis[v]=1;

    calctot(v,0,siz[v],1);
    for(int i=first[v];i;i=e[i].next)
        if (!vis[e[i].v])
        {
            int newsiz=solve(e[i].v);
            calctot(e[i].v,1,newsiz,-1);
            totsiz+=newsiz;
        }

    vis[v]=0;
    return totsiz;
}

int main()
{
    scanf("%d",&n);
    for(int i=1;i<n;i++)
    {
        int a,b;
        scanf("%d%d",&a,&b);
        insert(a,b),insert(b,a);
    }

    solve(0);
    double ans=0.0;
    for(int i=0;i<=n;i++)
        ans+=(double)final[i]/(double)(i+1);
    printf("%.4lf",ans);

    return 0;
}
posted @ 2018-06-07 22:28  Maxwei_wzj  阅读(115)  评论(0编辑  收藏  举报