bzoj 4326: NOIP2015 运输计划(树链剖分+二分+差分 or 树上差分+LCA+二分)
解法一:树链剖分+二分+差分
树链剖分快速求解任意两点间的路径的权值和;
然后,二分答案;
此题的难点是如何快速求解重合路径?
差分数组可以否???
在此之前先介绍一下相关变量:
1 int fa[maxn]; 2 int siz[maxn];//siz[i]:i子树的节点个数 3 int dep[maxn];//dep[i]:节点i在树中的深度 4 int son[maxn];//son[i]:节点i的重儿子 5 int w[maxn];//w[i]:i节点与其父节点的权值 6 int tid[maxn];//tid[i]:节点i的新编号 7 int top[maxn];//top[i]:节点i所在重链的祖先 8 int s[maxn];//s[i]:新编号中,[1,i]的权值之和,s[i]=w[1]+w[2]+...+w[i];
如何用差分数组求解重合路径呢?
考虑一下化简得题型,给定 n 个数;
给出 m 个区间 [L,R] ,如何求着 m 个区间公共覆盖的点?
例如 n = 10 , m = 4;
[2,8],[3,7],[4,6],[5,6]
这四个区间得公共覆盖得区间点为 5,6;
定义diff[ ]为差分数组。
对于给出得区间[L,R] , diff[L]++,diff[R+1]--;
最后扫一遍diff[] ,求出diff[]得前缀和sum[];
如果存在点 x 使得 sum[x] = m,那么点 x 就是这 m 个区间公共覆盖的点;
此题同理,求出不满足二分条件得 cnt 个区间[L,R],通过diff[]求出这 cnt 个区间公共覆盖的点;
代码如下:
1 int diff[maxn]; 2 void Update(int u,int v) 3 { 4 int x=top[u]; 5 int y=top[v]; 6 if(x == y)//如果 u,v 在同一条重链上 7 { 8 if(dep[u] > dep[v]) 9 swap(u,v); 10 //节点son[u]到节点v的新编号是连续的 11 diff[tid[son[u]]]++; 12 diff[tid[v]+1]--; 13 return ; 14 } 15 else//如果不在 16 { 17 if(dep[x] > dep[y]) 18 { 19 swap(x,y); 20 swap(u,v); 21 } 22 //节点 top[v]与节点v的新编号是连续的 23 diff[tid[y]]++; 24 diff[tid[v]+1]--; 25 Update(u,fa[y]); 26 } 27 }
AC代码:
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 #include<algorithm> 5 using namespace std; 6 #define ls(x) (x<<1) 7 #define rs(x) (x<<1|1) 8 #define mem(a,b) memset(a,b,sizeof(a)) 9 const int maxn=3e5+50; 10 11 int n,m; 12 int fa[maxn]; 13 int siz[maxn];//siz[i]:i子树的节点个数 14 int dep[maxn]; 15 int son[maxn];//son[i]:节点i的重儿子 16 int w[maxn];//w[i]:i节点与其父节点的权值 17 int tid[maxn];//tid[i]:节点i的新编号 18 int top[maxn];//top[i]:节点i所在重链的祖先 19 int s[maxn];//s[i]:新编号中,[1,i]的权值之和,s[i]=w[1]+w[2]+...+w[i]; 20 int num; 21 int head[maxn]; 22 struct Edge 23 { 24 int to; 25 int w; 26 int next; 27 }G[2*maxn]; 28 void addEdge(int u,int v,int w) 29 { 30 G[num].to=v; 31 G[num].w=w; 32 G[num].next=head[u]; 33 head[u]=num++; 34 } 35 struct Que 36 { 37 int u,v; 38 int w; 39 }que[maxn]; 40 void DFS1(int u,int f,int d) 41 { 42 fa[u]=f; 43 dep[u]=d; 44 siz[u]=1; 45 for(int i=head[u];~i;i=G[i].next) 46 { 47 int v=G[i].to; 48 if(v == f) 49 continue; 50 51 w[v]=G[i].w; 52 DFS1(v,u,d+1); 53 54 siz[u] += siz[v]; 55 if(son[u] == -1 || siz[v] > siz[son[u]]) 56 son[u]=v; 57 } 58 } 59 void DFS2(int u,int a,int &k) 60 { 61 top[u]=a; 62 tid[u]=++k; 63 s[k]=s[k-1]+w[u]; 64 if(son[u] == -1) 65 return ; 66 DFS2(son[u],a,k); 67 for(int i=head[u];~i;i=G[i].next) 68 { 69 int v=G[i].to; 70 if(v != son[u] && v != fa[u]) 71 DFS2(v,v,k); 72 } 73 } 74 int Find(int u,int v)//求解节点u到节点v的路径权值和 75 { 76 int x=top[u]; 77 int y=top[v]; 78 int ans=0; 79 80 while(x != y) 81 { 82 if(dep[x] > dep[y]) 83 { 84 swap(u,v); 85 swap(x,y); 86 } 87 ans += s[tid[v]]-s[tid[y]-1]; 88 v=fa[y]; 89 y=top[v]; 90 } 91 if(u != v) 92 { 93 if(dep[u] > dep[v]) 94 swap(u,v); 95 ans += s[tid[v]]-s[tid[u]]; 96 } 97 98 return ans; 99 } 100 101 int diff[maxn]; 102 void Update(int u,int v) 103 { 104 int x=top[u]; 105 int y=top[v]; 106 if(x == y)//如果 u,v 在同一条重链上 107 { 108 if(dep[u] > dep[v]) 109 swap(u,v); 110 //节点son[u]到节点v的新编号是连续的 111 diff[tid[son[u]]]++; 112 diff[tid[v]+1]--; 113 return ; 114 } 115 else//如果不在 116 { 117 if(dep[x] > dep[y]) 118 { 119 swap(x,y); 120 swap(u,v); 121 } 122 //节点 top[v]与节点v的新编号是连续的 123 diff[tid[y]]++; 124 diff[tid[v]+1]--; 125 Update(u,fa[y]); 126 } 127 } 128 /** 129 cnt:一共有cnt个权值和 > mid 130 ans1:这cnt个权值和最大的比mid大多少 131 ans2:这cnt个路径中权值最大的公共路径 132 */ 133 bool Check(int mid) 134 { 135 int cnt=0; 136 int ans1=0; 137 for(int i=1;i <= m;++i) 138 { 139 int u=que[i].u; 140 int v=que[i].v; 141 if(que[i].w > mid) 142 { 143 cnt++; 144 ans1=max(ans1,que[i].w-mid); 145 Update(u,v); 146 } 147 } 148 int ans2=0; 149 int tot=0; 150 for(int i=1;i <= n;++i) 151 { 152 tot += diff[i]; 153 diff[i]=0; 154 if(tot == cnt) 155 ans2=max(ans2,s[i]-s[i-1]); 156 } 157 //只有ans2 >= ans1 才能够使最大的权值和小于等于mid 158 return ans2 >= ans1; 159 } 160 int Solve() 161 { 162 DFS1(1,1,1); 163 int k=0; 164 DFS2(1,1,k); 165 166 for(int i=1;i <= m;++i) 167 { 168 int u=que[i].u; 169 int v=que[i].v; 170 que[i].w=Find(u,v); 171 } 172 173 int l=-1,r=300000000+50; 174 while(r-l > 1)//二分答案 175 { 176 int mid=l+((r-l)>>1); 177 if(Check(mid)) 178 r=mid; 179 else 180 l=mid; 181 } 182 return r; 183 } 184 void Init() 185 { 186 num=0; 187 mem(head,-1); 188 mem(diff,0); 189 mem(son,-1); 190 mem(s,0); 191 } 192 int main() 193 { 194 // freopen("C:\\Users\\hyacinthLJP\\Desktop\\in&&out\\BZOJ\\4326.in","r",stdin); 195 while(~scanf("%d%d",&n,&m)) 196 { 197 Init(); 198 for(int i=1;i < n;++i) 199 { 200 int u,v,w; 201 scanf("%d%d%d",&u,&v,&w); 202 addEdge(u,v,w); 203 addEdge(v,u,w); 204 } 205 for(int i=1;i <= m;++i) 206 scanf("%d%d",&que[i].u,&que[i].v); 207 208 printf("%d\n",Solve()); 209 } 210 return 0; 211 }
解法二:树上差分+LCA+二分
差分详解,戳这里👉
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define mem(a,b) memset(a,b,sizeof(a)) 4 const int maxn=3e5+50; 5 6 int n,m; 7 int num; 8 int head[maxn]; 9 struct Edge 10 { 11 int to; 12 int next; 13 int w; 14 }G[maxn<<1]; 15 void addEdge(int u,int v,int w) 16 { 17 G[num]=Edge{v,head[u],w}; 18 head[u]=num++; 19 } 20 int t[maxn];///t[u]:u与其父节点的权值 21 int dis[maxn];///dis[u]:u距根节点1的路径权值和 22 struct LCA 23 { 24 int cnt; 25 int dep[maxn<<1]; 26 int vs[maxn<<1]; 27 int pos[maxn]; 28 int dp[maxn<<1][30]; 29 int lb[maxn<<1]; 30 31 void DFS(int u,int f,int d,int dist) 32 { 33 vs[++cnt]=u; 34 dep[cnt]=d; 35 pos[u]=cnt; 36 dis[u]=dist; 37 for(int i=head[u];~i;i=G[i].next) 38 { 39 int v=G[i].to; 40 int w=G[i].w; 41 if(v == f) 42 continue; 43 t[v]=w; 44 DFS(v,u,d+1,dist+w); 45 46 vs[++cnt]=u; 47 dep[cnt]=d; 48 } 49 } 50 void ST() 51 { 52 lb[0]=-1; 53 for(int i=1;i <= cnt;++i) 54 { 55 dp[i][0]=i; 56 lb[i]=lb[i-1]+((i&(i-1)) == 0 ? 1:0); 57 } 58 for(int k=1;k <= lb[cnt];++k) 59 for(int i=1;i+(1<<k)-1 <= cnt;++i) 60 if(dep[dp[i][k-1]] > dep[dp[i+(1<<(k-1))][k-1]]) 61 dp[i][k]=dp[i+(1<<(k-1))][k-1]; 62 else 63 dp[i][k]=dp[i][k-1]; 64 } 65 void lcaInit() 66 { 67 cnt=0; 68 DFS(1,1,0,0); 69 ST(); 70 } 71 int lca(int u,int v) 72 { 73 u=pos[u]; 74 v=pos[v]; 75 if(u > v) 76 swap(u,v); 77 int k=lb[v-u+1]; 78 if(dep[dp[u][k]] > dep[dp[v-(1<<k)+1][k]]) 79 return vs[dp[v-(1<<k)+1][k]]; 80 else 81 return vs[dp[u][k]]; 82 } 83 void debug() 84 { 85 printf(" i:"); 86 for(int i=1;i <= cnt;++i) 87 printf(" %2d",i); 88 printf("\n"); 89 printf(" vs:"); 90 for(int i=1;i <= cnt;++i) 91 printf(" %2d",vs[i]); 92 printf("\n"); 93 printf("dep:"); 94 for(int i=1;i <= cnt;++i) 95 printf(" %2d",dep[i]); 96 printf("\n"); 97 printf("pos:"); 98 for(int i=1;i <= n;++i) 99 printf(" %2d",pos[i]); 100 printf("\n"); 101 printf("log:"); 102 for(int i=1;i <= cnt;++i) 103 printf(" %2d",lb[i]); 104 printf("\n"); 105 } 106 }_lca; 107 struct Que 108 { 109 int u,v,w; 110 }que[maxn]; 111 int diff[maxn]; 112 void DFS(int u,int f) 113 { 114 for(int i=head[u];~i;i=G[i].next) 115 { 116 int v=G[i].to; 117 if(v == f) 118 continue; 119 DFS(v,u); 120 121 diff[u] += diff[v]; 122 } 123 124 } 125 bool Check(int mid) 126 { 127 mem(diff,0); 128 int cnt=0; 129 int need=0; 130 for(int i=1;i <= m;++i) 131 { 132 if(que[i].w <= mid) 133 continue; 134 int u=que[i].u; 135 int v=que[i].v; 136 int lca=_lca.lca(u,v); 137 diff[u]++; 138 diff[v]++; 139 diff[lca] -= 2; 140 cnt++; 141 need=max(need,que[i].w-mid); 142 } 143 DFS(1,1); 144 145 for(int i=1;i <= n;++i) 146 { 147 if(diff[i] != cnt) 148 continue; 149 if(t[i] >= need) 150 return true; 151 } 152 return false; 153 } 154 int Solve() 155 { 156 _lca.lcaInit(); 157 // _lca.debug(); 158 for(int i=1;i <= m;++i) 159 { 160 int u,v,w; 161 scanf("%d%d",&u,&v); 162 w=dis[u]+dis[v]-2*dis[_lca.lca(u,v)]; 163 que[i]=Que{u,v,w}; 164 // printf("(%d,%d),lca=%d,%d\n",u,v,_lca.lca(u,v),w); 165 } 166 167 int l=-1,r=(int)3e8+50; 168 while(r-l > 1) 169 { 170 int mid=l+((r-l)>>1); 171 if(Check(mid)) 172 r=mid; 173 else 174 l=mid; 175 } 176 return r; 177 } 178 void Init() 179 { 180 num=0; 181 mem(head,-1); 182 } 183 int main() 184 { 185 // freopen("C:\\Users\\hyacinthLJP\\Desktop\\in&&out\\contest","r",stdin); 186 scanf("%d%d",&n,&m); 187 Init(); 188 for(int i=1;i < n;++i) 189 { 190 int u,v,w; 191 scanf("%d%d%d",&u,&v,&w); 192 addEdge(u,v,w); 193 addEdge(v,u,w); 194 } 195 printf("%d\n",Solve()); 196 197 return 0; 198 }