【NOIP 校内模拟】T3 忘了是啥名字了(dfs序+树状数组)
对于当前新加入的一条路径 他产生的贡献分为两种
1.另一条路径的LCA在当前路径上
2.当前路径的LCA在另一条上
对于情况1:
可以维护当前点到根节点有多少个LCA,查询只需查询u,v,-2*lca(u,v),修改需要对lca的子树+1
对于情况2:
显然的树上差分,查询就是lca子树的前缀和,修改u++,v++,lca-2
即开两个树状数组,一个支持单点查询+区间修改,一个支持单点修改+区间查询,不嫌麻烦的话可以尝试线段树。
需要开栈,某OJ栈空间感人。
#include<bits/stdc++.h>
#define N 1000005
#define M 1000005
#define ll long long
using namespace std;
template<class T>
inline void read(T &x)
{
x=0; int f=1;
static char ch=getchar();
while((!isdigit(ch))&&ch!='-') ch=getchar();
if(ch=='-') f=-1,ch=getchar();
while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
x*=f;
}
//1e6
struct Edge
{
int to,next;
}edge[2*N];
int n,m,tot,first[N];
inline void addedge(int x,int y)
{
tot++;
edge[tot].to=y; edge[tot].next=first[x]; first[x]=tot;
}
int up[N][27],depth[N],st[N],sign,ed[N];
ll con[N];
void dfs(int now,int fa)
{
up[now][0]=fa;
depth[now]=depth[fa]+1;
st[now]=++sign;
for(int i=1;i<=25;i++) up[now][i]=up[up[now][i-1]][i-1];
for(int u=first[now];u;u=edge[u].next)
{
int vis=edge[u].to;
if(vis==fa) continue;
dfs(vis,now);
}
ed[now]=sign;
}
inline int getlca(int x,int y)
{
if(depth[x]<depth[y]) swap(x,y);
for(int i=25;i>=0;i--) if(depth[up[x][i]]>=depth[y]) x=up[x][i];
if(x==y) return x;
for(int i=25;i>=0;i--) if(up[x][i]!=up[y][i]) x=up[x][i],y=up[y][i];
return up[x][0];
}
inline int lowbit(int x)
{
return x&(-x);
}
struct BIT
{
int n;
ll tree[N];
inline void getn(int x)
{
n=x;
}
inline void update(int x,ll del)
{
for(int i=x;i<=n;i+=lowbit(i)) tree[i]+=del;
}
inline ll query(int x)
{
ll ans=0;
for(int i=x;i;i-=lowbit(i)) ans+=tree[i];
return ans;
}
}bit1,bit2; //区间加单点查 单点加区间查 其实就是差分,普通 bit
int main()
{
ll size=40<<20;//40M
__asm__ ("movq %0,%%rsp\n"::"r"((char*)malloc(size)+size));//提交用这个
read(n),read(m);
for(register int i=1;i<n;i++)
{
int x,y;
read(x),read(y);
addedge(x,y); addedge(y,x);
}
dfs(1,0);
bit1.getn(n); bit2.getn(n);
ll ans=0;
//需要分两种情况讨论:其他的lca在这条路径上 自己的lca在其他路径上
for(int i=1,u,v,lca;i<=m;i++)
{
read(u); read(v); lca=getlca(u,v);
ans=ans+bit1.query(st[u])+bit1.query(st[v])-2*bit1.query(st[lca]);
ans=ans+bit2.query(ed[lca])-bit2.query(st[lca]-1);
ans=ans+con[lca];
con[lca]++;
bit1.update(st[lca],1); bit1.update(ed[lca]+1,-1);
bit2.update(st[u],1); bit2.update(st[v],1); bit2.update(st[lca],-2);
}
cout<<ans;
exit(0);
}
QQ40523591~欢迎一起学习交流~