[loj3225]Podatki drogowe

关于距离,使用线段树存储,并维护哈希值以支持比较

建立点分树,并对每一个节点维护(点分树)子树内所有点到其的距离(对应的线段树)

需要将这些线段树(在原树的结构上)可持久化,进而时空复杂度均为$o(n\log^{2}n)$

将这$o(n\log n)$个距离分为$o(n)$组(允许重复),每一组距离两两相加,贡献系数$\in \{1,-1\}$

(参考点分治的容斥,同时每组距离个数和与总距离个数同级)

对于给定的$mid$,在每一组内使用双指针,即可$o(n\log^{2}n)$求出$dis(u,v)\le mid$的点对数

在此基础上,对所有可行的线段树对二分,可以随机选取$mid$,期望二分次数为$o(\log n)$

另外,为了避免大量权值相同,可以加入随机扰动

最终,总复杂度为$o(n\log^{3}n)$,可以通过

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 #define N 25005
  4 #define M 10000005
  5 #define base 47
  6 #define mod 1000000007
  7 #define ll long long
  8 #define ull unsigned ll
  9 #define mid (l+r>>1)
 10 #define To edge[i].to
 11 struct Edge{
 12     int nex,to,len;
 13 }edge[N<<1];
 14 vector<int>v[N<<1],L[N<<1],R[N<<1],now[N<<1];
 15 int n,E,k,rt,x,y,z,val[N],head[N],vis[N],sz[N];
 16 int m,V,Rt[N],ls[M],rs[M],sum[M],P[N<<1];
 17 ull mi[N],f[M];
 18 mt19937 rnd(time(0));
 19 struct Node{
 20     int a,b;
 21     ull get_f(){
 22         return f[a]+f[b];
 23     }
 24     Node get_ls(){
 25         return Node{ls[a],ls[b]};
 26     }
 27     Node get_rs(){
 28         return Node{rs[a],rs[b]};
 29     }
 30 };
 31 int New(int k){
 32     ls[++V]=ls[k],rs[V]=rs[k],sum[V]=sum[k],f[V]=f[k];
 33     return V;
 34 }
 35 void update(int &k,int l,int r,int x){
 36     k=New(k);
 37     if (l==r){
 38         f[k]++,sum[k]=(sum[k]+val[x])%mod;
 39         return;
 40     }
 41     if (x<=mid)update(ls[k],l,mid,x);
 42     else update(rs[k],mid+1,r,x);
 43     sum[k]=(sum[ls[k]]+sum[rs[k]])%mod;
 44     f[k]=mi[r-mid]*f[ls[k]]+f[rs[k]];
 45 }
 46 int Cmp(Node a,Node b){
 47     if (a.get_f()==b.get_f()){
 48         if (a.a+a.b==b.a+b.b)return 0;
 49         return 1-((a.a+a.b<b.a+b.b)<<1);
 50     }
 51     int l=1,r=n;
 52     while (l<r){
 53         Node aa=a.get_rs(),bb=b.get_rs();
 54         if (aa.get_f()!=bb.get_f())a=aa,b=bb,l=mid+1;
 55         else a=a.get_ls(),b=b.get_ls(),r=mid;
 56     }
 57     return 1-((a.get_f()<b.get_f())<<1);
 58 }
 59 void add(int x,int y,int z){
 60     edge[E]=Edge{head[x],y,z};
 61     head[x]=E++;
 62 }
 63 void get_sz(int k,int fa){
 64     sz[k]=1;
 65     for(int i=head[k];i!=-1;i=edge[i].nex)
 66         if ((!vis[To])&&(To!=fa))get_sz(To,k),sz[k]+=sz[To];
 67 }
 68 void get_rt(int k,int fa,int s){
 69     int mx=s-sz[k];
 70     for(int i=head[k];i!=-1;i=edge[i].nex)
 71         if ((!vis[To])&&(To!=fa))get_rt(To,k,s),mx=max(mx,sz[To]);
 72     if (mx<=(s>>1))rt=k;
 73 }
 74 void get_seg(int k,int fa){
 75     if (!fa)Rt[k]=New(0);
 76     for(int i=head[k];i!=-1;i=edge[i].nex)
 77         if ((!vis[To])&&(To!=fa)){
 78             Rt[To]=Rt[k],update(Rt[To],1,n,edge[i].len);
 79             get_seg(To,k);
 80         }
 81 }
 82 void get_group(int k,int fa,int p){
 83     if (!fa)P[++m]=p;
 84     v[m].push_back(Rt[k]);
 85     for(int i=head[k];i!=-1;i=edge[i].nex)
 86         if ((!vis[To])&&(To!=fa))get_group(To,k,p);
 87 }
 88 void dfs(int k){
 89     get_sz(k,0),get_rt(k,0,sz[k]);
 90     vis[rt]=1,get_seg(rt,0),get_group(rt,0,1);
 91     for(int i=head[rt];i!=-1;i=edge[i].nex)
 92         if (!vis[To])get_group(To,0,-1),dfs(To);
 93 }
 94 bool cmp(int x,int y){
 95     return Cmp(Node{x,0},Node{y,0})<0;
 96 }
 97 int query(Node s){
 98     int ans=0;
 99     for(int i=1;i<=m;i++){
100         z=v[i].size();
101         for(int j=z-1,k=0;j>=0;j--){
102             while ((k<z)&&(Cmp(s,Node{v[i][j],v[i][k]})>0))k++;
103             now[i][j]=k,ans+=k*P[i];
104         }
105     }
106     return ans;
107 }
108 int main(){
109     scanf("%d%d",&n,&k);
110     k=(k<<1)+n,val[0]=mi[0]=1;
111     for(int i=1;i<N;i++){
112         val[i]=(ll)val[i-1]*n%mod;
113         mi[i]=mi[i-1]*base;
114     }
115     memset(head,-1,sizeof(head));
116     for(int i=1;i<n;i++){
117         scanf("%d%d%d",&x,&y,&z);
118         add(x,y,z),add(y,x,z);
119     }
120     dfs(1);
121     for(int i=1;i<=m;i++){
122         z=v[i].size(),sort(v[i].begin(),v[i].end(),cmp);
123         L[i].resize(z,0),R[i].resize(z,z),now[i].resize(z);
124     }
125     while (1){
126         ll pos=0,Sum=0;
127         for(int i=1;i<=m;i++)
128             for(int j=0;j<v[i].size();j++)Sum+=R[i][j]-L[i][j];
129         x=y=0,pos=(rnd()%Sum+Sum)%Sum;
130         for(int i=1;i<=m;i++){
131             for(int j=0;j<v[i].size();j++)
132                 if (R[i][j]-L[i][j]<=pos)pos-=R[i][j]-L[i][j];
133                 else{
134                     x=v[i][j],y=v[i][L[i][j]+pos];
135                     break;
136                 }
137             if ((x)&&(y))break;
138         }
139         bool flag=0;
140         for(int i=1;i<=m;i++){
141             for(int j=0;j<v[i].size();j++)
142                 if (L[i][j]<R[i][j]){
143                     if (Cmp(Node{x,y},Node{v[i][j],v[i][L[i][j]]}))flag=1;
144                     if (Cmp(Node{x,y},Node{v[i][j],v[i][R[i][j]-1]}))flag=1;
145                     if (flag)break;
146                 }
147             if (flag)break;
148         }
149         if (!flag)break;
150         if (query(Node{x,y})<k){
151             for(int i=1;i<=m;i++)
152                 for(int j=0;j<v[i].size();j++)
153                     if (L[i][j]<R[i][j])L[i][j]=now[i][j];
154         }
155         else{
156             for(int i=1;i<=m;i++)
157                 for(int j=0;j<v[i].size();j++)
158                     if (L[i][j]<R[i][j])R[i][j]=now[i][j];
159         }
160     }
161     printf("%d\n",(sum[x]+sum[y])%mod);
162     return 0;
163 }
View Code

 

posted @ 2022-06-22 15:01  PYWBKTDA  阅读(52)  评论(0编辑  收藏  举报