BZOJ_3159_决战
分析:
我使用树剖+splay维护这个东西。
对每条重链维护一棵splay,链加和查询正常做,剩下的链反转如下。
由于一定是深度递增的一条链,我们树剖将它分成从左到右log个区间,提取出对应子树,插入到一个新的splay中。
然后打标记进行反转,将子树归还给log个区间。
时间复杂度\(O(nlogn^2)\)
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
#define N 200050
typedef long long ll;
#define ls ch[p][0]
#define rs ch[p][1]
#define db(x) cerr<<#x<<" = "<<x<<endl
#define Db(x) cerr<<#x<<endl
#define get(x) (ch[f[x]][1]==x)
int ch[N][2],f[N],sz1[N],rev[N];
int head[N],to[N<<1],nxt[N<<1],fa[N],son[N],top[N],sz2[N],dep[N],cnt,root;
int idx[N],pid[N],id[N],n,m,TOT,tot;
ll sum[N],tag[N],mn[N],mx[N],num[N];
struct Splay {
int rt,bg,ed;
int newnode() {
int p=++tot; return p;
}
void init() {
rt=newnode(); int p=newnode();
ch[rt][1]=p; f[p]=rt; pushup(p); pushup(rt);
}
void pushup(int p) {
sum[p]=sum[ls]+sum[rs]+num[p];
mn[p]=min(mn[ls],min(mn[rs],num[p]));
mx[p]=max(mx[ls],max(mx[rs],num[p]));
sz1[p]=sz1[ls]+sz1[rs]+1;
}
void giv1(int p) {
rev[p]^=1; swap(ls,rs);
}
void giv2(int p,ll d) {
tag[p]+=d; sum[p]+=sz1[p]*d; mn[p]+=d; mx[p]+=d; num[p]+=d;
}
void pushdown(int p) {
if(rev[p]) {
if(ls) giv1(ls);
if(rs) giv1(rs);
rev[p]=0;
}
if(tag[p]) {
if(ls) giv2(ls,tag[p]);
if(rs) giv2(rs,tag[p]);
tag[p]=0;
}
}
void UPD(int x) {
if(x!=rt) UPD(f[x]);
pushdown(x);
}
void rotate(int x) {
int y=f[x],z=f[y],k=get(x);
ch[y][k]=ch[x][!k]; f[ch[y][k]]=y;
ch[x][!k]=y; f[y]=x; f[x]=z;
if(z) ch[z][ch[z][1]==y]=x;
if(y==rt) rt=x;
pushup(y); pushup(x);
}
void splay(int x,int y) {
UPD(x);
for(int d;(d=f[x])!=y;rotate(x)) if(f[d]!=y) rotate(get(x)==get(d)?d:x);
}
int find(int x) {
int p=rt;
while(1) {
pushdown(p);
if(sz1[ls]>=x) p=ls;
else {
x-=sz1[ls]+1;
if(!x) return p;
p=rs;
}
}
}
int BUILD(int l,int r,int fa) {
int mid=(l+r)>>1;
int p=newnode();
f[p]=fa;
if(l<mid) ls=BUILD(l,mid-1,p);
if(r>mid) rs=BUILD(mid+1,r,p);
pushup(p);
return p;
}
void build(int x) {
rt=BUILD(1,x+2,0);
}
void update(int x,int y,int z) {
x=x-bg+1,y=y-bg+1;
x=find(x),y=find(y+2);
splay(x,0); splay(y,x);
giv2(ch[y][0],z);
pushup(y); pushup(x);
}
ll qsum(int x,int y) {
x=x-bg+1,y=y-bg+1;
x=find(x),y=find(y+2);
splay(x,0); splay(y,x);
//db(x),db(y);
return sum[ch[y][0]];
}
ll qmin(int x,int y) {
x=x-bg+1,y=y-bg+1;
x=find(x),y=find(y+2);
splay(x,0); splay(y,x);
return mn[ch[y][0]];
}
ll qmax(int x,int y) {
x=x-bg+1,y=y-bg+1;
x=find(x),y=find(y+2);
splay(x,0); splay(y,x);
return mx[ch[y][0]];
}
int split(int x,int y) {
x=x-bg+1,y=y-bg+1;
x=find(x),y=find(y+2);
splay(x,0); splay(y,x);
int p=ch[y][0];
f[p]=0,ch[y][0]=0;
pushup(y); pushup(x);
return p;
}
void insert(int x,int p) {
int y,t=x;
x=find(t+1),y=find(t+2);
splay(x,0); splay(y,x);
ch[y][0]=p; f[p]=y;
pushup(y); pushup(x);
}
void findbug(int p) {
pushdown(p);
if(ls) findbug(ls);
//printf("p=%d num[p]=%lld sz1[p]=%d sum=%lld\n",p,num[p],sz1[p],sum[p]);
if(rs) findbug(rs);
}
}G[N],TMP;
inline void add(int u,int v) {
to[++cnt]=v; nxt[cnt]=head[u]; head[u]=cnt;
}
void df1(int x,int y) {
int i; sz2[x]=1; fa[x]=y;
dep[x]=dep[y]+1;
for(i=head[x];i;i=nxt[i]) if(to[i]!=y) {
df1(to[i],x); sz2[x]+=sz2[to[i]];
if(sz2[to[i]]>sz2[son[x]]) son[x]=to[i];
}
}
void df2(int x,int t) {
int i;
top[x]=t;
idx[x]=++idx[0]; pid[idx[0]]=x;
if(son[x]) df2(son[x],t);
for(i=head[x];i;i=nxt[i]) if(to[i]!=fa[x]&&to[i]!=son[x]) df2(to[i],to[i]);
}
char opt[12];
void INCREASE(int x,int y,int z) {
while(top[x]!=top[y]) {
if(dep[top[x]]>dep[top[y]]) swap(x,y);
G[id[idx[y]]].update(idx[top[y]],idx[y],z);
y=fa[top[y]];
}
if(dep[x]<dep[y]) swap(x,y);
G[id[idx[y]]].update(idx[y],idx[x],z);
}
ll SUM(int x,int y) {
ll re=0;
while(top[x]!=top[y]) {
if(dep[top[x]]>dep[top[y]]) swap(x,y);
re+=G[id[idx[y]]].qsum(idx[top[y]],idx[y]);
y=fa[top[y]];
}
if(dep[x]<dep[y]) swap(x,y);
re+=G[id[idx[y]]].qsum(idx[y],idx[x]);
return re;
}
ll MIN(int x,int y) {
ll re=1ll<<60;
while(top[x]!=top[y]) {
if(dep[top[x]]>dep[top[y]]) swap(x,y);
re=min(re,G[id[idx[y]]].qmin(idx[top[y]],idx[y]));
y=fa[top[y]];
}
if(dep[x]<dep[y]) swap(x,y);
re=min(re,G[id[idx[y]]].qmin(idx[y],idx[x]));
return re;
}
ll MAX(int x,int y) {
ll re=0;
while(top[x]!=top[y]) {
if(dep[top[x]]>dep[top[y]]) swap(x,y);
re=max(re,G[id[idx[y]]].qmax(idx[top[y]],idx[y]));
y=fa[top[y]];
}
if(dep[x]<dep[y]) swap(x,y);
re=max(re,G[id[idx[y]]].qmax(idx[y],idx[x]));
return re;
}
int lca(int x,int y) {
while(top[x]!=top[y]) {
if(dep[top[x]]>dep[top[y]]) swap(x,y);
y=fa[top[y]];
}
return dep[x]<dep[y]?x:y;
}
struct A {
int l,r,id,p;
}a[N];
void clr(int p) {
num[p]=tag[p]=rev[p]=mn[p]=mx[p]=sum[p]=ls=rs=f[p]=sz1[p]=0;
}
void INVERSE(int x,int y) {
if(dep[x]>dep[y]) swap(x,y);
int la=0;
while(top[x]!=top[y]) {
a[++la]=(A){idx[top[y]],idx[y],id[idx[y]],G[id[idx[y]]].split(idx[top[y]],idx[y])};
y=fa[top[y]];
}
a[++la]=(A){idx[x],idx[y],id[idx[x]],G[id[idx[x]]].split(idx[x],idx[y])};
int i;
for(i=la;i;i--) {
TMP.insert(sz1[TMP.rt]-2,a[i].p);
}
TMP.splay(1,0); TMP.splay(2,1);
TMP.giv1(ch[2][0]); TMP.pushup(2); TMP.pushup(1);
for(i=la;i;i--) {
int p=TMP.split(1,a[i].r-a[i].l+1);
G[a[i].id].insert(a[i].l-G[a[i].id].bg,p);
}
clr(1),clr(2); ch[1][1]=2; f[2]=1; sz1[1]=2; sz1[2]=1; TMP.rt=1;
}
int main() {
mn[0]=1ll<<60;
scanf("%d%d%d",&n,&m,&root);
int i,x,y,j=0;
for(i=1;i<n;i++) {
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
df1(root,0),df2(root,root);
for(i=1;i<=n;i++) {
if(top[pid[i]]!=top[pid[i-1]]) j++,G[j].bg=i;
id[i]=j; G[j].ed=i;
}
TMP.init(); TMP.bg=1;
TOT=j;
for(i=1;i<=TOT;i++) {
G[i].build(G[i].ed-G[i].bg+1);
}
int z;
for(i=1;i<=m;i++) {
scanf("%s",opt);
if(opt[0]=='I') {
if(opt[2]=='c') {
scanf("%d%d%d",&x,&y,&z);
INCREASE(x,y,z);
}else {
scanf("%d%d",&x,&y);
INVERSE(x,y);
}
}else if(opt[0]=='S') {
scanf("%d%d",&x,&y);
printf("%lld\n",SUM(x,y));
}else {
if(opt[1]=='a') {
scanf("%d%d",&x,&y);
printf("%lld\n",MAX(x,y));
}else {
scanf("%d%d",&x,&y);
printf("%lld\n",MIN(x,y));
}
}
}
}
/*
5 8 1
1 2
2 3
3 4
4 5
Sum 2 4
Increase 3 5 3
Minor 1 4
Sum 4 5
Invert 1 3
Major 1 2
Increase 1 5 2
Sum 1 5
*/