POJ3417 Network (树上差分)
Yixght is a manager of the company called SzqNetwork(SN). Now she's very worried because she has just received a bad news which denotes that DxtNetwork(DN), the SN's business rival, intents to attack the network of SN. More unfortunately, the original network of SN is so weak that we can just treat it as a tree. Formally, there are N nodes in SN's network, N-1 bidirectional channels to connect the nodes, and there always exists a route from any node to another. In order to protect the network from the attack, Yixght builds M new bidirectional channels between some of the nodes.
As the DN's best hacker, you can exactly destory two channels, one in the original network and the other among the M new channels. Now your higher-up wants to know how many ways you can divide the network of SN into at least two parts.
Input
The first line of the input file contains two integers: N (1 ≤ N ≤ 100 000), M (1 ≤ M ≤ 100 000) — the number of the nodes and the number of the new channels.
Following N-1 lines represent the channels in the original network of SN, each pair (a,b) denote that there is a channel between node a and node b.
Following M lines represent the new channels in the network, each pair (a,b) denote that a new channel between node a and node b is added to the network of SN.
Output
Output a single integer — the number of ways to divide the network into at least two parts.
Sample Input
4 1 1 2 2 3 1 4 3 4
Sample Output
3
这题样例太假了,一开始忘了调用bfs都能把样例过了...
第一道树上差分,就跟着蓝书写了。首先注意到这些主要边构成了一棵树的结构,再往上加附加边的时候一定会有环形成(类比化学里的不饱和度2333),这样的话,根据附加边的多少以及位置,对于每一条主要边可以分为三种情况:
对于第一种情况,砍掉主要边以后已经能把原图分为两部分了,所以在剩下的任意一条附加边里选一条砍掉即可;对于第二种情况,只能砍掉这一条与当前主要边成环的附加边;对于第三种情况砍掉任何附加边都不能分为两半。那么对于每条附加边而言,它对于形成的环里的任意一条主要边都有影响,所以可以统计每条主要边被覆盖了多少次。朴素的做法肯定不行,所以可以借鉴差分的思想。这个题用到就是边的树上差分(需要用到LCA)。
在这直接放一篇洛谷日报,讲的比较好QWQhttps://rpdreamer.blog.luogu.org/ci-fen-and-shu-shang-ci-fen 注意一点,先dfs,回溯时自叶子节点往树根节点更新。
#include <iostream> #include <cstdio> #include <cstring> #include <queue> #include <cmath> #include <algorithm> using namespace std; const int SIZE=100005; int f[SIZE][22],d[SIZE],vis[SIZE],lg[SIZE]; int ver[SIZE*2],Next[SIZE*2],head[SIZE]; int diff[SIZE]; int n,m,tot=0; queue<int>q; int ans=0; void add(int x,int y) { ver[++tot]=y,Next[tot]=head[x],head[x]=tot; } void bfs() { q.push(1),d[1]=1; while(q.size()) { int x=q.front();q.pop(); int i; for(i=head[x];i;i=Next[i]) { int y=ver[i]; if(d[y])continue; d[y]=d[x]+1; f[y][0]=x; int j; for(j=1;j<=lg[d[x]];j++) { f[y][j]=f[f[y][j-1]][j-1]; } q.push(y); } } } int lca(int x,int y) { if(d[x]<d[y])swap(x,y); while(d[x]>d[y]) { x=f[x][lg[d[x]-d[y]]-1]; } if(x==y)return x; int i; for(i=lg[d[x]]-1;i>=0;i--) { if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i]; } return f[x][0]; } void dfs(int x,int pre) { int i; for(i=head[x];i;i=Next[i]) { int y=ver[i]; if(y==pre)continue; dfs(y,x);//此时y已经被更新过了 if(diff[y]==0)ans+=m; else if(diff[y]==1)ans++; diff[x]+=diff[y]; } } int main() { cin>>n>>m; int i; for(i=1;i<=n;i++)lg[i]=lg[i-1]+(1<<lg[i-1]==i); memset(diff,0,sizeof(diff)); for(i=1;i<=n-1;i++) { int x,y; scanf("%d%d",&x,&y); add(x,y); add(y,x); } bfs();//千万别忘了加 for(i=1;i<=m;i++) { int x,y; scanf("%d%d",&x,&y); int anc=lca(x,y); diff[x]++,diff[y]++,diff[anc]-=2; } dfs(1,0); cout<<ans; return 0; }