树上差分:闇の連鎖 题解
题目描述
传说中的暗之连锁被人们称为 Dark。
Dark 是人类内心的黑暗的产物,古今中外的勇者们都试图打倒它。
经过研究,你发现 Dark 呈现无向图的结构,图中有 N 个节点和两类边,一类边被称为主要边,而另一类被称为附加边。
Dark 有 N – 1 条主要边,并且 Dark 的任意两个节点之间都存在一条只由主要边构成的路径。
另外,Dark 还有 M 条附加边。
你的任务是把 Dark 斩为不连通的两部分。
一开始 Dark 的附加边都处于无敌状态,你只能选择一条主要边切断。
一旦你切断了一条主要边,Dark 就会进入防御模式,主要边会变为无敌的而附加边可以被切断。
但是你的能力只能再切断 Dark 的一条附加边。
现在你想要知道,一共有多少种方案可以击败 Dark。
注意,就算你第一步切断主要边之后就已经把 Dark 斩为两截,你也需要切断一条附加边才算击败了 Dark。
输入格式
第一行包含两个整数 N 和 M。
之后 N – 1 行,每行包括两个整数 A 和 B,表示 A 和 B 之间有一条主要边。
之后 M 行以同样的格式给出附加边。
输出格式
输出一个整数表示答案。
数据范围
N≤100000,M≤200000,数据保证答案不超过231−1
样例
输入样例:
4 1
1 2
2 3
1 4
3 4
输出样例:
3
在没有附加边的情况下,我们发现这是一颗树,那么再添加条附加边(x,y)后,会造成(x,y)之间产生一个环
如果我们第一步截断了(x,y)之间的一条路,那么我们第二次只能截掉(x,y)之间的附加边,才能使其不连通;
我们将每条附加边(x,y)称为将(x,y)之间的路径覆盖了一遍;
因此我们只需要统计出每条主要边被覆盖了几次即可;
对于只被覆盖一次的边,第二次我们只能切断(x,y)边,方法唯一;
如果我们第一步切断了被覆盖0次的边,那么我们已经将其分为两部分,那么第二部只需要在m条附加边中任选一条即可,如果第一步截到被覆盖超过两次的边,将无法将其分为两部分;
运用乘法原理,我们累加答案;
那么怎么标记我们的边(x,y)被覆盖了几次呢,那么我们可以使用树上差分,是解决此类问题的经典套路;
我们想,对于一条边(x,y),我们添加一条边;
那么只会对x到lca(x,y)到y上的边产生影响,对于(x,y)我们将x节点的权值+1,y节点的权值+1,另lca(x,y)的权值-2,画图很好理解,那么我们进行一遍dfs求出每个节点权值,那么这个值就是节点父节点连边被覆盖的次数,按上述方法累加答案即可;
时间复杂度分析:O(N+M)
#include<bits/stdc++.h> using namespace std; #define N 500100 int x,y,n,m,tot,ans,p[N],f1[N],lin[N],f[N][25],d[N],vis[N]; template<typename T>inline void read(T &x) { x=0;T f=1,ch=getchar(); while(!isdigit(ch)) {if(ch=='-') f=-1; ch=getchar();} while(isdigit(ch)) {x=(x<<1)+(x<<3)+(ch^48); ch=getchar();} x*=f; } struct gg { int x,y,next; }a[N<<1]; inline void add(int x,int y) { a[++tot].x=x; a[tot].y=y; a[tot].next=lin[x]; lin[x]=tot; } void bfs(int x) { queue<int> q; q.push(1); d[1]=1; while(q.size()) { int x=q.front(); q.pop(); for(int i=lin[x];i;i=a[i].next) { int y=a[i].y; if(d[y]) continue; d[y]=d[x]+1; f[y][0]=x; for(int j=1;j<=23;j++){ f[y][j]=f[f[y][j-1]][j-1]; } q.push(y); } } } inline int lca(int x,int y) { if(d[x]>d[y]) swap(x,y); for(int i=23;i>=0;i--) { if(d[f[y][i]]>=d[x]) y=f[y][i]; } if(x==y) return x; for(int i=23;i>=0;i--) { if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i]; } return f[x][0]; } inline void dfs(int x) { vis[x]=1; for(int i=lin[x];i;i=a[i].next) { int y=a[i].y; if(vis[y]) continue; dfs(y); p[x]+=p[y]; } } int main() { read(n); read(m); int x,y; for(int i=1;i<n;i++) { read(x);read(y); add(x,y); add(y,x); } bfs(1); for(int i=1;i<=m;i++) { read(x);read(y); p[x]++,p[y]++; p[lca(x,y)]-=2; } dfs(1); for(int i=2;i<=n;i++) { if(!p[i]) ans+=m; if(p[i]==1) ans+=1; } cout<<ans<<endl; return 0; }