[loj3225]Podatki drogowe
关于距离,使用线段树存储,并维护哈希值以支持比较
建立点分树,并对每一个节点维护(点分树)子树内所有点到其的距离(对应的线段树)
需要将这些线段树(在原树的结构上)可持久化,进而时空复杂度均为$o(n\log^{2}n)$
将这$o(n\log n)$个距离分为$o(n)$组(允许重复),每一组距离两两相加,贡献系数$\in \{1,-1\}$
(参考点分治的容斥,同时每组距离个数和与总距离个数同级)
对于给定的$mid$,在每一组内使用双指针,即可$o(n\log^{2}n)$求出$dis(u,v)\le mid$的点对数
在此基础上,对所有可行的线段树对二分,可以随机选取$mid$,期望二分次数为$o(\log n)$
另外,为了避免大量权值相同,可以加入随机扰动
最终,总复杂度为$o(n\log^{3}n)$,可以通过
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N 25005 4 #define M 10000005 5 #define base 47 6 #define mod 1000000007 7 #define ll long long 8 #define ull unsigned ll 9 #define mid (l+r>>1) 10 #define To edge[i].to 11 struct Edge{ 12 int nex,to,len; 13 }edge[N<<1]; 14 vector<int>v[N<<1],L[N<<1],R[N<<1],now[N<<1]; 15 int n,E,k,rt,x,y,z,val[N],head[N],vis[N],sz[N]; 16 int m,V,Rt[N],ls[M],rs[M],sum[M],P[N<<1]; 17 ull mi[N],f[M]; 18 mt19937 rnd(time(0)); 19 struct Node{ 20 int a,b; 21 ull get_f(){ 22 return f[a]+f[b]; 23 } 24 Node get_ls(){ 25 return Node{ls[a],ls[b]}; 26 } 27 Node get_rs(){ 28 return Node{rs[a],rs[b]}; 29 } 30 }; 31 int New(int k){ 32 ls[++V]=ls[k],rs[V]=rs[k],sum[V]=sum[k],f[V]=f[k]; 33 return V; 34 } 35 void update(int &k,int l,int r,int x){ 36 k=New(k); 37 if (l==r){ 38 f[k]++,sum[k]=(sum[k]+val[x])%mod; 39 return; 40 } 41 if (x<=mid)update(ls[k],l,mid,x); 42 else update(rs[k],mid+1,r,x); 43 sum[k]=(sum[ls[k]]+sum[rs[k]])%mod; 44 f[k]=mi[r-mid]*f[ls[k]]+f[rs[k]]; 45 } 46 int Cmp(Node a,Node b){ 47 if (a.get_f()==b.get_f()){ 48 if (a.a+a.b==b.a+b.b)return 0; 49 return 1-((a.a+a.b<b.a+b.b)<<1); 50 } 51 int l=1,r=n; 52 while (l<r){ 53 Node aa=a.get_rs(),bb=b.get_rs(); 54 if (aa.get_f()!=bb.get_f())a=aa,b=bb,l=mid+1; 55 else a=a.get_ls(),b=b.get_ls(),r=mid; 56 } 57 return 1-((a.get_f()<b.get_f())<<1); 58 } 59 void add(int x,int y,int z){ 60 edge[E]=Edge{head[x],y,z}; 61 head[x]=E++; 62 } 63 void get_sz(int k,int fa){ 64 sz[k]=1; 65 for(int i=head[k];i!=-1;i=edge[i].nex) 66 if ((!vis[To])&&(To!=fa))get_sz(To,k),sz[k]+=sz[To]; 67 } 68 void get_rt(int k,int fa,int s){ 69 int mx=s-sz[k]; 70 for(int i=head[k];i!=-1;i=edge[i].nex) 71 if ((!vis[To])&&(To!=fa))get_rt(To,k,s),mx=max(mx,sz[To]); 72 if (mx<=(s>>1))rt=k; 73 } 74 void get_seg(int k,int fa){ 75 if (!fa)Rt[k]=New(0); 76 for(int i=head[k];i!=-1;i=edge[i].nex) 77 if ((!vis[To])&&(To!=fa)){ 78 Rt[To]=Rt[k],update(Rt[To],1,n,edge[i].len); 79 get_seg(To,k); 80 } 81 } 82 void get_group(int k,int fa,int p){ 83 if (!fa)P[++m]=p; 84 v[m].push_back(Rt[k]); 85 for(int i=head[k];i!=-1;i=edge[i].nex) 86 if ((!vis[To])&&(To!=fa))get_group(To,k,p); 87 } 88 void dfs(int k){ 89 get_sz(k,0),get_rt(k,0,sz[k]); 90 vis[rt]=1,get_seg(rt,0),get_group(rt,0,1); 91 for(int i=head[rt];i!=-1;i=edge[i].nex) 92 if (!vis[To])get_group(To,0,-1),dfs(To); 93 } 94 bool cmp(int x,int y){ 95 return Cmp(Node{x,0},Node{y,0})<0; 96 } 97 int query(Node s){ 98 int ans=0; 99 for(int i=1;i<=m;i++){ 100 z=v[i].size(); 101 for(int j=z-1,k=0;j>=0;j--){ 102 while ((k<z)&&(Cmp(s,Node{v[i][j],v[i][k]})>0))k++; 103 now[i][j]=k,ans+=k*P[i]; 104 } 105 } 106 return ans; 107 } 108 int main(){ 109 scanf("%d%d",&n,&k); 110 k=(k<<1)+n,val[0]=mi[0]=1; 111 for(int i=1;i<N;i++){ 112 val[i]=(ll)val[i-1]*n%mod; 113 mi[i]=mi[i-1]*base; 114 } 115 memset(head,-1,sizeof(head)); 116 for(int i=1;i<n;i++){ 117 scanf("%d%d%d",&x,&y,&z); 118 add(x,y,z),add(y,x,z); 119 } 120 dfs(1); 121 for(int i=1;i<=m;i++){ 122 z=v[i].size(),sort(v[i].begin(),v[i].end(),cmp); 123 L[i].resize(z,0),R[i].resize(z,z),now[i].resize(z); 124 } 125 while (1){ 126 ll pos=0,Sum=0; 127 for(int i=1;i<=m;i++) 128 for(int j=0;j<v[i].size();j++)Sum+=R[i][j]-L[i][j]; 129 x=y=0,pos=(rnd()%Sum+Sum)%Sum; 130 for(int i=1;i<=m;i++){ 131 for(int j=0;j<v[i].size();j++) 132 if (R[i][j]-L[i][j]<=pos)pos-=R[i][j]-L[i][j]; 133 else{ 134 x=v[i][j],y=v[i][L[i][j]+pos]; 135 break; 136 } 137 if ((x)&&(y))break; 138 } 139 bool flag=0; 140 for(int i=1;i<=m;i++){ 141 for(int j=0;j<v[i].size();j++) 142 if (L[i][j]<R[i][j]){ 143 if (Cmp(Node{x,y},Node{v[i][j],v[i][L[i][j]]}))flag=1; 144 if (Cmp(Node{x,y},Node{v[i][j],v[i][R[i][j]-1]}))flag=1; 145 if (flag)break; 146 } 147 if (flag)break; 148 } 149 if (!flag)break; 150 if (query(Node{x,y})<k){ 151 for(int i=1;i<=m;i++) 152 for(int j=0;j<v[i].size();j++) 153 if (L[i][j]<R[i][j])L[i][j]=now[i][j]; 154 } 155 else{ 156 for(int i=1;i<=m;i++) 157 for(int j=0;j<v[i].size();j++) 158 if (L[i][j]<R[i][j])R[i][j]=now[i][j]; 159 } 160 } 161 printf("%d\n",(sum[x]+sum[y])%mod); 162 return 0; 163 }