[USACO19JAN]Exercise Route
这题的数据有点水,暴力合并\(set\)好像过了
分析一下这个题的性质,发现我们一条非树边就会形成一个环,而我们要求选择两个非树边,就会形成两个环,要求不走重复的点,就是说我们需要走一个大环,且必须经过这两个小环
显然需要这两个小环有至少一条公共边
发现问题转化成了求有多少对路径,这对路径有公共边
应该可以用数据结构来维护大力讨论的,当然树上差分的正解思路也非常妙
于是我们写个暴力吧!
对于这道题我第一想法是启发式合并\(set\),我们可以利用如同树上差分那样的方式,对于一条路径\((u,v)\),我们在\(u,v\)处打上一个标记,在\(lca(u,v)\)再打上一个删除标记,每次合并\(set\)的时候,往\(set\)里插入的所有标记都表明这个标记代表的路径经过了这条边,这样这条路径就可以和原来\(set\)里的标记任意组合了,就可以统计答案了
但是会和正解树上差分遇到一样的问题,就是如果两条路径的公共边形成的不是一条链,这对路径我们会计算两次
这里采用和正解一样的方法就好了,就是利用一个\(map\)来找到这样的路径有多少对
到这里复杂度还非常正常,是有点大的\(O(nlog^2n)\),但是我们又发现我们好像没有什么办法去利用删除标记
如果每次合并前都要扫一遍删除集合的话,复杂度显然就不对了,菊花树随便卡掉
于是我们启发式删除,看看维护删除集合小还是待合并集合小,那个小就遍历哪一个
还是过不去怎么办,那就把\(set\)换成unordered_set,信仰一下就可以了
代码
#include<tr1/unordered_set>
#include<tr1/unordered_map>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#define maxn 200005
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define set_it unordered_set<int>::iterator
# define getchar() (S==T&&(T=(S=BB)+fread(BB,1,1<<15,stdin),S==T)?EOF:*S++)
using namespace std::tr1;
char BB[1 << 18],*S=BB,*T=BB;
inline int read() {
int x=0;char c=getchar();while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
struct E{int v,nxt;}e[maxn<<1];
int head[maxn],deep[maxn],sum[maxn],fa[maxn],top[maxn],son[maxn],f[maxn][19],st[maxn];
int n,m,num,cnt,md,t;
int rt[maxn];
LL ans;
unordered_set<int> s[maxn],del[maxn];
unordered_map<LL,int> ma;
inline void add(int x,int y) {e[++num].v=y;e[num].nxt=head[x];head[x]=num;}
void dfs1(int x) {
sum[x]=1;int maxx=-1;
for(re int i=head[x];i;i=e[i].nxt) {
if(deep[e[i].v]) continue;
deep[e[i].v]=deep[x]+1;fa[e[i].v]=x;f[e[i].v][0]=x;
for(re int j=1;j<=18;j++) f[e[i].v][j]=f[f[e[i].v][j-1]][j-1];
dfs1(e[i].v);
sum[x]+=sum[e[i].v];
if(sum[e[i].v]>maxx) maxx=sum[e[i].v],son[x]=e[i].v;
}
}
void dfs2(int x,int topf) {
top[x]=topf;
if(!son[x]) return;
dfs2(son[x],topf);
for(re int i=head[x];i;i=e[i].nxt) {
if(top[e[i].v]) continue;
dfs2(e[i].v,e[i].v);
}
}
inline int LCA(int x,int y) {
while(top[x]!=top[y]) {
if(deep[top[x]]<deep[top[y]]) std::swap(x,y);
x=fa[top[x]];
}
if(deep[x]<deep[y]) return x;return y;
}
inline void D(int a,int c) {
if(del[c].size()<s[a].size()) {
for(set_it it=del[c].begin();it!=del[c].end();++it) {
int x=*it;
s[a].erase(x);
}
return;
}
t=0;
for(set_it it=s[a].begin();it!=s[a].end();++it)
if(del[c].find(*it)!=del[c].end()) st[++t]=*it;
while(t) s[a].erase(st[t--]);
}
inline void merge(int a,int b,int c) {
int now=s[a].size();
for(set_it it=s[b].begin();it!=s[b].end();++it) {
int x=*it;
if(del[c].find(x)!=del[c].end()) continue;
ans+=now;s[a].insert(x);
}
}
inline int jump(int x,int y) {
for (re int i=18;i>=0;--i) if(deep[x]-(1<<i)>deep[y]) x=f[x][i];
return x;
}
void dfs(int x) {
D(rt[x],x);
for(re int i=head[x];i;i=e[i].nxt) {
if(deep[e[i].v]<deep[x]) continue;
dfs(e[i].v);
if(s[rt[e[i].v]].size()<=s[rt[x]].size()) {
if(x!=1) merge(rt[x],rt[e[i].v],x);
}
else {
D(rt[e[i].v],x);
if(x!=1) merge(rt[e[i].v],rt[x],x),rt[x]=rt[e[i].v];
}
}
}
int main() {
n=read(),m=read();int x,y,z;
for(re int i=1;i<n;i++) x=read(),y=read(),add(x,y),add(y,x);
deep[1]=1,dfs1(1),dfs2(1,1);
for(re int i=1;i<=n;i++) rt[i]=i;
for(re int i=n;i<=m;i++) {
x=read(),y=read();z=LCA(x,y);
if(x==y) continue;
if(deep[x]<deep[y]) std::swap(x,y);
ans+=s[x].size();
if(z!=y) ans+=s[y].size();
s[x].insert(i);
if(z!=y) s[y].insert(i);
del[z].insert(i);
if(x==z||y==z) continue;
int xx=jump(x,z),yy=jump(y,z);
if(xx>yy) std::swap(xx,yy);
ans-=ma[(LL)xx*(LL)n+yy];ma[(LL)xx*(LL)n+yy]++;
}
dfs(1);
printf("%lld\n",ans);
}