【BZOJ3451】Tyvj1953 Normal - 点分治+FFT

题目来源:NOI2019模拟测试赛(七)

非原题面,题意有略微区别

题意:

吐槽:

心态崩了。

好不容易场上想出一题正解,写了三个小时结果写了个假的点分治,卡成$O(n^2)$

我退役吧。

题解:

原题是求随机树分治的期望深度和,题意相同。

对于一个点$x$,考虑点$y$是否能作为它在点分树上的祖先节点,显然当且仅当$y$在$x$到$y$的路径中第一个被选为分治中心时会对$x$产生1的贡献;

由于路径上所有点被选到的概率都是相等的,所以此时的期望就是$\frac{1}{dis(x,y)}$;

那么总的期望就是$\sum\limits_{x=1}^{n}\sum\limits_{y=1}^{n}\frac{1}{dis(x,y)}$;

在这里写个暴力即可爆踩我的假点分治;

考虑统计每种长度的路径条数,可以用点分治做,并且在点分树里合并时子树的期望是一个卷积的形式,因此可以用FFT来加速;

于是我就快乐的写了个点分治+FFT,获得了60分的好成绩;

为什么?参考这篇博客的证明,我最初的写法就是其中的第一种写法,搜完一个子树就和已经搜过的合并,这样做的话FFT的长度会是$子树中最大深度\times 根节点儿子个数=O(n^2)$的,正确的写法应该搜完再一起合并,或者像里面说的第二种方法一样直接搜当前子树,更新答案然后搜重心的每个儿子的子树,减去不合法的路径,这样子FFT的长度才是$O(n)$的。

代码:

