P8990 [北大集训 2021] 小明的树 题解
妙妙题。
如果一个点被点亮了,那么就称这个点为白点,否则为黑点。
由题意可得点 \(1\) 任意时刻都是黑点,于是“一个树是美丽的当前仅当对于每一个被点亮的节点,这个节点子树内的节点都是点亮的。”便可以转化为:
- 一个树是美丽的当前仅当任意时刻黑点形成一个连通块。
显然有:
-
黑点连通块个数 \(=\) 黑点个数 \(-\) 连接两个黑点的边的个数。
-
白点连通块个数 \(=\) 连接一个黑点和一个白点的边的个数。
对时间线建立线段树维护即可。线段树上每个节点记录节点区间内黑点连通块个数的最小值和对应时刻的白点连通块个数的和。
参考代码:
#include<bits/stdc++.h>
#define ll long long
#define mxn 500003
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define rept(i,a,b) for(int i=a;i<b;++i)
using namespace std;
struct node{
int x,y;
}d[mxn];
int n,m,a[mxn],p[mxn],f[mxn<<2],c[mxn<<2],ad[mxn<<2];
ll t[mxn<<2],d1[mxn<<2];
void push_up(int p){
f[p]=min(f[p<<1],f[p<<1|1]);
t[p]=0,c[p]=0;
if(f[p<<1]==f[p])t[p]+=t[p<<1],c[p]+=c[p<<1];
if(f[p<<1|1]==f[p])t[p]+=t[p<<1|1],c[p]+=c[p<<1|1];
}
void build(int p,int l,int r){
if(l==r){
f[p]=n-l,t[p]=0,c[p]=1;
return;
}
int mid=(l+r)>>1;
build(p<<1,l,mid);
build(p<<1|1,mid+1,r);
push_up(p);
}
void push_down(int p){
if(ad[p]){
f[p<<1]+=ad[p],ad[p<<1]+=ad[p];
f[p<<1|1]+=ad[p],ad[p<<1|1]+=ad[p];
ad[p]=0;
}
if(d1[p]){
t[p<<1]+=d1[p]*c[p<<1],d1[p<<1]+=d1[p];
t[p<<1|1]+=d1[p]*c[p<<1|1],d1[p<<1|1]+=d1[p];
d1[p]=0;
}
}
void upd(int p,int l,int r,int x,int L,int R){
if(l<=L&&R<=r){
f[p]+=x,ad[p]+=x;
return;
}
push_down(p);
int mid=(L+R)>>1;
if(l<=mid)upd(p<<1,l,r,x,L,mid);
if(r>mid)upd(p<<1|1,l,r,x,mid+1,R);
push_up(p);
}
void add(int p,int l,int r,int x,int L,int R){
if(l<=L&&R<=r){
t[p]+=x*c[p],d1[p]+=x;
return;
}
push_down(p);
int mid=(L+R)>>1;
if(l<=mid)add(p<<1,l,r,x,L,mid);
if(r>mid)add(p<<1|1,l,r,x,mid+1,R);
push_up(p);
}
void add(int a,int b){
int x=min(p[a],p[b]),y=max(p[a],p[b]);
if(x>1)upd(1,1,x-1,-1,1,n-1);
if(x<y)add(1,x,y-1,1,1,n-1);
}
void del(int a,int b){
int x=min(p[a],p[b]),y=max(p[a],p[b]);
if(x>1)upd(1,1,x-1,1,1,n-1);
if(x<y)add(1,x,y-1,-1,1,n-1);
}
signed main(){
scanf("%d%d",&n,&m);
rept(i,1,n)scanf("%d%d",&d[i].x,&d[i].y);
rept(i,1,n)scanf("%d",&a[i]),p[a[i]]=i;
p[1]=n;
build(1,1,n-1);
rept(i,1,n){
int x=min(p[d[i].x],p[d[i].y]),y=max(p[d[i].x],p[d[i].y]);
if(x>1)upd(1,1,x-1,-1,1,n-1);
if(x<y)add(1,x,y-1,1,1,n-1);
}
printf("%lld\n",t[1]);
int x1,y1,x2,y2;
while(m--){
scanf("%d%d%d%d",&x1,&y1,&x2,&y2);
del(x1,y1);
add(x2,y2);
printf("%lld\n",t[1]);
}
return 0;
}