【BZOJ3451】Tyvj1953 Normal - 点分治+FFT
题目来源:NOI2019模拟测试赛(七)
非原题面,题意有略微区别
题意:
吐槽:
心态崩了。
好不容易场上想出一题正解,写了三个小时结果写了个假的点分治,卡成$O(n^2)$
我退役吧。
题解:
原题是求随机树分治的期望深度和,题意相同。
对于一个点$x$,考虑点$y$是否能作为它在点分树上的祖先节点,显然当且仅当$y$在$x$到$y$的路径中第一个被选为分治中心时会对$x$产生1的贡献;
由于路径上所有点被选到的概率都是相等的,所以此时的期望就是$\frac{1}{dis(x,y)}$;
那么总的期望就是$\sum\limits_{x=1}^{n}\sum\limits_{y=1}^{n}\frac{1}{dis(x,y)}$;
在这里写个暴力即可爆踩我的假点分治;
考虑统计每种长度的路径条数,可以用点分治做,并且在点分树里合并时子树的期望是一个卷积的形式,因此可以用FFT来加速;
于是我就快乐的写了个点分治+FFT,获得了60分的好成绩;
为什么?参考这篇博客的证明,我最初的写法就是其中的第一种写法,搜完一个子树就和已经搜过的合并,这样做的话FFT的长度会是$子树中最大深度\times 根节点儿子个数=O(n^2)$的,正确的写法应该搜完再一起合并,或者像里面说的第二种方法一样直接搜当前子树,更新答案然后搜重心的每个儿子的子树,减去不合法的路径,这样子FFT的长度才是$O(n)$的。
代码:
假点分治(60pts):
1 #include<algorithm>
2 #include<iostream>
3 #include<cstring>
4 #include<cstdio>
5 #include<cmath>
6 #include<queue>
7 #define inf 2147483647
8 #define eps 1e-9
9 #define mod 1000000007
10 using namespace std;
11 typedef long long ll;
12 typedef double db;
13 const db pi=acos(-1.0);
14
15 struct edge{
16 int v,next;
17 }a[200001];
18 int n,u,v,S,rt,mxd,bit,bitnum,tot=0,cnt=0,ans=0,jc[100001],inv[100001],anss[200001],tp[200001],num[200001],s[200001],rev[200001],head[100001],mx[100001],siz[100001],dep[100001];
19 bool used[100001];
20 struct cp{
21 db a,b;
22 cp(){}
23 cp(db _a,db _b){
24 a=_a,b=_b;
25 }
26 friend cp operator +(cp a,cp b){return cp(a.a+b.a,a.b+b.b);}
27 friend cp operator -(cp a,cp b){return cp(a.a-b.a,a.b-b.b);}
28 friend cp operator *(cp a,cp b){return cp(a.a*b.a-a.b*b.b,a.a*b.b+a.b*b.a);}
29 friend cp operator *(cp a,db b){return cp(a.a*b,a.b*b);}
30 friend cp operator /(cp a,db b){return cp(a.a/b,a.b/b);}
31 }A[200001],B[200001],W[200001][2];
32 void _(){
33 for(int i=1;i<=(1<<17);i<<=1){
34 W[i][0]=cp(cos(pi/i),sin(pi/i));
35 W[i][1]=cp(cos(pi/i),-sin(pi/i));
36 }
37 }
38 void fft(cp *s,int op){
39 for(int i=0;i<bit;i++)if(i<rev[i])swap(s[i],s[rev[i]]);
40 for(int i=1;i<bit;i<<=1){
41 //cp w(cos(pi/i),op*sin(pi/i));
42 cp w=W[i][op==-1];
43 for(int p=i<<1,j=0;j<bit;j+=p){
44 cp wk(1,0);
45 for(int k=j;k<i+j;k++,wk=wk*w){
46 cp x=s[k],y=wk*s[k+i];
47 s[k]=x+y;
48 s[k+i]=x-y;
49 }
50 }
51 }
52 if(op==-1){
53 for(int i=0;i<bit;i++){
54 s[i]=s[i]/(db)bit;
55 }
56 }
57 }
58 void add(int u,int v){
59 a[++tot].v=v;
60 a[tot].next=head[u];
61 head[u]=tot;
62 }
63 void mul(int *ret,int *a,int *b,int n){
64 for(bit=1,bitnum=0;bit<=n*2;bit<<=1)bitnum++;
65 for(int i=1;i<=bit;i++){
66 rev[i]=(rev[i>>1]>>1)|((i&1)<<(bitnum-1));
67 }
68 for(int i=0;i<bit;i++){
69 A[i]=cp((db)a[i],0);
70 B[i]=cp(0,0);
71 }
72 for(int i=1;i<=cnt;i++){
73 a[b[i]]++;
74 B[b[i]].a+=1;
75 }
76 fft(A,1);
77 fft(B,1);
78 for(int i=0;i<bit;i++)A[i]=A[i]*B[i];
79 fft(A,-1);
80 for(int i=0;i<bit;i++)ret[i]=(int)(A[i].a+0.5);
81 }
82 void getrt(int u,int fa){
83 mx[u]=0;
84 siz[u]=1;
85 for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
86 int v=a[tmp].v;
87 if(!used[v]&&v!=fa){
88 getrt(v,u);
89 siz[u]+=siz[v];
90 mx[u]=max(mx[u],siz[v]);
91 }
92 }
93 mx[u]=max(mx[u],S-mx[u]);
94 if(mx[u]<mx[rt])rt=u;
95 }
96 void getdep(int u,int fa,int dpt){
97 mxd=max(mxd,dpt);
98 s[++cnt]=dpt;
99 for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
100 int v=a[tmp].v;
101 if(!used[v]&&v!=fa){
102 getdep(v,u,dpt+1);
103 }
104 }
105 }
106 void divide(int u){
107 used[u]=true;
108 mxd=0;
109 for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
110 int v=a[tmp].v;
111 if(!used[v]){
112 cnt=0;
113 getdep(v,u,1);
114 mul(tp,num,s,mxd);
115 for(int i=0;i<bit;i++)anss[i]+=tp[i];
116 }
117 }
118 for(int i=1;i<=mxd;i++){
119 anss[i]+=num[i];
120 num[i]=0;
121 }
122 for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
123 int v=a[tmp].v;
124 if(!used[v]){
125 S=siz[v];
126 rt=0;
127 getrt(v,0);
128 divide(rt);
129 }
130 }
131 }
132 int main(){
133 memset(head,-1,sizeof(head));
134 _();
135 scanf("%d",&n);
136 jc[0]=inv[0]=inv[1]=1;
137 for(int i=2;i<=n+1;i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
138 for(int i=1;i<=n+1;i++)jc[i]=(ll)jc[i-1]*i%mod;
139 for(int i=1;i<n;i++){
140 scanf("%d%d",&u,&v);
141 add(u,v);
142 add(v,u);
143 }
144 S=n;
145 mx[rt=0]=6666666;
146 getrt(1,-1);
147 divide(rt);
148 ans=n;
149 for(int i=1;i<=n;i++){
150 ans=(ans+(ll)anss[i]*inv[i+1]*2%mod)%mod;
151 }
152 printf("%lld",(ll)ans*jc[n]%mod);
153 return 0;
154 }
AC代码(100pts):
1 #include<algorithm>
2 #include<iostream>
3 #include<cstring>
4 #include<cstdio>
5 #include<cmath>
6 #include<queue>
7 #define inf 2147483647
8 #define eps 1e-9
9 #define mod 1000000007
10 using namespace std;
11 typedef long long ll;
12 typedef double db;
13 const db pi=acos(-1.0);
14
15 struct edge{
16 int v,next;
17 }a[200001];
18 int n,u,v,S,rt,mxd,bit,bitnum,tot=0,cnt=0,ans=0,jc[100001],inv[100001],anss[200001],tp[200001],num[200001],rev[200001],head[100001],mx[100001],siz[100001],dep[100001],dps[100001];
19 bool used[100001];
20 struct cp{
21 db a,b;
22 cp(){}
23 cp(db _a,db _b){
24 a=_a,b=_b;
25 }
26 friend cp operator +(cp a,cp b){return cp(a.a+b.a,a.b+b.b);}
27 friend cp operator -(cp a,cp b){return cp(a.a-b.a,a.b-b.b);}
28 friend cp operator *(cp a,cp b){return cp(a.a*b.a-a.b*b.b,a.a*b.b+a.b*b.a);}
29 friend cp operator *(cp a,db b){return cp(a.a*b,a.b*b);}
30 friend cp operator /(cp a,db b){return cp(a.a/b,a.b/b);}
31 }A[200001],B[200001],W[200001][2];
32 void _(){
33 for(int i=1;i<=(1<<17);i<<=1){
34 W[i][0]=cp(cos(pi/i),sin(pi/i));
35 W[i][1]=cp(cos(pi/i),-sin(pi/i));
36 }
37 }
38 void fft(cp *s,int op){
39 for(int i=0;i<bit;i++)if(i<rev[i])swap(s[i],s[rev[i]]);
40 for(int i=1;i<bit;i<<=1){
41 //cp w(cos(pi/i),op*sin(pi/i));
42 cp w=W[i][op==-1];
43 for(int p=i<<1,j=0;j<bit;j+=p){
44 cp wk(1,0);
45 for(int k=j;k<i+j;k++,wk=wk*w){
46 cp x=s[k],y=wk*s[k+i];
47 s[k]=x+y;
48 s[k+i]=x-y;
49 }
50 }
51 }
52 if(op==-1){
53 for(int i=0;i<bit;i++){
54 s[i]=s[i]/(db)bit;
55 }
56 }
57 }
58 void add(int u,int v){
59 a[++tot].v=v;
60 a[tot].next=head[u];
61 head[u]=tot;
62 }
63 void mul(int *ret,int *a,int *b,int n){
64 for(bit=1,bitnum=0;bit<=n*2;bit<<=1)bitnum++;
65 for(int i=1;i<bit;i++){
66 rev[i]=(rev[i>>1]>>1)|((i&1)<<(bitnum-1));
67 }
68 for(int i=0;i<bit;i++){
69 A[i]=cp((db)a[i],0);
70 B[i]=cp((db)b[i],0);
71 }
72 fft(A,1);
73 fft(B,1);
74 for(int i=0;i<bit;i++)A[i]=A[i]*B[i];
75 fft(A,-1);
76 for(int i=0;i<bit;i++)ret[i]=(int)(A[i].a+0.5);
77 }
78 void getrt(int u,int fa){
79 mx[u]=0;
80 siz[u]=1;
81 for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
82 int v=a[tmp].v;
83 if(!used[v]&&v!=fa){
84 getrt(v,u);
85 siz[u]+=siz[v];
86 mx[u]=max(mx[u],siz[v]);
87 }
88 }
89 mx[u]=max(mx[u],S-mx[u]);
90 if(mx[u]<mx[rt])rt=u;
91 }
92 void getdep(int u,int fa,int dpt){
93 mxd=max(mxd,dpt);
94 dps[dpt]++;
95 for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
96 int v=a[tmp].v;
97 if(!used[v]&&v!=fa){
98 getdep(v,u,dpt+1);
99 }
100 }
101 }
102 void divide(int u){
103 used[u]=true;
104 num[0]=1;
105 for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
106 int v=a[tmp].v;
107 if(!used[v]){
108 getdep(v,u,1);
109 for(int i=1;i<=mxd;i++){
110 num[i]+=dps[i];
111 tp[i]=dps[i];
112 dps[i]=0;
113 }
114 cnt=max(cnt,mxd);
115 mul(tp,tp,tp,mxd);
116 for(int i=1;i<=mxd*2;i++){
117 anss[i]-=tp[i];
118 tp[i]=0;
119 }
120 mxd=0;
121 }
122 }
123 for(int i=0;i<=cnt;i++){
124 tp[i]=num[i];
125 num[i]=0;
126 }
127 mul(tp,tp,tp,cnt);
128 for(int i=0;i<=cnt*2;i++){
129 anss[i]+=tp[i];
130 tp[i]=0;
131 }
132 cnt=0;
133 for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
134 int v=a[tmp].v;
135 if(!used[v]){
136 S=siz[v];
137 rt=0;
138 getrt(v,0);
139 divide(rt);
140 }
141 }
142 }
143 int main(){
144 memset(head,-1,sizeof(head));
145 _();
146 scanf("%d",&n);
147 jc[0]=inv[0]=inv[1]=1;
148 for(int i=2;i<=n+1;i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
149 for(int i=1;i<=n+1;i++)jc[i]=(ll)jc[i-1]*i%mod;
150 for(int i=1;i<n;i++){
151 scanf("%d%d",&u,&v);
152 add(u,v);
153 add(v,u);
154 }
155 S=n;
156 mx[rt=0]=6666666;
157 getrt(1,-1);
158 divide(rt);
159 ans=n;
160 for(int i=1;i<=n;i++){
161 ans=(ans+(ll)anss[i]*inv[i+1]%mod)%mod;
162 }
163 printf("%lld",(ll)ans*jc[n]%mod);
164 return 0;
165 }