假点分治(60pts):

  1 #include<algorithm>
  2 #include<iostream>
  3 #include<cstring>
  4 #include<cstdio>
  5 #include<cmath>
  6 #include<queue>
  7 #define inf 2147483647
  8 #define eps 1e-9
  9 #define mod 1000000007
 10 using namespace std;
 11 typedef long long ll;
 12 typedef double db;
 13 const db pi=acos(-1.0);
 14 
 15 struct edge{
 16     int v,next;
 17 }a[200001];
 18 int n,u,v,S,rt,mxd,bit,bitnum,tot=0,cnt=0,ans=0,jc[100001],inv[100001],anss[200001],tp[200001],num[200001],s[200001],rev[200001],head[100001],mx[100001],siz[100001],dep[100001];
 19 bool used[100001];
 20 struct cp{
 21     db a,b;
 22     cp(){}
 23     cp(db _a,db _b){
 24         a=_a,b=_b;
 25     }
 26     friend cp operator +(cp a,cp b){return cp(a.a+b.a,a.b+b.b);}
 27     friend cp operator -(cp a,cp b){return cp(a.a-b.a,a.b-b.b);}
 28     friend cp operator *(cp a,cp b){return cp(a.a*b.a-a.b*b.b,a.a*b.b+a.b*b.a);}
 29     friend cp operator *(cp a,db b){return cp(a.a*b,a.b*b);}
 30     friend cp operator /(cp a,db b){return cp(a.a/b,a.b/b);}
 31 }A[200001],B[200001],W[200001][2];
 32 void _(){
 33     for(int i=1;i<=(1<<17);i<<=1){
 34         W[i][0]=cp(cos(pi/i),sin(pi/i));
 35         W[i][1]=cp(cos(pi/i),-sin(pi/i));
 36     }
 37 }
 38 void fft(cp *s,int op){
 39     for(int i=0;i<bit;i++)if(i<rev[i])swap(s[i],s[rev[i]]);
 40     for(int i=1;i<bit;i<<=1){
 41         //cp w(cos(pi/i),op*sin(pi/i));
 42         cp w=W[i][op==-1];
 43         for(int p=i<<1,j=0;j<bit;j+=p){
 44             cp wk(1,0);
 45             for(int k=j;k<i+j;k++,wk=wk*w){
 46                 cp x=s[k],y=wk*s[k+i];
 47                 s[k]=x+y;
 48                 s[k+i]=x-y;
 49             }
 50         }
 51     }
 52     if(op==-1){
 53         for(int i=0;i<bit;i++){
 54             s[i]=s[i]/(db)bit;
 55         }
 56     }
 57 }
 58 void add(int u,int v){
 59     a[++tot].v=v;
 60     a[tot].next=head[u];
 61     head[u]=tot;
 62 }
 63 void mul(int *ret,int *a,int *b,int n){
 64     for(bit=1,bitnum=0;bit<=n*2;bit<<=1)bitnum++;
 65     for(int i=1;i<=bit;i++){
 66         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bitnum-1));
 67     }
 68     for(int i=0;i<bit;i++){
 69         A[i]=cp((db)a[i],0);
 70         B[i]=cp(0,0);
 71     }
 72     for(int i=1;i<=cnt;i++){
 73         a[b[i]]++;
 74         B[b[i]].a+=1;
 75     }
 76     fft(A,1);
 77     fft(B,1);
 78     for(int i=0;i<bit;i++)A[i]=A[i]*B[i];
 79     fft(A,-1);
 80     for(int i=0;i<bit;i++)ret[i]=(int)(A[i].a+0.5);
 81 }
 82 void getrt(int u,int fa){
 83     mx[u]=0;
 84     siz[u]=1;
 85     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
 86         int v=a[tmp].v;
 87         if(!used[v]&&v!=fa){
 88             getrt(v,u);
 89             siz[u]+=siz[v];
 90             mx[u]=max(mx[u],siz[v]);
 91         }
 92     }
 93     mx[u]=max(mx[u],S-mx[u]);
 94     if(mx[u]<mx[rt])rt=u;
 95 }
 96 void getdep(int u,int fa,int dpt){
 97     mxd=max(mxd,dpt);
 98     s[++cnt]=dpt;
 99     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
100         int v=a[tmp].v;
101         if(!used[v]&&v!=fa){
102             getdep(v,u,dpt+1);
103         }
104     }
105 }
106 void divide(int u){
107     used[u]=true;
108     mxd=0;
109     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
110         int v=a[tmp].v;
111         if(!used[v]){
112             cnt=0;
113             getdep(v,u,1);
114             mul(tp,num,s,mxd);
115             for(int i=0;i<bit;i++)anss[i]+=tp[i];
116         }
117     }
118     for(int i=1;i<=mxd;i++){
119         anss[i]+=num[i];
120         num[i]=0;
121     }
122     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
123         int v=a[tmp].v;
124         if(!used[v]){
125             S=siz[v];
126             rt=0;
127             getrt(v,0);
128             divide(rt);
129         }
130     }
131 }
132 int main(){
133     memset(head,-1,sizeof(head));
134     _();
135     scanf("%d",&n);
136     jc[0]=inv[0]=inv[1]=1;
137     for(int i=2;i<=n+1;i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
138     for(int i=1;i<=n+1;i++)jc[i]=(ll)jc[i-1]*i%mod;
139     for(int i=1;i<n;i++){
140         scanf("%d%d",&u,&v);
141         add(u,v);
142         add(v,u);
143     }
144     S=n;
145     mx[rt=0]=6666666;
146     getrt(1,-1);
147     divide(rt);
148     ans=n;
149     for(int i=1;i<=n;i++){
150         ans=(ans+(ll)anss[i]*inv[i+1]*2%mod)%mod;
151     }
152     printf("%lld",(ll)ans*jc[n]%mod);
153     return 0;
154 }

AC代码(100pts):

  1 #include<algorithm>
  2 #include<iostream>
  3 #include<cstring>
  4 #include<cstdio>
  5 #include<cmath>
  6 #include<queue>
  7 #define inf 2147483647
  8 #define eps 1e-9
  9 #define mod 1000000007
 10 using namespace std;
 11 typedef long long ll;
 12 typedef double db;
 13 const db pi=acos(-1.0);
 14 
 15 struct edge{
 16     int v,next;
 17 }a[200001];
 18 int n,u,v,S,rt,mxd,bit,bitnum,tot=0,cnt=0,ans=0,jc[100001],inv[100001],anss[200001],tp[200001],num[200001],rev[200001],head[100001],mx[100001],siz[100001],dep[100001],dps[100001];
 19 bool used[100001];
 20 struct cp{
 21     db a,b;
 22     cp(){}
 23     cp(db _a,db _b){
 24         a=_a,b=_b;
 25     }
 26     friend cp operator +(cp a,cp b){return cp(a.a+b.a,a.b+b.b);}
 27     friend cp operator -(cp a,cp b){return cp(a.a-b.a,a.b-b.b);}
 28     friend cp operator *(cp a,cp b){return cp(a.a*b.a-a.b*b.b,a.a*b.b+a.b*b.a);}
 29     friend cp operator *(cp a,db b){return cp(a.a*b,a.b*b);}
 30     friend cp operator /(cp a,db b){return cp(a.a/b,a.b/b);}
 31 }A[200001],B[200001],W[200001][2];
 32 void _(){
 33     for(int i=1;i<=(1<<17);i<<=1){
 34         W[i][0]=cp(cos(pi/i),sin(pi/i));
 35         W[i][1]=cp(cos(pi/i),-sin(pi/i));
 36     }
 37 }
 38 void fft(cp *s,int op){
 39     for(int i=0;i<bit;i++)if(i<rev[i])swap(s[i],s[rev[i]]);
 40     for(int i=1;i<bit;i<<=1){
 41         //cp w(cos(pi/i),op*sin(pi/i));
 42         cp w=W[i][op==-1];
 43         for(int p=i<<1,j=0;j<bit;j+=p){
 44             cp wk(1,0);
 45             for(int k=j;k<i+j;k++,wk=wk*w){
 46                 cp x=s[k],y=wk*s[k+i];
 47                 s[k]=x+y;
 48                 s[k+i]=x-y;
 49             }
 50         }
 51     }
 52     if(op==-1){
 53         for(int i=0;i<bit;i++){
 54             s[i]=s[i]/(db)bit;
 55         }
 56     }
 57 }
 58 void add(int u,int v){
 59     a[++tot].v=v;
 60     a[tot].next=head[u];
 61     head[u]=tot;
 62 }
 63 void mul(int *ret,int *a,int *b,int n){
 64     for(bit=1,bitnum=0;bit<=n*2;bit<<=1)bitnum++;
 65     for(int i=1;i<bit;i++){
 66         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bitnum-1));
 67     }
 68     for(int i=0;i<bit;i++){
 69         A[i]=cp((db)a[i],0);
 70         B[i]=cp((db)b[i],0);
 71     }
 72     fft(A,1);
 73     fft(B,1);
 74     for(int i=0;i<bit;i++)A[i]=A[i]*B[i];
 75     fft(A,-1);
 76     for(int i=0;i<bit;i++)ret[i]=(int)(A[i].a+0.5);
 77 }
 78 void getrt(int u,int fa){
 79     mx[u]=0;
 80     siz[u]=1;
 81     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
 82         int v=a[tmp].v;
 83         if(!used[v]&&v!=fa){
 84             getrt(v,u);
 85             siz[u]+=siz[v];
 86             mx[u]=max(mx[u],siz[v]);
 87         }
 88     }
 89     mx[u]=max(mx[u],S-mx[u]);
 90     if(mx[u]<mx[rt])rt=u;
 91 }
 92 void getdep(int u,int fa,int dpt){
 93     mxd=max(mxd,dpt);
 94     dps[dpt]++;
 95     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
 96         int v=a[tmp].v;
 97         if(!used[v]&&v!=fa){
 98             getdep(v,u,dpt+1);
 99         }
100     }
101 }
102 void divide(int u){
103     used[u]=true;
104     num[0]=1;
105     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
106         int v=a[tmp].v;
107         if(!used[v]){
108             getdep(v,u,1);
109             for(int i=1;i<=mxd;i++){
110                 num[i]+=dps[i];
111                 tp[i]=dps[i];
112                 dps[i]=0;
113             }
114             cnt=max(cnt,mxd);
115             mul(tp,tp,tp,mxd);
116             for(int i=1;i<=mxd*2;i++){
117                 anss[i]-=tp[i];
118                 tp[i]=0;
119             }
120             mxd=0;
121         }
122     }
123     for(int i=0;i<=cnt;i++){
124         tp[i]=num[i];
125         num[i]=0;
126     }
127     mul(tp,tp,tp,cnt);
128     for(int i=0;i<=cnt*2;i++){
129         anss[i]+=tp[i];
130         tp[i]=0;
131     }
132     cnt=0;
133     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
134         int v=a[tmp].v;
135         if(!used[v]){
136             S=siz[v];
137             rt=0;
138             getrt(v,0);
139             divide(rt);
140         }
141     }
142 }
143 int main(){
144     memset(head,-1,sizeof(head));
145     _();
146     scanf("%d",&n);
147     jc[0]=inv[0]=inv[1]=1;
148     for(int i=2;i<=n+1;i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
149     for(int i=1;i<=n+1;i++)jc[i]=(ll)jc[i-1]*i%mod;
150     for(int i=1;i<n;i++){
151         scanf("%d%d",&u,&v);
152         add(u,v);
153         add(v,u);
154     }
155     S=n;
156     mx[rt=0]=6666666;
157     getrt(1,-1);
158     divide(rt);
159     ans=n;
160     for(int i=1;i<=n;i++){
161         ans=(ans+(ll)anss[i]*inv[i+1]%mod)%mod;
162     }
163     printf("%lld",(ll)ans*jc[n]%mod);
164     return 0;
165 }
posted @ 2018-12-18 21:03  DCDCBigBig  阅读(257)  评论(0编辑  收藏  举报