【BZOJ3451】Normal-概率期望+点分治+NTT
测试地址:Normal
题目大意:将点分治中找分治重心的过程,变成随机在当前块中取一个点,点分治的每一步骤(即处理一块)消耗的时间为块的大小,问总消耗时间的期望。
做法:本题需要用到概率期望+点分治+NTT。
首先根据期望的线性性,不难想到分开计算每个点被计算的期望次数,累加起来就是答案。而每个点被计算的次数,等于它在点分树上的深度(根深度为),那么对于一个点,某点(可以是点自己,它自己一定为自己的祖先)作为点分树上它的祖先的概率,等同于在原树中,点是在路径到上的点中第一个被选为分治重心的概率,它们是相互独立的,把这些概率累加起来就是点的期望深度。具体地,因为每个点被第一次选的概率相同,所以点作为点祖先的概率为,其中为到路径上点的数目。
因此答案就是求,暴力计算是的,为了加快这个速度,容易想到计算为不同数值时的路径数目,这是一个经典的点分治问题,而在具体计算时,有两种可行的写法:
第一种做法,是在处理某一个分治重心时,将所有分出的子树按大小从小到大排序,然后顺次用FFT/NTT合并信息,显然这样是的。
第二种做法,是在处理某一个分治重心时,先直接用一次FFT/NTT算出该块中过分治重心的路径(可能自交)的信息,然后枚举每棵子树去重,显然这样也是的。
两种做法都可行,而第二种做法写起来更简单,所以这里我用了第二种做法,于是我们就完成了这一题。至于为什么可以用NTT,因为,所以取模后和原值是相同的,NTT写起来又特别方便,还不用担心精度误差,美滋滋。
我傻逼的地方:TLE,以为是常数写挂,结果是分治重心求错了……简直是太菜了……
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
const ll g=3;
int n,first[30010]={0},tot=0,top,q[30010],r[120010];
int siz[30010],mxson[30010];
ll now[120010]={0},final[120010]={0};
bool vis[30010]={0};
struct edge
{
int v,next;
}e[60010];
void insert(int a,int b)
{
e[++tot].v=b;
e[tot].next=first[a];
first[a]=tot;
}
void dp(int v,int fa)
{
siz[v]=1;
mxson[v]=0;
q[++top]=v;
for(int i=first[v];i;i=e[i].next)
if (e[i].v!=fa&&!vis[e[i].v])
{
dp(e[i].v,v);
mxson[v]=max(mxson[v],siz[e[i].v]);
siz[v]+=siz[e[i].v];
}
}
int find(int v)
{
top=0;
dp(v,-1);
int mn=1000000000,mni;
for(int i=1;i<=top;i++)
if (max(mxson[q[i]],siz[v]-siz[q[i]])<mn)
{
mn=max(mxson[q[i]],siz[v]-siz[q[i]]);
mni=q[i];
}
return mni;
}
ll power(ll a,ll b)
{
ll s=1,ss=a;
if (b<0) b+=mod-1;
while(b)
{
if (b&1) s=s*ss%mod;
ss=ss*ss%mod;b>>=1;
}
return s;
}
void NTT(ll *a,int type,int n)
{
for(int i=0;i<=n;i++)
if (i<r[i]) swap(a[i],a[r[i]]);
for(int mid=1;mid<n;mid<<=1)
{
ll W=power(g,type*(mod-1)/(mid<<1));
for(int l=0;l<n;l+=(mid<<1))
{
ll w=1;
for(int k=0;k<mid;k++,w=w*W%mod)
{
ll x=a[l+k],y=w*a[l+mid+k]%mod;
a[l+k]=(x+y)%mod;
a[l+mid+k]=(x-y+mod)%mod;
}
}
}
if (type==-1)
{
ll inv=power(n,mod-2);
for(int i=0;i<=n;i++)
a[i]=a[i]*inv%mod;
}
}
void calc(int v,int fa,int dis)
{
now[dis]++;
for(int i=first[v];i;i=e[i].next)
if (e[i].v!=fa&&!vis[e[i].v])
calc(e[i].v,v,dis+1);
}
void calctot(int v,int d,int siz,ll type)
{
int x=1,bit=0;
while(x<=(siz<<1)) x<<=1,bit++;
r[0]=0;
for(int i=1;i<=x;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<=x;i++) now[i]=0;
calc(v,-1,d);
NTT(now,1,x);
for(int i=0;i<=x;i++) now[i]=now[i]*now[i]%mod;
NTT(now,-1,x);
for(int i=0;i<=x;i++) final[i]+=type*now[i];
}
int solve(int v)
{
int totsiz=1;
v=find(v);
vis[v]=1;
calctot(v,0,siz[v],1);
for(int i=first[v];i;i=e[i].next)
if (!vis[e[i].v])
{
int newsiz=solve(e[i].v);
calctot(e[i].v,1,newsiz,-1);
totsiz+=newsiz;
}
vis[v]=0;
return totsiz;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<n;i++)
{
int a,b;
scanf("%d%d",&a,&b);
insert(a,b),insert(b,a);
}
solve(0);
double ans=0.0;
for(int i=0;i<=n;i++)
ans+=(double)final[i]/(double)(i+1);
printf("%.4lf",ans);
return 0;
}